BDEEstimator#

class bde.ml.models.BDEEstimator(model_class=<class 'bde.ml.models.FullyConnectedModule'>, model_kwargs=None, n_chains=1, chain_len=1, warmup=1, n_samples=1, optimizer_class=<function adam>, optimizer_kwargs=None, loss=<bde.ml.loss.GaussianNLLLoss object>, batch_size=1, epochs=1, metrics=None, validation_size=None, seed=42)#

SKlearn-compatible implementation of a BDE estimator.

The estimator attempts to sample the parameter distribution of probabilistic machine learning models to estimate the posterior predictive distribution, allowing to predict uncertainty and confidence values alongside predicted values.

Attributes:
model_class

The neural network model class wrapped by the estimator.

model_kwargs

The kwargs used to init the wrapped model.

optimizer_class

The optimizer class used by the estimator for training.

optimizer_kwargs

The kwargs used to init optimizer.

loss

A class representing the loss function.

batch_size

Number of samples per batch (size of first dimension).

epochs

Number of epochs for the DE-Initialization stage (per chain).

metrics

A list of metrics to evaluate during training, by default None.

validation_size

The size of the validation set, or a tuple containing validation data. by default None.

seed

Random seed for initialization.

n_chains

Number of MCMC sampling chains.

chain_len

Number of sampling steps during the MCMC-Sampling stage (per chain).

warmup

Number of warmup (burn-in) steps before the MCMC-Sampling (per chain).

n_samples

Number of samples to take from each Gaussian-distribution during prediction.

params_

Parameters generated from DE init stage.

samples_

Sampled model parameters from mcmc stage.

history_

Loss and metric records during DE init stage.

model_

The initialized model class.

is_fitted_

A flag indicating weather the model has been fitted or not.

n_features_in_

Number of input features (the size of the last input dimension).

Parameters:

Methods

fit(X, y=None, n_devices=-1)

Fit the model to the training data.

predict(X)

Predict the output for the given input data using the trained model.

sample_from_samples(x, n_devices=1, batch_size=-1)

Take samples from sampled Gaussian distributions based on model params.

predict_with_credibility_eti(X, a=0.95)

Make prediction with a credible interval.

predict_as_de(X, n_devices=-1)

Predict with model as a deep ensemble.

tree_flatten()

Serialize module into a JAX PyTree.

tree_unflatten(aux_data, children)

Build module from a serialized PyTree.

log_prior(params):

Calculate the log of the prior probability for a set of params.

logdensity_for_batch(params, carry, batch)

Evaluate log-density for a batch of data.

burn_in_loop(rng, params, n_burns, warmup)

Perform burn-in for sampler.

mcmc_sampling(model_states, rng_key, train, n_devices, parallel_batch_size, mask)

Perform MCMC-burn-in and sampling.

Methods

burn_in_loop(rng, params, n_burns, warmup)

Perform burn-in for sampler.

fit(X[, y, n_devices])

Fit the function to the given data.

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

Get parameters for this estimator.

history_description()

Make a readable version of the training history.

init_inner_params(n_features, optimizer, rng_key)

Create trainable model state.

load(path)

Load estimator from file.

log_prior(params)

Calculate the log of the prior probability for a set of params.

logdensity_for_batch(params, carry, batch)

Evaluate log-density for a batch of data.

mcmc_sampling(model_states, rng_key, train, ...)

Perform MCMC-burn-in and sampling.

predict(X)

Apply the fitted model to the input data.

predict_as_de(X[, n_devices])

Predict with model as a deep ensemble.

predict_with_credibility_eti(X[, a])

Make prediction with a credible interval.

sample_from_samples(x[, n_devices, batch_size])

Take samples from sampled Gaussian distributions based on model params.

save(path)

Save estimator to file.

set_fit_request(*[, n_devices])

Request metadata passed to the fit method.

set_params(**params)

Set the parameters of this estimator.

tree_flatten()

Specify how to serialize estimator into a JAX pytree.

tree_unflatten(aux_data, children)

Specify how to build an estimator from a JAX pytree.

static burn_in_loop(rng, params, n_burns, warmup)#

Perform burn-in for sampler.

Parameters:
fit(X, y=None, n_devices=-1)#

Fit the function to the given data.

Parameters:
X

The input data.

y

The labels. If y is None, X is assumed to include the labels as well.

n_devices

Number of devices to use for parallelization. -1 means using all available devices. If a number greater than all available devices is given, the max number of devices is used.

Returns:
BDEEstimator

The fitted estimator.

Return type:

BDEEstimator

Parameters:
get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:
routingMetadataRequest

A MetadataRequest encapsulating routing information.

get_params(deep=True)#

Get parameters for this estimator.

Parameters:
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:
paramsdict

