BDEEstimator#

class bde.ml.models.BDEEstimator(model_class=<class 'bde.ml.models.FullyConnectedModule'>, model_kwargs=None, n_chains=1, chain_len=1, warmup=1, n_samples=1, optimizer_class=<function adam>, optimizer_kwargs=None, loss=<bde.ml.loss.GaussianNLLLoss object>, batch_size=1, epochs=1, metrics=None, validation_size=None, seed=42)#

SKlearn-compatible implementation of a BDE estimator.

# TODO: Describe BDE estimator.

Attributes:
# TODO: List
Parameters:

Methods

fit(X, y=None)

Fit the model to the training data.

predict(X)

Predict the output for the given input data using the trained model.

Methods

burn_in_loop(rng, params, n_burns, warmup)

Perform burn-in for sampler.

fit(X[, y, n_devices])

Fit the function to the given data.

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

Get parameters for this estimator.

history_description()

Make a readable version of the training history.

init_inner_params(n_features, optimizer, rng_key)

Create trainable model state.

load(path)

Load estimator from file.

log_prior(params)

Calculate the log of the prior probability for a set of params.

logdensity_for_batch(params, carry, batch)

Evaluate log-density for a batch of data.

mcmc_sampling(model_states, rng_key, train, ...)

Perform MCMC-burn-in and sampling.

predict(X)

Apply the fitted model to the input data.

predict_as_de(X[, n_devices])

Predict with model as a deep ensemble.

predict_with_credibility_eti(X[, a])

Make prediction with a credible interval.

sample_from_samples(x[, n_devices, batch_size])

Take samples from sampled Gaussian distributions based on model params.

save(path)

Save estimator to file.

set_fit_request(*[, n_devices])

Request metadata passed to the fit method.

set_params(**params)

Set the parameters of this estimator.

tree_flatten()

Specify how to serialize estimator into a JAX pytree.

tree_unflatten(aux_data, children)

Specify how to build an estimator from a JAX pytree.

static burn_in_loop(rng, params, n_burns, warmup)#

Perform burn-in for sampler.

Parameters:
fit(X, y=None, n_devices=-1)#

Fit the function to the given data.

Parameters:
X

The input data.

y

The labels. If y is None, X is assumed to include the labels as well.

n_devices

Number of devices to use for parallelization. -1 means using all available devices. If a number greater than all available devices is given, the max number of devices is used.

Returns:
BDEEstimator

The fitted estimator.

Return type:

BDEEstimator

Parameters:
get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:
routingMetadataRequest

A MetadataRequest encapsulating routing information.

get_params(deep=True)#

Get parameters for this estimator.

Parameters:
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:
paramsdict

Parameter names mapped to their values.

history_description()#

Make a readable version of the training history.

Returns:
Dict

Each key corresponds to an evaluation metric/ loss and each value is an array describing the values for different epochs.

Raises:
AssertionError

If the model is not fitted.

Return type:

Dict[str, Array]

init_inner_params(n_features, optimizer, rng_key)#

Create trainable model state.

Parameters:
n_features

Number of input features.

optimizer

Optimization algorithm used for training.

rng_key

Randomness key.

Returns:
train_state.TrainState

Initialized training state.

Return type:

TrainState

classmethod load(path)#

Load estimator from file.

Return type:

FullyConnectedEstimator

Parameters:

path (str | Path)

log_prior(params)#

Calculate the log of the prior probability for a set of params.

logdensity_for_batch(params, carry, batch)#

Evaluate log-density for a batch of data.

Parameters:
params

Parameters of the evaluated model.

carry

log-prior + logdensity of previous batches.

batch

Current batch to be evaluated.

Returns:
Tuple:
  • Updated logdensity (carry + current batch value).

  • None (used for compatibility with jax.lax.scan)

Return type:

Tuple[Array, None]

Parameters:
mcmc_sampling(model_states, rng_key, train, n_devices, parallel_batch_size, mask)#

Perform MCMC-burn-in and sampling.

Parameters:
model_states

Initial model states for sampler.

rng_key

A key used to initialize the random processes.

train

Training dataset. The dataset must include only 1 batch (full-batch).

n_devices

Exact number of computational devices used for parallelization.

parallel_batch_size

The number of chains to compute on each device when running parallel computations.

mask

A 2D-array indicating which chains are used for padding during parallelization (shape = (n_devices, batch_size)). - Chains corresponding to the value one in the mask will be evaluated

and sampled from.

  • Chains corresponding to the value zero in the mask will be ignored when possible and return pseudo-samples which can be discarded.

Returns:
List[Dict]

Samples from all mcmc-chains.

Return type:

List[Dict]

Parameters:
  • model_states (TrainState)

  • rng_key (PRNGKeyArray)

  • train (BasicDataset)

  • n_devices (int)

  • parallel_batch_size (int)

  • mask (Array)

predict(X)#

Apply the fitted model to the input data.

Parameters:
X

The input data.

Returns:
Array

Predicted values (mean of samples).

Return type:

Array

Parameters:

X (Array | ndarray | bool | number | bool | int | float | complex)

predict_as_de(X, n_devices=-1)#

Predict with model as a deep ensemble.

This method ignores the samples data and uses the initialization params only.

Parameters:
X

The input data.

n_devices

Number of devices to use for parallelization. -1 means using all available devices. If a number greater than all available devices is given, the max number of devices is used.

Returns:
Array

Predicted values (mean_of_predictions).

Return type:

Array

Parameters:
predict_with_credibility_eti(X, a=0.95)#

Make prediction with a credible interval.

Parameters:
X

The input data.

a

Size of credibility interval (in probability: 0 - 1).

Returns
——-
3 arrays with:
- Predicted values (median of samples).
- Lower value of confidence interval per prediction.
- Upper value of confidence interval per prediction.
Return type:

Tuple[Array, Array, Array]

Parameters:
sample_from_samples(x, n_devices=1, batch_size=-1)#

Take samples from sampled Gaussian distributions based on model params.

The mean and std predicted by the sampled models define Gaussian distributions. Take samples from these distributions.

Parameters:
X

The input data.

n_devices

batch_size

Returns:
Array

The last dim of the array is a list of samples taken from each predicted distribution in each batch. The shape is: (b_1, b_2, ..., b_2, output_size / 2, self.n_samples * self.n_chains)

Return type:

Array

Parameters:
save(path)#

Save estimator to file.

Return type:

None

Parameters:

path (str | Path)

set_fit_request(*, n_devices='$UNCHANGED$')#

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:
n_devicesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for n_devices parameter in fit.

Returns:
selfobject

The updated object.

Parameters:
Return type:

BDEEstimator

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:
**paramsdict

Estimator parameters.

Returns:
selfestimator instance

Estimator instance.

tree_flatten()#

Specify how to serialize estimator into a JAX pytree.

Returns:
Tuple[Sequence[ArrayLike], Any]
A tuple with 2 elements:
  • The children, containing arrays & pytrees.

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex, Dict]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build an estimator from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
BDEEstimator

Reconstructed estimator.

Return type:

BDEEstimator

Parameters:

Examples using bde.ml.models.BDEEstimator#

sphx_glr_auto_examples_example01.py

Plot Template Estimator.