bde.ml package#

Submodules#

bde.ml.loss module#

Loss Functions for Bayesian Neural Networks.

This module contains implementations of loss functions and their wrappers used in training Bayesian Neural Networks within the Bayesian Deep Ensembles (BDE) framework.

Classes#

  • Loss: Defines the API used by loss-related classes.

  • LossMSE: A callable class for computing MSE loss.

  • NLLLoss: Base class for Negative Log Likelihood Loss functions.

  • GaussianNLLLoss: A callable class for computing the Gaussian negative

    log-likelihood loss.

Functions#

  • flax_training_loss_wrapper_regression: Wraps a regression loss function for

    training.

  • flax_training_loss_wrapper_classification: Wraps a classification loss function

    for training.

class bde.ml.loss.GaussianNLLLoss(epsilon=1e-06, mean_weight=1.0, is_full=False)#

Bases: NLLLoss

Gaussian negative log likelihood loss.

A callable jax-supported class for computing the negative log likelihood loss of a Gaussian distribution. This loss is commonly used in probabilistic models to quantify the difference between the predicted probability distribution and the true labels.

Mathematically, it is defined as:

\[\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \frac{(\hat\mu - \mu)^2}{var} + \log{(2\pi)} ]\]

This implementation includes the following parameters:

\[\begin{split}\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \omega_{\text{mean weight}} \cdot \frac{(\hat\mu - \mu)^2}{var} + \begin{cases} \log{(2\pi)} && \text{"is full" is True } \\ 0 && \text{"is full" is False } \end{cases} ]\end{split}\]

where

\[var = max(\sigma^2, \epsilon)\]
Attributes:
paramsdict[str, …]

Defines loss-related parameters: - epsilon : float

A stability factor for the variance.

  • mean_weightfloat

    A scale factor for the mean.

  • is_fullbool

    If true include constant loss value, otherwise ignored.

Parameters:

Methods

__call__(y_true, y_pred)

Computes the log-likelihood loss for the given predictions and labels.

_split_pred(y_true, y_pred)

Splits the predicted values into predictions and their corresponding uncertainties.

apply_reduced()

Evaluates the reduced loss (inherited from base class).

apply_reduced(y_true, y_pred, **kwargs)#

Evaluate reduced the loss.

The loss is evaluated separately for each item in the batch, and the mean of these values is returned.

Parameters:
y_true

The ground truth.

y_pred

The prediction.

**kwargs

Other keywords that may be passed to the unreduced loss function.

Returns:
Array

The reduced loss value.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex]

Parameters:
params: dict[str, Any]#
tree_flatten()#

Specify how to serialize module into a JAX pytree.

Returns:
A tuple with 2 elements:
  • The children, containing arrays & pytrees

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
GaussianNLLLoss

Reconstructed loss function.

Return type:

GaussianNLLLoss

Parameters:
class bde.ml.loss.Loss#

Bases: ABC

An abstract class for implementing the API of loss functions.

apply_reduced(y_true, y_pred, **kwargs)#

Evaluate reduced the loss.

The loss is evaluated separately for each item in the batch, and the mean of these values is returned.

Parameters:
y_true

The ground truth.

y_pred

The prediction.

**kwargs

Other keywords that may be passed to the unreduced loss function.

Returns:
Array

The reduced loss value.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex]

Parameters:
abstract tree_flatten()#

Specify how to serialize module into a JAX pytree.

Returns:
A tuple with 2 elements:
  • The children, containing arrays & pytrees

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

abstract classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
Loss

Reconstructed loss function.

Return type:

Loss

Parameters:
class bde.ml.loss.LossMSE#

Bases: Loss

A class wrapper for MSE loss.

apply_reduced(y_true, y_pred, **kwargs)#

Evaluate reduced the loss.

The loss is evaluated separately for each item in the batch, and the mean of these values is returned.

Parameters:
y_true

The ground truth.

y_pred

The prediction.

**kwargs

Other keywords that may be passed to the unreduced loss function.

Returns:
Array

The reduced loss value.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex]

Parameters:
tree_flatten()#

Specify how to serialize module into a JAX pytree.

Returns:
A tuple with 2 elements:
  • The children, containing arrays & pytrees

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
LossMSE

Reconstructed loss function.

Return type:

LossMSE

Parameters:
class bde.ml.loss.NLLLoss#

Bases: Loss, ABC

Negative log likelihood loss.

A base class for loss classes representing the negative log likelihood loss from a certain probability distribution.

\[\ell_{\text{NLL-loss}} = -\log{\mathcal{P}(\text{data} | \text{model})}\]
apply_reduced(y_true, y_pred, **kwargs)#

Evaluate reduced the loss.

The loss is evaluated separately for each item in the batch, and the mean of these values is returned.

Parameters:
y_true

The ground truth.

y_pred

The prediction.

**kwargs

Other keywords that may be passed to the unreduced loss function.

Returns:
Array

The reduced loss value.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex]

Parameters:
params: dict[str, Any]#
abstract tree_flatten()#

Specify how to serialize module into a JAX pytree.

Returns:
A tuple with 2 elements:
  • The children, containing arrays & pytrees

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

abstract classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
Loss

Reconstructed loss function.

Return type:

Loss

Parameters:
bde.ml.loss.flax_training_loss_wrapper_classification(f_loss)#

Wrap a classification loss function for use in Flax training.

This function wraps a classification loss function so that it can be used in the training loop of a Flax model.

Parameters:
f_loss

The loss function to wrap. It should take the true labels and predicted labels as input and return the computed loss value.

Returns:
Callable[[TrainState, dict, tuple[ArrayLike, ArrayLike]], float]

A function that can be used in the training loop, taking the model state, parameters, and a batch of data as input and returning the loss.

Return type:

Callable[[TrainState, dict, tuple[Union[Array, ndarray, bool, number, bool, int, float, complex], Union[Array, ndarray, bool, number, bool, int, float, complex]]], float]

Parameters:

f_loss (Callable[[Array | ndarray | bool | number | bool | int | float | complex, Array | ndarray | bool | number | bool | int | float | complex], float])

bde.ml.loss.flax_training_loss_wrapper_regression(f_loss)#

Wrap a regression loss function for use in Flax training.

This function wraps a regression loss function so that it can be used in the training loop of a Flax model.

Parameters:
f_loss

The loss function to wrap. It should take the true labels and predicted labels as input and return the computed loss value.

Returns:
Callable[[TrainState, dict, tuple[ArrayLike, ArrayLike]], float]

A function that can be used in the training loop, taking the model state, parameters, and a batch of data as input and returning the loss.

Return type:

Callable[[TrainState, dict, tuple[Union[Array, ndarray, bool, number, bool, int, float, complex], Union[Array, ndarray, bool, number, bool, int, float, complex]]], float]

Parameters:

f_loss (Callable[[Array | ndarray | bool | number | bool | int | float | complex, Array | ndarray | bool | number | bool | int | float | complex], float])

bde.ml.models module#

Models.

