GaussianNLLLoss#

class bde.ml.loss.GaussianNLLLoss(epsilon=1e-06, mean_weight=1.0, is_full=True)#

Gaussian negative log likelihood loss.

A callable jax-supported class for computing the negative log likelihood loss of a Gaussian distribution. This loss is commonly used in probabilistic models to quantify the difference between the predicted probability distribution and the true labels.

Mathematically, it is defined as:

\[\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \frac{(\hat\mu - \mu)^2}{var} + \log{(2\pi)} ]\]

This implementation includes the following parameters:

\[\begin{split}\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \omega_{\text{mean weight}} \cdot \frac{(\hat\mu - \mu)^2}{var} + \begin{cases} \log{(2\pi)} && \text{"is full" is True } \\ 0 && \text{"is full" is False } \end{cases} ]\end{split}\]

where

\[var = max(\sigma^2, \epsilon)\]
Attributes:
paramsdict[str, …]

Defines loss-related parameters: - epsilon : float

A stability factor for the variance.

  • mean_weightfloat

    A scale factor for the mean.

  • is_fullbool

    If true include constant loss value, otherwise ignored.

Parameters:

Methods

__call__(y_true, y_pred)

Computes the log-likelihood loss for the predicted parametrization of the Gaussian distribution, given the provided labels.

_split_pred(y_true, y_pred)

Splits the predicted values into predictions of mean and std of Gaussian distributions.

apply_reduced()

Evaluates and reduces the loss.

tree_flatten()

Used to turn the class into a jitible PyTree.

tree_unflatten(aux_data, children)

A class method used to recreate the class from a PyTree.

Methods

apply_reduced(y_true, y_pred, **kwargs)

Evaluate and reduces the loss.

tree_flatten()

Specify how to serialize module into a JAX PyTree.

tree_unflatten(aux_data, children)

Specify how to build a module from a JAX PyTree.

apply_reduced(y_true, y_pred, **kwargs)#

Evaluate and reduces the loss.

The loss is evaluated separately for each item in the batch and the loss of all batches is reduced by arithmetic mean to a single value.

Parameters:
y_true

The ground truth labels.

y_pred

The predictions.

**kwargs

Other keywords that may be passed to the unreduced loss function.

Returns:
Array

The reduced loss value.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex]

Parameters:
tree_flatten()#

Specify how to serialize module into a JAX PyTree.

Returns:
A tuple with 2 elements:
  • The children, containing arrays & PyTrees

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX PyTree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & PyTrees.

Returns:
GaussianNLLLoss

Reconstructed loss function.

Return type:

GaussianNLLLoss

Parameters:

Examples using bde.ml.loss.GaussianNLLLoss#

sphx_glr_auto_examples_example01.py

Plot Template Estimator.