GaussianNLLLoss#

class bde.ml.loss.GaussianNLLLoss(epsilon=1e-06, mean_weight=1.0, is_full=False)#

Gaussian negative log likelihood loss.

A callable jax-supported class for computing the negative log likelihood loss of a Gaussian distribution. This loss is commonly used in probabilistic models to quantify the difference between the predicted probability distribution and the true labels.

Mathematically, it is defined as:

\[\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \frac{(\hat\mu - \mu)^2}{var} + \log{(2\pi)} ]\]

This implementation includes the following parameters:

\[\begin{split}\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \omega_{\text{mean weight}} \cdot \frac{(\hat\mu - \mu)^2}{var} + \begin{cases} \log{(2\pi)} && \text{"is full" is True } \\ 0 && \text{"is full" is False } \end{cases} ]\end{split}\]

where

\[var = max(\sigma^2, \epsilon)\]
Attributes:
paramsdict[str, …]

Defines loss-related parameters: - epsilon : float

A stability factor for the variance.

  • mean_weightfloat

    A scale factor for the mean.

  • is_fullbool

    If true include constant loss value, otherwise ignored.

Parameters:

Methods

__call__(y_true, y_pred)

Computes the log-likelihood loss for the given predictions and labels.

_split_pred(y_true, y_pred)

Splits the predicted values into predictions and their corresponding uncertainties.

apply_reduced()

Evaluates the reduced loss (inherited from base class).

Methods

apply_reduced(y_true, y_pred, **kwargs)

Evaluate reduced the loss.

tree_flatten()

Specify how to serialize module into a JAX pytree.

tree_unflatten(aux_data, children)

Specify how to build a module from a JAX pytree.

apply_reduced(y_true, y_pred, **kwargs)#

Evaluate reduced the loss.

The loss is evaluated separately for each item in the batch, and the mean of these values is returned.

Parameters:
y_true

The ground truth.

y_pred

The prediction.

**kwargs

Other keywords that may be passed to the unreduced loss function.

Returns:
Array

The reduced loss value.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex]

Parameters:
tree_flatten()#

Specify how to serialize module into a JAX pytree.

Returns:
A tuple with 2 elements:
  • The children, containing arrays & pytrees

  • The aux_data, containing static and hashable data.

Return type:

Tuple[Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

classmethod tree_unflatten(aux_data, children)#

Specify how to build a module from a JAX pytree.

Parameters:
aux_data

Contains static, hashable data.

children

Contain arrays & pytrees.

Returns:
GaussianNLLLoss

Reconstructed loss function.

Return type:

GaussianNLLLoss

Parameters:

Examples using bde.ml.loss.GaussianNLLLoss#

sphx_glr_auto_examples_example01.py

Plot Template Estimator.