This module contains classes and functions for defining and managing various neural network models used in the Bayesian Deep Ensembles (BDE) framework. It includes basic building blocks like fully connected layers and estimators that adhere to the scikit-learn API.

Classes#

  • BasicModule: An abstract base class defining an API for neural network modules.

  • FullyConnectedModule: A fully connected neural network module.

  • FullyConnectedEstimator: An SKlearn-compatible estimator for training models.

  • BDEEstimator: An SKlearn-compatible implementation of Bayesian Deep Ensembles (BDEs).

Functions#

  • init_dense_model: Utility function for initializing a fully connected dense model.

class bde.ml.models.BDEEstimator(model_class=<class 'bde.ml.models.FullyConnectedModule'>, model_kwargs=None, n_chains=1, chain_len=1, warmup=1, n_samples=1, optimizer_class=<function adam>, optimizer_kwargs=None, loss=<bde.ml.loss.GaussianNLLLoss object>, batch_size=1, epochs=1, metrics=None, validation_size=None, seed=42)#

Bases: FullyConnectedEstimator

SKlearn-compatible implementation of a BDE estimator.

# TODO: Describe BDE estimator.

Attributes:
# TODO: List
Parameters:

Methods

fit(X, y=None)

Fit the model to the training data.

predict(X)

Predict the output for the given input data using the trained model.

static burn_in_loop(rng, params, n_burns, warmup)#

Perform burn-in for sampler.

Parameters:
fit(X, y=None, n_devices=-1)#

Fit the function to the given data.

Parameters:
X

The input data.

y

The labels. If y is None, X is assumed to include the labels as well.

n_devices

Number of devices to use for parallelization. -1 means using all available devices. If a number greater than all available devices is given, the max number of devices is used.

Returns:
BDEEstimator

The fitted estimator.

Return type:

BDEEstimator

Parameters:
get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:
routingMetadataRequest

A MetadataRequest encapsulating routing information.

get_params(deep=True)#

Get parameters for this estimator.

Parameters:
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:
paramsdict

Parameter names mapped to their values.

history_description()#

Make a readable version of the training history.

Returns:
Dict

Each key corresponds to an evaluation metric/ loss and each value is an array describing the values for different epochs.

Raises:
AssertionError

If the model is not fitted.

Return type:

Dict[str, Array]

init_inner_params(n_features, optimizer, rng_key)#

Create trainable model state.

Parameters:
n_features

Number of input features.

optimizer

Optimization algorithm used for training.

rng_key

Randomness key.

Returns:
train_state.TrainState

Initialized training state.

Return type:

TrainState

classmethod load(path)#

Load estimator from file.

Return type:

FullyConnectedEstimator

Parameters:

path (str | Path)

log_prior(params)#

Calculate the log of the prior probability for a set of params.

logdensity_for_batch(params, carry, batch)#

Evaluate log-density for a batch of data.

Parameters:
params

Parameters of the evaluated model.

carry

log-prior + logdensity of previous batches.

batch

Current batch to be evaluated.

Returns:
Tuple:
  • Updated logdensity (carry + current batch value).

  • None (used for compatibility with jax.lax.scan)

Return type:

Tuple[Array, None]

Parameters:
mcmc_sampling(model_states, rng_key, train, n_devices, parallel_batch_size, mask)#

Perform MCMC-burn-in and sampling.

Parameters:
model_states

Initial model states for sampler.

rng_key

A key used to initialize the random processes.

train

Training dataset. The dataset must include only 1 batch (full-batch).

n_devices

Exact number of computational devices used for parallelization.

parallel_batch_size

The number of chains to compute on each device when running parallel computations.

mask

A 2D-array indicating which chains are used for padding during parallelization (shape = (n_devices, batch_size)). - Chains corresponding to the value one in the mask will be evaluated

and sampled from.

  • Chains corresponding to the value zero in the mask will be ignored when possible and return pseudo-samples which can be discarded.

Returns:
List[Dict]

Samples from all mcmc-chains.

Return type:

List[Dict]

Parameters:
  • model_states (TrainState)

  • rng_key (PRNGKeyArray)

  • train (BasicDataset)

  • n_devices (int)

  • parallel_batch_size (int)

  • mask (Array)

predict(X)#

Apply the fitted model to the input data.

Parameters:
X

The input data.

Returns:
Array

Predicted values (mean of samples).

Return type:

Array

Parameters:

X (Array | ndarray | bool | number | bool | int | float | complex)

predict_as_de(X, n_devices=-1)#

Predict with model as a deep ensemble.

This method ignores the samples data and uses the initialization params only.

Parameters:
X

The input data.

n_devices

Number of devices to use for parallelization. -1 means using all available devices. If a number greater than all available devices is given, the max number of devices is used.

Returns:
Array

Predicted values (mean_of_predictions).

Return type:

Array

Parameters:
predict_with_credibility_eti(X, a=0.95)#

Make prediction with a credible interval.

Parameters:
X

The input data.

a

Size of credibility interval (in probability: 0 - 1).

Returns
——-
3 arrays with:
- Predicted values (median of samples).
- Lower value of confidence interval per prediction.
- Upper value of confidence interval per prediction.
Return type:

Tuple[Array, Array, Array]

Parameters:
sample_from_samples(x, n_devices=1, batch_size=-1)#

Take samples from sampled Gaussian distributions based on model params.

The mean and std predicted by the sampled models define Gaussian distributions. Take samples from these distributions.

Parameters:
X

The input data.

n_devices

batch_size

Returns:
Array

The last dim of the array is a list of samples taken from each predicted distribution in each batch. The shape is: (b_1, b_2, ..., b_2, output_size / 2, self.n_samples * self.n_chains)

Return type:

Array

Parameters:
save(path)#

Save estimator to file.

Return type:

None

Parameters:

path (str | Path)

set_fit_request(*, n_devices='$UNCHANGED$')#

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:
n_devicesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for n_devices parameter in fit.

Returns:
selfobject

The updated object.

Parameters:
Return type:

BDEEstimator

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:
**paramsdict

Estimator parameters.

Returns:
selfestimator instance

Estimator instance.

tree_flatten()#

Specify how to serialize estimator into a JAX pytree.

Returns:
Tuple[Sequence[ArrayLike], Any]
A tuple with 2 elements:
  • The children, containing arrays & pytrees.

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex, Dict]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build an estimator from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
BDEEstimator

Reconstructed estimator.

Return type:

BDEEstimator

