GaussianNLLLoss#
- class bde.ml.loss.GaussianNLLLoss(epsilon=1e-06, mean_weight=1.0, is_full=False)#
Gaussian negative log likelihood loss.
A callable jax-supported class for computing the negative log likelihood loss of a Gaussian distribution. This loss is commonly used in probabilistic models to quantify the difference between the predicted probability distribution and the true labels.
Mathematically, it is defined as:
\[\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \frac{(\hat\mu - \mu)^2}{var} + \log{(2\pi)} ]\]This implementation includes the following parameters:
\[\begin{split}\ell_{\text{Gaussian NLLLoss}} = \frac{1}{2}[ \log{(var)} + \omega_{\text{mean weight}} \cdot \frac{(\hat\mu - \mu)^2}{var} + \begin{cases} \log{(2\pi)} && \text{"is full" is True } \\ 0 && \text{"is full" is False } \end{cases} ]\end{split}\]where
\[var = max(\sigma^2, \epsilon)\]- Attributes:
- paramsdict[str, …]
Defines loss-related parameters: - epsilon : float
A stability factor for the variance.
- mean_weightfloat
A scale factor for the mean.
- is_fullbool
If true include constant loss value, otherwise ignored.
- Parameters:
Methods
__call__(y_true, y_pred)
Computes the log-likelihood loss for the given predictions and labels.
_split_pred(y_true, y_pred)
Splits the predicted values into predictions and their corresponding uncertainties.
apply_reduced()
Evaluates the reduced loss (inherited from base class).
Methods
apply_reduced
(y_true, y_pred, **kwargs)Evaluate reduced the loss.
Specify how to serialize module into a JAX pytree.
tree_unflatten
(aux_data, children)Specify how to build a module from a JAX pytree.
- apply_reduced(y_true, y_pred, **kwargs)#
Evaluate reduced the loss.
The loss is evaluated separately for each item in the batch, and the mean of these values is returned.
- tree_flatten()#
Specify how to serialize module into a JAX pytree.
- classmethod tree_unflatten(aux_data, children)#
Specify how to build a module from a JAX pytree.
- Parameters:
- aux_data
Contains static, hashable data.
- children
Contain arrays & pytrees.
- Returns:
- GaussianNLLLoss
Reconstructed loss function.
- Return type:
- Parameters:
Examples using bde.ml.loss.GaussianNLLLoss
#
sphx_glr_auto_examples_example01.py