Parameter names mapped to their values.

history_description()#

Make a readable version of the training history.

Returns:
Dict

Each key corresponds to an evaluation metric/ loss and each value is an array describing the values for different epochs.

Raises:
AssertionError

If the model is not fitted.

Return type:

Dict[str, Array]

init_inner_params(n_features, optimizer, rng_key)#

Create trainable model state.

Parameters:
n_features

Number of input features.

optimizer

Optimization algorithm used for training.

rng_key

Randomness key.

Returns:
train_state.TrainState

Initialized training state.

Return type:

TrainState

classmethod load(path)#

Load estimator from file.

Return type:

FullyConnectedEstimator

Parameters:

path (str | Path)

log_prior(params)#

Calculate the log of the prior probability for a set of params.

Parameters:
params

A PyTree of model parameters.

Returns:
float

The log-prior probability of the parameters.

Return type:

float

logdensity_for_batch(params, carry, batch)#

Evaluate log-density for a batch of data.

Parameters:
params

Parameters of the evaluated model.

carry

log-prior + logdensity of previous batches.

batch

Current batch to be evaluated.

Returns:
Tuple:
  • Updated logdensity (carry + current batch value).

  • None (used for compatibility with jax.lax.scan)

Return type:

Tuple[Array, None]

Parameters:
mcmc_sampling(model_states, rng_key, train, n_devices, parallel_batch_size, mask)#

Perform MCMC-burn-in and sampling.

Parameters:
model_states

Initial model states for sampler.

rng_key

A key used to initialize the random processes.

train

Training dataset. The dataset must include only 1 batch (full-batch).

n_devices

Exact number of computational devices used for parallelization.

parallel_batch_size

The number of chains to compute on each device when running parallel computations.

mask

A 2D-array indicating which chains are used for padding during parallelization (shape = (n_devices, batch_size)). - Chains corresponding to the value one in the mask will be evaluated

and sampled from.

  • Chains corresponding to the value zero in the mask will be ignored when possible and return pseudo-samples which can be discarded.

Returns:
List[Dict]

Samples from all mcmc-chains.

Return type:

List[Dict]

Parameters:
  • model_states (TrainState)

  • rng_key (PRNGKeyArray)

  • train (BasicDataset)

  • n_devices (int)

  • parallel_batch_size (int)

  • mask (Array)

predict(X)#

Apply the fitted model to the input data.

Parameters:
X

The input data.

Returns:
Array

Predicted values (mean of samples).

Return type:

Array

Parameters:

X (Array | ndarray | bool | number | bool | int | float | complex)

predict_as_de(X, n_devices=-1)#

Predict with model as a deep ensemble.

This method ignores the samples data and uses the initialization params only.

Parameters:
X

The input data.

n_devices

Number of devices to use for parallelization. -1 means using all available devices. If a number greater than all available devices is given, the max number of devices is used.

Returns:
Array

Predicted values (mean_of_predictions).

Return type:

Array

Parameters:
predict_with_credibility_eti(X, a=0.95)#

Make prediction with a credible interval.

Parameters:
X

The input data.

a

Size of credibility interval (in probability: 0 - 1).

Returns
——-
3 arrays with:
- Predicted values (median of samples).
- Lower value of confidence interval per prediction.
- Upper value of confidence interval per prediction.
Return type:

Tuple[Array, Array, Array]

Parameters:
sample_from_samples(x, n_devices=1, batch_size=-1)#

Take samples from sampled Gaussian distributions based on model params.

The mean and std predicted by the sampled models define Gaussian distributions. Take samples from these distributions.

Parameters:
X

The input data.

n_devices

batch_size

Returns:
Array

The last dim of the array is a list of samples taken from each predicted distribution in each batch. The shape is: (b_1, b_2, ..., b_2, output_size / 2, self.n_samples * self.n_chains)

Return type:

Array

Parameters:
save(path)#

Save estimator to file.

Return type:

None

Parameters:

path (str | Path)

set_fit_request(*, n_devices='$UNCHANGED$')#

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:
n_devicesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for n_devices parameter in fit.

Returns:
selfobject

The updated object.

Parameters:
Return type:

BDEEstimator

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:
**paramsdict

Estimator parameters.

Returns:
selfestimator instance

Estimator instance.

tree_flatten()#

Specify how to serialize estimator into a JAX pytree.

Returns:
Tuple[Sequence[ArrayLike], Any]
A tuple with 2 elements:
  • The children, containing arrays & pytrees.

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex, Dict]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build an estimator from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
BDEEstimator

Reconstructed estimator.

Return type:

BDEEstimator

Parameters:

Examples using bde.ml.models.BDEEstimator#

sphx_glr_auto_examples_example01.py

Plot Template Estimator.