Parameters:
class bde.ml.models.BasicModule(n_output_params, n_input_params=None, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module, ABC

An abstract base class for easy inheritance and API implementation.

Attributes:
n_output_paramsUnion[int, list[int]]

The number of output parameters or the shape of the output tensor(s). Similar to n_input_params, this can be an integer or a list.

n_input_paramsOptional[Union[int, list[int]]]

The number of input parameters or the shape of the input tensor(s). This can be an integer for models with a single-input or a list of integers for multi-input models.

Parameters:
  • n_output_params (int | list[int])

  • n_input_params (int | list[int] | None)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

Methods

__call__(*args, **kwargs)

Abstract method to be implemented by subclasses, defining the API of a forward pass of the module.

apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Transformer(nn.Module):
...   def encode(self, x):
...     ...

>>> x = jnp.ones((16, 9))
>>> model = Transformer()
>>> variables = model.init(jax.random.key(0), x, method=Transformer.encode)

>>> encoded = model.apply(variables, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

>>> encoded = model.apply(variables, x, method=model.encode)

You can also pass a string to a callable attribute of the module. For example, the previous can be written as:

>>> encoded = model.apply(variables, x, method='encode')

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

>>> def other_fn(instance, x):
...   # instance.some_module_attr(...)
...   instance.encode
...   ...

>>> model.apply(variables, x, method=other_fn)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to apply. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, add_noise=False):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     if add_noise:
...       # Add gaussian noise
...       noise_key = self.make_rng('noise')
...       x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, x)
>>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs)

>>> rngs['noise'] = jax.random.key(0)
>>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # different output (key(1) vs key(0))
>>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1)

>>> del rngs['noise']
>>> # self.make_rng('noise') will default to using the 'params' RNG stream
>>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # same output (key(0))
>>> np.testing.assert_allclose(out1, out2)

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0))
>>> # same output (key(0))
>>> np.testing.assert_allclose(out2, out3)
Return type:

Any | tuple[Any, FrozenDict[str, Mapping[str, Any]] | dict[str, Any]]

Parameters:
Args:
variables: A dictionary containing variables keyed by variable

collections. See flax.core.variables for more details about variables.

*args: Named arguments passed to the specified apply method. rngs: a dict of PRNGKeys to initialize the PRNG sequences. The “params”

PRNG sequence is used to initialize parameters.

method: A function to call apply on. This is generally a function in the

module. If provided, applies this method. If not provided, applies the __call__ method of the module. A string can also be provided to specify a method by name.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

capture_intermediates: If True, captures intermediate return values of

all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

**kwargs: Keyword arguments passed to the specified apply method.

Returns:

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.

bind(variables, *args, rngs=None, mutable=False)#

Creates an interactive Module instance by binding variables and RNGs.

bind provides an “interactive” instance of a Module directly without transforming a function with apply. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.

Once the variables (and optionally RNGs) are bound to a Module it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. bind() should only be used for interactive experimentation, and in all other cases we strongly encourage users to use apply() instead.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = nn.Dense(3)
...     self.decoder = nn.Dense(5)
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> x = jnp.ones((16, 9))
>>> ae = AutoEncoder()
>>> variables = ae.init(jax.random.key(0), x)
>>> model = ae.bind(variables)
>>> z = model.encoder(x)
>>> x_reconstructed = model.decoder(z)
Return type:

TypeVar(M, bound= Module)

Parameters:
Args:
variables: A dictionary containing variables keyed by variable

collections. See flax.core.variables for more details about variables.

*args: Named arguments (not used). rngs: a dict of PRNGKeys to initialize the PRNG sequences. mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

Returns:

A copy of this instance with bound variables and RNGs.

clone(*, parent=None, _deep_clone=False, _reset_names=False, **updates)#

Creates a clone of this Module, with optionally updated arguments.

Return type:

TypeVar(M, bound= Module)

Parameters:
NOTE: end users are encouraged to use the copy method. clone is used

primarily for internal routines, and copy offers simpler arguments and better defaults.

Args:
parent: The parent of the clone. The clone will have no parent if no

explicit parent is specified.

_deep_clone: A boolean or a weak value dictionary to control deep cloning

of submodules. If True, submodules will be cloned recursively. If a weak value dictionary is passed, it will be used to cache cloned submodules. This flag is used by init/apply/bind to avoid scope leakage.

_reset_names: If True, name=None is also passed to submodules when

cloning. Resetting names in submodules is necessary when calling .unbind.

**updates: Attribute updates.

Returns:

A clone of the this Module with the updated attributes and parent.

copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)#

Creates a copy of this Module, with optionally updated arguments.

Return type:

TypeVar(M, bound= Module)

Parameters:
  • self (M)

  • parent (Scope | Module | _Sentinel | None)

  • name (str | None)

Args:
parent: The parent of the copy. By default the current module is taken

as parent if not explicitly specified.

name: A new name for the copied Module, by default a new automatic name

will be given.

**updates: Attribute updates.

Returns:

A copy of the this Module with the updated name, parent, and attributes.

get_variable(col, name, default=None)#

Retrieves the value of a Variable.

Return type:

TypeVar(T)

Parameters:
  • col (str)

  • name (str)

  • default (T | None)

Args:

col: the variable collection. name: the name of the variable. default: the default value to return if the variable does not exist in

this scope.

Returns:

The value of the input variable, of the default value if the variable doesn’t exist in this scope.

has_rng(name)#

Returns true if a PRNGSequence with name name exists.

Return type:

bool

Parameters:

name (str)

has_variable(col, name)#

Checks if a variable of given collection and name exists in this Module.

See flax.core.variables for more explanation on variables and collections.

Return type:

bool

Parameters:
Args:

col: The variable collection name. name: The name of the variable.

Returns:

True if the variable exists.

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns modified variables.

init takes as first argument either a single PRNGKey, or a dictionary mapping variable collections names to their PRNGKeys, and will call method (which is the module’s __call__ function by default) passing *args and **kwargs, and returns a dictionary of initialized variables.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, x, train=False)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to init. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     other_variable = self.variable(
...       'other_collection',
...       'other_variable',
...       lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape),
...       x,
...     )
...     x = x + other_variable.value
...
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)}
>>> variables0 = module.init(rngs, x)

>>> rngs['other_rng'] = jax.random.key(0)
>>> variables1 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables0['params'], variables1['params']
... )
>>> # different other_variable (key(1) vs key(0))
>>> np.testing.assert_raises(
...   AssertionError,
...   np.testing.assert_allclose,
...   variables0['other_collection']['other_variable'],
...   variables1['other_collection']['other_variable'],
... )

>>> del rngs['other_rng']
>>> # self.make_rng('other_rng') will default to using the 'params' RNG stream
>>> variables2 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables1['params'], variables2['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables1['other_collection']['other_variable'],
...   variables2['other_collection']['other_variable'],
... )

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> variables3 = module.init(jax.random.key(0), x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables2['params'], variables3['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables2['other_collection']['other_variable'],
...   variables3['other_collection']['other_variable'],
... )

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), x)

init is a light wrapper over apply, so other apply arguments like method, mutable, and capture_intermediates are also available.

Return type:

FrozenDict[str, Mapping[str, Any]] | dict[str, Any]

Parameters:
Args:

rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not

provided, applies the __call__ method. A string can also be provided to specify a method by name.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

capture_intermediates: If True, captures intermediate return values of

all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

**kwargs: Keyword arguments passed to the init function.

Returns:

The initialized variable dict.

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns output and modified variables.

Return type:

tuple[Any, FrozenDict[str, Mapping[str, Any]] | dict[str, Any]]

Parameters:
Args:

rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not

