BDEEstimator#
- class bde.ml.models.BDEEstimator(model_class=<class 'bde.ml.models.FullyConnectedModule'>, model_kwargs=None, n_chains=1, chain_len=1, warmup=1, n_samples=1, optimizer_class=<function adam>, optimizer_kwargs=None, loss=<bde.ml.loss.GaussianNLLLoss object>, batch_size=1, epochs=1, metrics=None, validation_size=None, seed=42)#
SKlearn-compatible implementation of a BDE estimator.
# TODO: Describe BDE estimator.
- Attributes:
- # TODO: List
- Parameters:
Methods
fit(X, y=None)
Fit the model to the training data.
predict(X)
Predict the output for the given input data using the trained model.
Methods
burn_in_loop
(rng, params, n_burns, warmup)Perform burn-in for sampler.
fit
(X[, y, n_devices])Fit the function to the given data.
Get metadata routing of this object.
get_params
([deep])Get parameters for this estimator.
Make a readable version of the training history.
init_inner_params
(n_features, optimizer, rng_key)Create trainable model state.
load
(path)Load estimator from file.
log_prior
(params)Calculate the log of the prior probability for a set of params.
logdensity_for_batch
(params, carry, batch)Evaluate log-density for a batch of data.
mcmc_sampling
(model_states, rng_key, train, ...)Perform MCMC-burn-in and sampling.
predict
(X)Apply the fitted model to the input data.
predict_as_de
(X[, n_devices])Predict with model as a deep ensemble.
predict_with_credibility_eti
(X[, a])Make prediction with a credible interval.
sample_from_samples
(x[, n_devices, batch_size])Take samples from sampled Gaussian distributions based on model params.
save
(path)Save estimator to file.
set_fit_request
(*[, n_devices])Request metadata passed to the
fit
method.set_params
(**params)Set the parameters of this estimator.
Specify how to serialize estimator into a JAX pytree.
tree_unflatten
(aux_data, children)Specify how to build an estimator from a JAX pytree.
- static burn_in_loop(rng, params, n_burns, warmup)#
Perform burn-in for sampler.
- fit(X, y=None, n_devices=-1)#
Fit the function to the given data.
- Parameters:
- X
The input data.
- y
The labels. If y is None, X is assumed to include the labels as well.
- n_devices
Number of devices to use for parallelization.
-1
means using all available devices. If a number greater than all available devices is given, the max number of devices is used.
- Returns:
- BDEEstimator
The fitted estimator.
- Return type:
- Parameters:
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
- routingMetadataRequest
A
MetadataRequest
encapsulating routing information.
- get_params(deep=True)#
Get parameters for this estimator.
- Parameters:
- deepbool, default=True
If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
- paramsdict
Parameter names mapped to their values.
- history_description()#
Make a readable version of the training history.
- init_inner_params(n_features, optimizer, rng_key)#
Create trainable model state.
- Parameters:
- n_features
Number of input features.
- optimizer
Optimization algorithm used for training.
- rng_key
Randomness key.
- Returns:
- train_state.TrainState
Initialized training state.
- Return type:
TrainState
- classmethod load(path)#
Load estimator from file.
- Return type:
- Parameters:
- log_prior(params)#
Calculate the log of the prior probability for a set of params.
- logdensity_for_batch(params, carry, batch)#
Evaluate log-density for a batch of data.
- Parameters:
- params
Parameters of the evaluated model.
- carry
log-prior + logdensity of previous batches.
- batch
Current batch to be evaluated.
- Returns:
- Tuple:
Updated logdensity (carry + current batch value).
None (used for compatibility with
jax.lax.scan
)
- Return type:
- Parameters:
- mcmc_sampling(model_states, rng_key, train, n_devices, parallel_batch_size, mask)#
Perform MCMC-burn-in and sampling.
- Parameters:
- model_states
Initial model states for sampler.
- rng_key
A key used to initialize the random processes.
- train
Training dataset. The dataset must include only 1 batch (full-batch).
- n_devices
Exact number of computational devices used for parallelization.
- parallel_batch_size
The number of chains to compute on each device when running parallel computations.
- mask
A 2D-array indicating which chains are used for padding during parallelization (shape = (
n_devices
,batch_size
)). - Chains corresponding to the valueone
in the mask will be evaluatedand sampled from.
Chains corresponding to the value
zero
in the mask will be ignored when possible and return pseudo-samples which can be discarded.
- Returns:
- List[Dict]
Samples from all mcmc-chains.
- Return type:
- Parameters:
- predict(X)#
Apply the fitted model to the input data.
- predict_as_de(X, n_devices=-1)#
Predict with model as a deep ensemble.
This method ignores the samples data and uses the initialization params only.
- Parameters:
- X
The input data.
- n_devices
Number of devices to use for parallelization.
-1
means using all available devices. If a number greater than all available devices is given, the max number of devices is used.
- Returns:
- Array
Predicted values (mean_of_predictions).
- Return type:
Array
- Parameters:
- predict_with_credibility_eti(X, a=0.95)#
Make prediction with a credible interval.
- Parameters:
- X
The input data.
- a
Size of credibility interval (in probability: 0 - 1).
- Returns
- ——-
- 3 arrays with:
- - Predicted values (median of samples).
- - Lower value of confidence interval per prediction.
- - Upper value of confidence interval per prediction.
- Return type:
Tuple
[Array
,Array
,Array
]- Parameters:
- sample_from_samples(x, n_devices=1, batch_size=-1)#
Take samples from sampled Gaussian distributions based on model params.
The mean and std predicted by the sampled models define Gaussian distributions. Take samples from these distributions.
- Parameters:
- X
The input data.
- n_devices
…
- batch_size
…
- Returns:
- Array
The last dim of the array is a list of samples taken from each predicted distribution in each batch. The shape is:
(b_1, b_2, ..., b_2, output_size / 2, self.n_samples * self.n_chains)
- Return type:
Array
- Parameters:
- set_fit_request(*, n_devices='$UNCHANGED$')#
Request metadata passed to the
fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed tofit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it tofit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.- Parameters:
- n_devicesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
n_devices
parameter infit
.
- Returns:
- selfobject
The updated object.
- Parameters:
self (BDEEstimator)
- Return type:
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters:
- **paramsdict
Estimator parameters.
- Returns:
- selfestimator instance
Estimator instance.
- tree_flatten()#
Specify how to serialize estimator into a JAX pytree.
- classmethod tree_unflatten(aux_data, children)#
Specify how to build an estimator from a JAX pytree.
Examples using bde.ml.models.BDEEstimator
#
sphx_glr_auto_examples_example01.py