provided, applies the __call__ method. A string can also be provided to specify a method by name.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.

capture_intermediates: If True, captures intermediate return values of

all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

**kwargs: Keyword arguments passed to the init function.

Returns:

(output, vars), where vars are is a dict of the modified collections.

is_initializing()#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

Return type:

bool

is_mutable_collection(col)#

Returns true if the collection col is mutable.

Return type:

bool

Parameters:

col (str)

lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)#

Initializes a module without computing on an actual input.

lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a jax.ShapeDtypeStruct which specifies the shape and dtype of the input but no concrete data.

Example:

>>> model = nn.Dense(features=256)
>>> variables = model.lazy_init(
...     jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))

The args and kwargs args passed to lazy_init can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise lazy_init cannot infer which variables should be initialized.

Return type:

FrozenDict[str, Mapping[str, Any]]

Parameters:
Args:

rngs: The rngs for the variable collections. *args: arguments passed to the init function. method: An optional method. If provided, applies this method. If not

provided, applies the __call__ method.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

**kwargs: Keyword arguments passed to the init function.

Returns:

The initialized variable dict.

make_rng(name='params')#

Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility. :rtype: Array

Note

If an invalid name is passed (i.e. no RNG key was passed by the user in .init or .apply for this name), then name will default to 'params'.

Example:

>>> import jax
>>> import flax.linen as nn

>>> class ParamsModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('params')
>>> class OtherModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('other')

>>> key = jax.random.key(0)
>>> params_out, _ = ParamsModule().init_with_output({'params': key})
>>> # self.make_rng('other') will default to using the 'params' RNG stream
>>> other_out, _ = OtherModule().init_with_output({'params': key})
>>> assert params_out == other_out

Learn more about RNG’s by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html

Args:

name: The RNG sequence name.

Returns:

The newly generated RNG key.

Parameters:

name (str)

Return type:

Array

module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)#

Returns a dictionary mapping module paths to module instances.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns a dictionary mapping module paths to unbounded copies of module instances that were used at runtime. module_paths uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> modules = Foo().module_paths(jax.random.key(0), x)
>>> print({
...     p: type(m).__name__ for p, m in modules.items()
... })
{'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
Return type:

dict[str, Module]

Parameters:
Args:

rngs: The rngs for the variable collections as passed to Module.init. *args: The arguments to the forward computation. show_repeated: If True, repeated calls to the same module will be

shown in the table, otherwise only the first call will be shown. Default is False.

mutable: Can be bool, str, or list. Specifies which collections should

be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.

**kwargs: keyword arguments to pass to the forward computation.

Returns:

A dict`ionary mapping module paths to module instances.

n_input_params: Union[int, list[int], None] = None#
n_output_params: Union[int, list[int]]#
name: Optional[str] = None#
param(name, init_fn, *init_args, unbox=True, **init_kwargs)#

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named “params”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args or init_kwargs:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     mean = self.param('mean', nn.initializers.lecun_normal(), x.shape)
...     ...
...     return x * mean
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Return type:

Union[TypeVar(T), AxisMetadata[TypeVar(T)]]

Parameters:
Args:

name: The parameter name. init_fn: The function that will be called to compute the initial value of

this variable. This function will only be called the first time this parameter is used in this module.

*init_args: The positional arguments to pass to init_fn. unbox: If True, AxisMetadata instances are replaced by their unboxed

value, see flax.nn.meta.unbox (default: True).

**init_kwargs: The key-word arguments to pass to init_fn.

Returns:

The value of the initialized parameter. Throws an error if the parameter exists already.

parent: Union[Module, Scope, _Sentinel, None] = None#
property path#

Get the path of this Module. Top-level root modules have an empty path (). Note that this method can only be used on bound modules that have a valid scope.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class SubModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'SubModel path: {self.path}')
...     return x

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'Model path: {self.path}')
...     return SubModel()(x)

>>> model = Model()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 2)))
Model path: ()
SubModel path: ('SubModel_0',)
perturb(name, value, collection='perturbations')#

Add an zero-value variable (‘perturbation’) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation argument. :rtype: TypeVar(T)

Note

This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = self.perturb('dense3', x)
...     return nn.Dense(2)(x)

>>> def loss(variables, inputs, targets):
...   preds = model.apply(variables, inputs)
...   return jnp.square(preds - targets).mean()

>>> x = jnp.ones((2, 9))
>>> y = jnp.ones((2, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
>>> print(intm_grads['perturbations']['dense3'])
[[-1.456924   -0.44332537  0.02422847]
 [-1.456924   -0.44332537  0.02422847]]

If perturbations are not passed to apply, perturb behaves like a no-op so you can easily disable the behavior when not needed:

>>> model.apply(variables, x) # works as expected
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
>>> 'perturbations' not in intm_grads
True
Parameters:
  • name (str)

  • value (T)

  • collection (str)

Return type:

T

put_variable(col, name, value)#

Updates the value of the given variable if it is mutable, or an error otherwise.

Args:

col: the variable collection. name: the name of the variable. value: the new value of the variable.

Parameters:
scope: Scope | None = None#
setup()#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases: :rtype: None

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

Return type:

None

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     self.sow('intermediates', 'h', h)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(variables, x, mutable=['intermediates'])
>>> jax.tree.map(jnp.shape, state['intermediates'])
{'h': ((16, 4),)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

>>> class Foo2(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     init_fn = lambda: 0
...     reduce_fn = lambda a, b: a + b
...     self.sow('intermediates', 'h', x,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     self.sow('intermediates', 'h', x * 2,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     return x

>>> x = jnp.ones((1, 1))
>>> model = Foo2()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(
...     variables, x, mutable=['intermediates'])
>>> print(state['intermediates'])
{'h': Array([[3.]], dtype=float32)}
Return type:

bool

Parameters:
Args:

col: The name of the variable collection. name: The name of the variable. value: The value of the variable. reduce_fn: The function used to combine the existing value with the new

value. The default is to append the value to a tuple.

init_fn: For the first value stored, reduce_fn will be passed the result

of init_fn together with the value to be stored. The default is an empty tuple.

Returns:

True if the value has been stored successfully, False otherwise.

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)#

Creates a summary of the Module represented as a table.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Additional arguments can be passed into the console_kwargs argument, for example, {'width': 120}. For a full list of console_kwargs arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))

>>> # print(Foo().tabulate(
>>> #     jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))

This gives the following output:

                                      Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                              Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Note: vjp_flops returns 0 if the module is not differentiable.

Return type:

str

Parameters:
Args:

rngs: The rngs for the variable collections as passed to Module.init. *args: The arguments to the forward computation. depth: controls how many submodule deep the summary can go. By default,

its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

show_repeated: If True, repeated calls to the same module will be shown

in the table, otherwise only the first call will be shown. Default is False.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.

console_kwargs: An optional dictionary with additional keyword arguments

that are passed to rich.console.Console when rendering the table. Default arguments are {'force_terminal': True, 'force_jupyter': False}.

table_kwargs: An optional dictionary with additional keyword arguments

that are passed to rich.table.Table constructor.

column_kwargs: An optional dictionary with additional keyword arguments

that are passed to rich.table.Table.add_column when adding columns to the table.

compute_flops: whether to include a flops column in the table listing

the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion’s UNet, whereas otherwise tabulation would finish in 5 seconds).

compute_vjp_flops: whether to include a vjp_flops column in the table

listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of compute_flops.

**kwargs: keyword arguments to pass to the forward computation.

Returns:

A string summarizing the Module.

tree_flatten()#

Specify how to serialize module into a JAX pytree.

Returns:
Tuple[Sequence[ArrayLike], Any]
A tuple with 2 elements:
  • The children, containing arrays & pytrees

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

abstract classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
FullyConnectedModule

Reconstructed Module.

Return type:

FullyConnectedModule

Parameters:
unbind()#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

>>> class Encoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(256)(x)

>>> class Decoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(784)(x)

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = Encoder()
...     self.decoder = Decoder()
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> module = AutoEncoder()
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 784)))

>>> # Extract the Encoder sub-Module and its variables
>>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
Return type:

tuple[TypeVar(M, bound= Module), Mapping[str, Mapping[str, Any]]]

Parameters:

self (M)

Returns:

A tuple with an unbound copy of this Module and its variables.

variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)#

Declares and returns a variable in this Module.

See flax.core.variables for more information. See also param() for a shorthand way to define read-only variables in the “params” collection.

Contrary to param(), all arguments passing using init_fn should be passed on explicitly:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     key = self.make_rng('stats')
...     mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape)
...     ...
...     return x * mean.value
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling init() and apply().

Return type:

Union[Variable[TypeVar(T)], Variable[AxisMetadata[TypeVar(T)]]]

Parameters:
Args:

col: The variable collection name. name: The variable name. init_fn: The function that will be called to compute the initial value of

this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.

*init_args: The positional arguments to pass to init_fn. unbox: If True, AxisMetadata instances are replaced by their unboxed

value, see flax.nn.meta.unbox (default: True).

**init_kwargs: The key-word arguments to pass to init_fn

Returns:

A flax.core.variables.Variable that can be read or set via “.value” attribute. Throws an error if the variable exists already.

property variables: Mapping[str, Mapping[str, Any]]#

Returns the variables in this module.

class bde.ml.models.FullyConnectedEstimator(model_class=<class 'bde.ml.models.FullyConnectedModule'>, model_kwargs=None, optimizer_class=<function adam>, optimizer_kwargs=None, loss=<bde.ml.loss.LossMSE object>, batch_size=1, epochs=1, metrics=None, validation_size=None, seed=42, **kwargs)#

Bases: BaseEstimator

SKlearn-compatible estimator for training fully connected neural networks with Jax.

The FullyConnectedEstimator class wraps a Flax-based neural network model into an SKlearn-style estimator, providing a compatible interface for fitting, predicting, and evaluating models.

Attributes:
# TODO: List
Parameters:

Methods

fit(X, y=None)

Fit the model to the training data.

predict(X)

Predict the output for the given input data using the trained model.

_more_tags()

Used by the SKlearn API to set model tags.

fit(X, y=None)#

Fit the function to the given data.

Parameters:
X

The input data.

y

The labels. If y is None, X is assumed to include the labels as well.

Returns:
FullyConnectedEstimator

The fitted estimator.

Return type:

FullyConnectedEstimator

Parameters:
get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:
routingMetadataRequest

A MetadataRequest encapsulating routing information.

get_params(deep=True)#

Get parameters for this estimator.

Parameters:
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:
paramsdict

Parameter names mapped to their values.

history_description()#

Make a readable version of the training history.

Returns:
Dict

Each key corresponds to an evaluation metric/ loss and each value is an array describing the values for different epochs.

Raises:
AssertionError

If the model is not fitted.

Return type:

Dict[str, Array]

init_inner_params(n_features, optimizer, rng_key)#

Create trainable model state.

Parameters:
n_features

Number of input features.

optimizer

Optimization algorithm used for training.

rng_key

Randomness key.

Returns:
train_state.TrainState

Initialized training state.

Return type:

TrainState

classmethod load(path)#

Load estimator from file.

Return type:

FullyConnectedEstimator

Parameters:

path (str | Path)

predict(X)#

Apply the fitted model to the input data.

Parameters:
X

The input data.

Returns:
Array

Predicted labels.

Return type:

Array

Parameters:

X (Array | ndarray | bool | number | bool | int | float | complex)

save(path)#

Save estimator to file.

Return type:

None

Parameters:

path (str | Path)

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:
**paramsdict

Estimator parameters.

Returns:
selfestimator instance

Estimator instance.

tree_flatten()#

Specify how to serialize estimator into a JAX pytree.

Returns:
Tuple[Sequence[ArrayLike], Any]
A tuple with 2 elements:
  • The children, containing arrays & pytrees.

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex, Dict]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build an estimator from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
FullyConnectedEstimator

Reconstructed estimator.

Return type:

FullyConnectedEstimator

Parameters:
class bde.ml.models.FullyConnectedModule(n_output_params, n_input_params=None, layer_sizes=None, do_final_activation=True, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: BasicModule

A class for easy initialization of fully connected neural networks with flax.

This class allows for the creation of fully connected neural networks with a variable number of layers and neurons per layer. This class implements the API defined by BasicModule.

Attributes:
n_output_paramsint

The number of output features or neurons in the output layer.

n_input_paramsOptional[int]

The number of input features or neurons in the input layer. If None, the number if determined based on the used params (usually determined by the data used for fitting).

layer_sizesOptional[Union[Iterable[int], int]], optional

The number of neurons in each hidden layer. If an integer is provided, a single hidden layer with that many neurons is created. If an iterable of integers is provided, multiple hidden layers are created with the specified number of neurons. Default is None, which implies no hidden layers (only an input layer and an output layer).

do_final_activationbool, optional

Whether to apply an activation function to the output layer. Default is True, meaning the final layer will have an activation function (softmax).

Parameters:
  • n_output_params (int)

  • n_input_params (int | None)

  • layer_sizes (Iterable[int] | int | None)

  • do_final_activation (bool)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

Methods

__call__(x)

Define the forward pass of the fully connected network.

apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Transformer(nn.Module):
...   def encode(self, x):
...     ...

>>> x = jnp.ones((16, 9))
>>> model = Transformer()
>>> variables = model.init(jax.random.key(0), x, method=Transformer.encode)

>>> encoded = model.apply(variables, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

>>> encoded = model.apply(variables, x, method=model.encode)

You can also pass a string to a callable attribute of the module. For example, the previous can be written as:

>>> encoded = model.apply(variables, x, method='encode')

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

>>> def other_fn(instance, x):
...   # instance.some_module_attr(...)
...   instance.encode
...   ...

>>> model.apply(variables, x, method=other_fn)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to apply. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, add_noise=False):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     if add_noise:
...       # Add gaussian noise
...       noise_key = self.make_rng('noise')
...       x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, x)
>>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs)

>>> rngs['noise'] = jax.random.key(0)
>>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # different output (key(1) vs key(0))
>>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1)

>>> del rngs['noise']
>>> # self.make_rng('noise') will default to using the 'params' RNG stream
>>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # same output (key(0))
>>> np.testing.assert_allclose(out1, out2)

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0))
>>> # same output (key(0))
>>> np.testing.assert_allclose(out2, out3)
Return type:

Any | tuple[Any, FrozenDict[str, Mapping[str, Any]] | dict[str, Any]]

Parameters:
Args:
variables: A dictionary containing variables keyed by variable

collections. See flax.core.variables for more details about variables.

*args: Named arguments passed to the specified apply method. rngs: a dict of PRNGKeys to initialize the PRNG sequences. The “params”

PRNG sequence is used to initialize parameters.

method: A function to call apply on. This is generally a function in the

module. If provided, applies this method. If not provided, applies the __call__ method of the module. A string can also be provided to specify a method by name.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

capture_intermediates: If True, captures intermediate return values of

all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

**kwargs: Keyword arguments passed to the specified apply method.

Returns:

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.

bind(variables, *args, rngs=None, mutable=False)#

Creates an interactive Module instance by binding variables and RNGs.

bind provides an “interactive” instance of a Module directly without transforming a function with apply. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.

Once the variables (and optionally RNGs) are bound to a Module it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. bind() should only be used for interactive experimentation, and in all other cases we strongly encourage users to use apply() instead.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = nn.Dense(3)
...     self.decoder = nn.Dense(5)
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> x = jnp.ones((16, 9))
>>> ae = AutoEncoder()
>>> variables = ae.init(jax.random.key(0), x)
>>> model = ae.bind(variables)
>>> z = model.encoder(x)
>>> x_reconstructed = model.decoder(z)
Return type:

TypeVar(M, bound= Module)

Parameters:
Args:
variables: A dictionary containing variables keyed by variable

collections. See flax.core.variables for more details about variables.

*args: Named arguments (not used). rngs: a dict of PRNGKeys to initialize the PRNG sequences. mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

Returns:

A copy of this instance with bound variables and RNGs.

clone(*, parent=None, _deep_clone=False, _reset_names=False, **updates)#

Creates a clone of this Module, with optionally updated arguments.

Return type:

TypeVar(M, bound= Module)

Parameters:
NOTE: end users are encouraged to use the copy method. clone is used

primarily for internal routines, and copy offers simpler arguments and better defaults.

Args:
parent: The parent of the clone. The clone will have no parent if no

explicit parent is specified.

_deep_clone: A boolean or a weak value dictionary to control deep cloning

of submodules. If True, submodules will be cloned recursively. If a weak value dictionary is passed, it will be used to cache cloned submodules. This flag is used by init/apply/bind to avoid scope leakage.

_reset_names: If True, name=None is also passed to submodules when

cloning. Resetting names in submodules is necessary when calling .unbind.

**updates: Attribute updates.

Returns:

A clone of the this Module with the updated attributes and parent.

copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)#

Creates a copy of this Module, with optionally updated arguments.

Return type:

TypeVar(M, bound= Module)

Parameters:
  • self (M)

  • parent (Scope | Module | _Sentinel | None)

  • name (str | None)

Args:
parent: The parent of the copy. By default the current module is taken

as parent if not explicitly specified.

name: A new name for the copied Module, by default a new automatic name

will be given.

**updates: Attribute updates.

Returns:

A copy of the this Module with the updated name, parent, and attributes.

do_final_activation: bool = True#
get_variable(col, name, default=None)#

Retrieves the value of a Variable.

Return type:

TypeVar(T)

Parameters:
  • col (str)

  • name (str)

  • default (T | None)

Args:

col: the variable collection. name: the name of the variable. default: the default value to return if the variable does not exist in

this scope.

Returns:

The value of the input variable, of the default value if the variable doesn’t exist in this scope.

has_rng(name)#

Returns true if a PRNGSequence with name name exists.

Return type:

bool

Parameters:

name (str)

has_variable(col, name)#

Checks if a variable of given collection and name exists in this Module.

See flax.core.variables for more explanation on variables and collections.

Return type:

bool

Parameters:
Args:

col: The variable collection name. name: The name of the variable.

Returns:

True if the variable exists.

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns modified variables.

init takes as first argument either a single PRNGKey, or a dictionary mapping variable collections names to their PRNGKeys, and will call method (which is the module’s __call__ function by default) passing *args and **kwargs, and returns a dictionary of initialized variables.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, x, train=False)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to init. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     other_variable = self.variable(
...       'other_collection',
...       'other_variable',
...       lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape),
...       x,
...     )
...     x = x + other_variable.value
...
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)}
>>> variables0 = module.init(rngs, x)

>>> rngs['other_rng'] = jax.random.key(0)
>>> variables1 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables0['params'], variables1['params']
... )
>>> # different other_variable (key(1) vs key(0))
>>> np.testing.assert_raises(
...   AssertionError,
...   np.testing.assert_allclose,
...   variables0['other_collection']['other_variable'],
...   variables1['other_collection']['other_variable'],
... )

>>> del rngs['other_rng']
>>> # self.make_rng('other_rng') will default to using the 'params' RNG stream
>>> variables2 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables1['params'], variables2['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables1['other_collection']['other_variable'],
...   variables2['other_collection']['other_variable'],
... )

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> variables3 = module.init(jax.random.key(0), x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables2['params'], variables3['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables2['other_collection']['other_variable'],
...   variables3['other_collection']['other_variable'],
... )

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), x)

init is a light wrapper over apply, so other apply arguments like method, mutable, and capture_intermediates are also available.

Return type:

FrozenDict[str, Mapping[str, Any]] | dict[str, Any]

Parameters:
Args:

rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not

provided, applies the __call__ method. A string can also be provided to specify a method by name.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

capture_intermediates: If True, captures intermediate return values of

all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

**kwargs: Keyword arguments passed to the init function.

Returns:

The initialized variable dict.

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns output and modified variables.

Return type:

tuple[Any, FrozenDict[str, Mapping[str, Any]] | dict[str, Any]]

Parameters:
Args:

rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not

provided, applies the __call__ method. A string can also be provided to specify a method by name.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.

capture_intermediates: If True, captures intermediate return values of

all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

**kwargs: Keyword arguments passed to the init function.

Returns:

(output, vars), where vars are is a dict of the modified collections.

is_initializing()#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

Return type:

bool

is_mutable_collection(col)#

Returns true if the collection col is mutable.

Return type:

bool

Parameters:

col (str)

layer_sizes: Union[Iterable[int], int, None] = None#
lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)#

Initializes a module without computing on an actual input.

lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a jax.ShapeDtypeStruct which specifies the shape and dtype of the input but no concrete data.

Example:

>>> model = nn.Dense(features=256)
>>> variables = model.lazy_init(
...     jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))

The args and kwargs args passed to lazy_init can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise lazy_init cannot infer which variables should be initialized.

Return type:

FrozenDict[str, Mapping[str, Any]]

Parameters:
Args:

rngs: The rngs for the variable collections. *args: arguments passed to the init function. method: An optional method. If provided, applies this method. If not

provided, applies the __call__ method.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

**kwargs: Keyword arguments passed to the init function.

Returns:

The initialized variable dict.

make_rng(name='params')#

Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility. :rtype: Array

Note

If an invalid name is passed (i.e. no RNG key was passed by the user in .init or .apply for this name), then name will default to 'params'.

Example:

>>> import jax
>>> import flax.linen as nn

>>> class ParamsModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('params')
>>> class OtherModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('other')

>>> key = jax.random.key(0)
>>> params_out, _ = ParamsModule().init_with_output({'params': key})
>>> # self.make_rng('other') will default to using the 'params' RNG stream
>>> other_out, _ = OtherModule().init_with_output({'params': key})
>>> assert params_out == other_out

Learn more about RNG’s by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html

Args:

name: The RNG sequence name.

Returns:

The newly generated RNG key.

Parameters:

name (str)

Return type:

Array

module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)#

Returns a dictionary mapping module paths to module instances.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns a dictionary mapping module paths to unbounded copies of module instances that were used at runtime. module_paths uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> modules = Foo().module_paths(jax.random.key(0), x)
>>> print({
...     p: type(m).__name__ for p, m in modules.items()
... })
{'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
Return type:

dict[str, Module]

Parameters:
Args:

rngs: The rngs for the variable collections as passed to Module.init. *args: The arguments to the forward computation. show_repeated: If True, repeated calls to the same module will be

shown in the table, otherwise only the first call will be shown. Default is False.

mutable: Can be bool, str, or list. Specifies which collections should

be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.

**kwargs: keyword arguments to pass to the forward computation.

Returns:

A dict`ionary mapping module paths to module instances.

n_input_params: Optional[int] = None#
n_output_params: int#
name: Optional[str] = None#
param(name, init_fn, *init_args, unbox=True, **init_kwargs)#

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named “params”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args or init_kwargs:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     mean = self.param('mean', nn.initializers.lecun_normal(), x.shape)
...     ...
...     return x * mean
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Return type:

Union[TypeVar(T), AxisMetadata[TypeVar(T)]]

Parameters:
Args:

name: The parameter name. init_fn: The function that will be called to compute the initial value of

this variable. This function will only be called the first time this parameter is used in this module.

*init_args: The positional arguments to pass to init_fn. unbox: If True, AxisMetadata instances are replaced by their unboxed

value, see flax.nn.meta.unbox (default: True).

**init_kwargs: The key-word arguments to pass to init_fn.

Returns:

The value of the initialized parameter. Throws an error if the parameter exists already.

parent: Union[Module, Scope, _Sentinel, None] = None#
property path#

Get the path of this Module. Top-level root modules have an empty path (). Note that this method can only be used on bound modules that have a valid scope.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class SubModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'SubModel path: {self.path}')
...     return x

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'Model path: {self.path}')
...     return SubModel()(x)

>>> model = Model()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 2)))
Model path: ()
SubModel path: ('SubModel_0',)
perturb(name, value, collection='perturbations')#

Add an zero-value variable (‘perturbation’) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation argument. :rtype: TypeVar(T)

Note

This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = self.perturb('dense3', x)
...     return nn.Dense(2)(x)

>>> def loss(variables, inputs, targets):
...   preds = model.apply(variables, inputs)
...   return jnp.square(preds - targets).mean()

>>> x = jnp.ones((2, 9))
>>> y = jnp.ones((2, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
>>> print(intm_grads['perturbations']['dense3'])
[[-1.456924   -0.44332537  0.02422847]
 [-1.456924   -0.44332537  0.02422847]]

If perturbations are not passed to apply, perturb behaves like a no-op so you can easily disable the behavior when not needed:

>>> model.apply(variables, x) # works as expected
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
>>> 'perturbations' not in intm_grads
True
Parameters:
  • name (str)

  • value (T)

  • collection (str)

Return type:

T

put_variable(col, name, value)#

Updates the value of the given variable if it is mutable, or an error otherwise.

Args:

col: the variable collection. name: the name of the variable. value: the new value of the variable.

Parameters:
scope: Scope | None = None#
setup()#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases: :rtype: None

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

Return type:

None

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     self.sow('intermediates', 'h', h)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(variables, x, mutable=['intermediates'])
>>> jax.tree.map(jnp.shape, state['intermediates'])
{'h': ((16, 4),)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

>>> class Foo2(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     init_fn = lambda: 0
...     reduce_fn = lambda a, b: a + b
...     self.sow('intermediates', 'h', x,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     self.sow('intermediates', 'h', x * 2,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     return x

>>> x = jnp.ones((1, 1))
>>> model = Foo2()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(
...     variables, x, mutable=['intermediates'])
>>> print(state['intermediates'])
{'h': Array([[3.]], dtype=float32)}
Return type:

bool

Parameters:
Args:

col: The name of the variable collection. name: The name of the variable. value: The value of the variable. reduce_fn: The function used to combine the existing value with the new

value. The default is to append the value to a tuple.

init_fn: For the first value stored, reduce_fn will be passed the result

of init_fn together with the value to be stored. The default is an empty tuple.

Returns:

True if the value has been stored successfully, False otherwise.

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)#

Creates a summary of the Module represented as a table.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Additional arguments can be passed into the console_kwargs argument, for example, {'width': 120}. For a full list of console_kwargs arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))

>>> # print(Foo().tabulate(
>>> #     jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))

This gives the following output:

                                      Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                              Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Note: vjp_flops returns 0 if the module is not differentiable.

Return type:

str

Parameters:
Args:

rngs: The rngs for the variable collections as passed to Module.init. *args: The arguments to the forward computation. depth: controls how many submodule deep the summary can go. By default,

its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

show_repeated: If True, repeated calls to the same module will be shown

in the table, otherwise only the first call will be shown. Default is False.

mutable: Can be bool, str, or list. Specifies which collections should be

treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.

console_kwargs: An optional dictionary with additional keyword arguments

that are passed to rich.console.Console when rendering the table. Default arguments are {'force_terminal': True, 'force_jupyter': False}.

table_kwargs: An optional dictionary with additional keyword arguments

that are passed to rich.table.Table constructor.

column_kwargs: An optional dictionary with additional keyword arguments

that are passed to rich.table.Table.add_column when adding columns to the table.

compute_flops: whether to include a flops column in the table listing

the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion’s UNet, whereas otherwise tabulation would finish in 5 seconds).

compute_vjp_flops: whether to include a vjp_flops column in the table

listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of compute_flops.

**kwargs: keyword arguments to pass to the forward computation.

Returns:

A string summarizing the Module.

tree_flatten()#

Specify how to serialize module into a JAX pytree.

Returns:
Tuple[Sequence[ArrayLike], Any]
A tuple with 2 elements:
  • The children, containing arrays & pytrees (empty).

  • The aux_data, containing static and hashable data (4 items).

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data (4 elements).

children

Contain arrays & pytrees. Not used by this class - Should be empty.

Returns:
FullyConnectedModule

Reconstructed Module.

Return type:

FullyConnectedModule

Parameters:
unbind()#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

>>> class Encoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(256)(x)

>>> class Decoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(784)(x)

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = Encoder()
...     self.decoder = Decoder()
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> module = AutoEncoder()
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 784)))

>>> # Extract the Encoder sub-Module and its variables
>>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
Return type:

tuple[TypeVar(M, bound= Module), Mapping[str, Mapping[str, Any]]]

Parameters:

self (M)

Returns:

A tuple with an unbound copy of this Module and its variables.

variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)#

Declares and returns a variable in this Module.

See flax.core.variables for more information. See also param() for a shorthand way to define read-only variables in the “params” collection.

Contrary to param(), all arguments passing using init_fn should be passed on explicitly:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     key = self.make_rng('stats')
...     mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape)
...     ...
...     return x * mean.value
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling init() and apply().

Return type:

Union[Variable[TypeVar(T)], Variable[AxisMetadata[TypeVar(T)]]]

Parameters:
Args:

col: The variable collection name. name: The variable name. init_fn: The function that will be called to compute the initial value of

this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.

*init_args: The positional arguments to pass to init_fn. unbox: If True, AxisMetadata instances are replaced by their unboxed

value, see flax.nn.meta.unbox (default: True).

**init_kwargs: The key-word arguments to pass to init_fn

Returns:

A flax.core.variables.Variable that can be read or set via “.value” attribute. Throws an error if the variable exists already.

property variables: Mapping[str, Mapping[str, Any]]#

Returns the variables in this module.

bde.ml.models.init_dense_model(model, batch_size=1, n_features=None, seed=42)#

Fast initialization for a fully connected dense network.

Parameters:
model

A model object.

batch_size

The batch size for training.

n_features

The size of the input layer. If it is set to None, it is inferred based on the provided model.

seed

A seed or a PRNGKey for initialization.

Returns:
Tuple[dict, Array]
A tuple with:
  • A parameters dict,

  • The input used for the initialization.

Return type:

Tuple[dict, Array]

Parameters:
bde.ml.models.init_dense_model_jitted(model, rng_key, batch_size=1, n_features=1)#

Fast initialization for a fully connected dense network.

A jitted version of init_dense_model().

Parameters:
model

A model object.

rng_key

A PRNGKey used for randomness in initialization.

batch_size

The batch size for training.

n_features

The size of the input layer. If it is set to None, it is inferred based on the provided model.

Returns:
Tuple[dict, Array]
A tuple with:
  • A parameters dict,

  • The input used for the initialization.

Return type:

Tuple[dict, Array]

Parameters:

bde.ml.training module#

Training Utilities for Bayesian Neural Networks.

This module provides functionality for training Bayesian Neural Networks within the Bayesian Deep Ensembles (BDE) framework.

Functions#

  • train_step: Executes a single optimization step for the neural network.

  • jitted_training: Fits a model over data for a parameters-set.

  • jitted_training_epoch: Performs 1 training epoch for model training

    (parameter optimization + metrics evaluation + validation).

bde.ml.training.jitted_evaluation_for_a_metric(model_state, batches, metrics, history, idx_metric, idx_history, idx_epoch)#

Evaluate a training epoch for 1 metric.

Parameters:
TODO: Complete
Returns:
TODO: Complete
Parameters:
  • model_state (TrainState)

  • batches (BasicDataset)

  • history (Array)

  • idx_metric (int)

  • idx_history (int)

  • idx_epoch (int)

bde.ml.training.jitted_evaluation_over_batch(model_state, batches, f_eval, num_batch, m_val)#

Perform intermediate evaluation over a metric for 1 batch of data.

Parameters:
TODO: Complete
Returns:
TODO: Complete
Return type:

float

Parameters:
  • model_state (TrainState)

  • batches (BasicDataset)

  • num_batch (int)

  • m_val (float)

bde.ml.training.jitted_training(model_state, epochs, f_loss, metrics, train, valid)#

Train a model on a single parameters set.

A jitted training loop for a model using a single parameter set.

Parameters:
model_state

A class containing the model architecture + training parameters 6 optimizer.

epochs

An array with the indices of the training epochs.

f_loss

A class implementing the optimized loss function.

metrics

An array of metric classes.

train

The training dataset.

valid

The validation dataset.

Returns:
Tuple[TrainState, Array]

Updated training state and an array describing the metrics over the training epochs.

Return type:

Tuple[TrainState, Array]

Parameters:
  • model_state (TrainState)

  • epochs (Array)

  • f_loss (Loss)

  • metrics (Array)

  • train (BasicDataset)

  • valid (BasicDataset)

bde.ml.training.jitted_training_epoch(model_state, train, valid, f_loss, metrics)#

Train a model for 1 epoch.

A jitted training loop for a model over a single epoch. Performs training, metrics evaluation and validation.

Parameters:
model_state

A class containing the model architecture + training parameters 6 optimizer.

train

The training dataset.

valid

The validation dataset.

f_loss

A class implementing the optimized loss function.

metrics

An array of metric classes.

Returns:
Tuple[Tuple[TrainState, BasicDataset, BasicDataset], Array]

2 items are returned: - The first item is a triplet containing:

  • Updated model state.

  • Updated training dataset (updates shuffling).

  • Updated validation dataset (updates shuffling).

  • The 2nd item is a 1D-array describing the evaluation of all metrics over this epoch.

Return type:

Tuple[Tuple[TrainState, BasicDataset, BasicDataset], Array]

Parameters:
  • model_state (TrainState)

  • train (BasicDataset)

  • valid (BasicDataset)

  • f_loss (Loss)

  • metrics (Array)

bde.ml.training.train_step(state, batch, f_loss)#

Perform an optimization step for the network.

This function updates the model parameters by performing a single optimization step using the provided loss function.

Parameters:
state

The training-state of the network.

batch
Input data-points for the training set, containing 2 items:
  • A set of training data-points.

  • The corresponding labels.

f_loss
The loss function used while training. Should have the following signature:

(y_true, y_pred)

Returns:
Tuple[TrainState, float]

Updated state of the network and the loss.

Return type:

Tuple[TrainState, float]

Parameters:

Module contents#

Machine Learning Module for Bayesian Deep Ensembles (BDE).

The bde.ml module provides the core machine learning components required for building and training Bayesian Neural Networks within the Bayesian Deep Ensembles (BDE) framework.

This module includes submodules for defining loss functions, neural network models, and training procedures, enabling flexible and robust implementation of BDE models.

Submodules#

  • datasets: Handles data and dataset management.

  • loss: Contains loss functions implementations and loss function related utilities.

  • models: Defines the neural network architectures supported by the BDE framework.

  • training: Implements the training algorithms and routines used for model optimization.

Example Usage#

# TODO: Provide examples

>>> # TODO: Provide an example
>>>
>>>
>>>