loss_fns
#
Loss functions.
Classes:
| Name | Description |
|---|---|
BCELoss |
|
BCEWithLogitsLoss |
|
CrossEntropyLoss |
|
L1Loss |
|
MSELoss |
|
NLLLoss |
|
VariationalFreeEnergy |
Variational Free Energy Loss. |
Functions:
| Name | Description |
|---|---|
inputs_and_expanded_targets |
Ensure loss can be computed with additional dimensions of (sampled) predictions in inputs. |
Attributes:
| Name | Type | Description |
|---|---|---|
NegativeELBO |
|
BCELoss
#
BCEWithLogitsLoss
#
CrossEntropyLoss
#
L1Loss
#
MSELoss
#
NLLLoss
#
VariationalFreeEnergy
#
VariationalFreeEnergy(
nll: _Loss,
model: BNNMixin,
prior_loc: Float[Tensor, "parameter"] | None = None,
prior_scale: Float[Tensor, "parameter"] | None = None,
kl_weight: float | None = 1.0,
reduction: str = "mean",
)
Bases: Module
Variational Free Energy Loss.
Computes the variational free energy loss for variational inference with the Kullback-Leibler regularization term computed in weight space. This is also known as the negative evidence lower bound (ELBO).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nll
|
_Loss
|
Loss function defining the negative log-likelihood. |
required |
model
|
BNNMixin
|
The probabilistic model. |
required |
prior_loc
|
Float[Tensor, 'parameter'] | None
|
Location(s) of the prior Gaussian distribution. |
None
|
prior_scale
|
Float[Tensor, 'parameter'] | None
|
Scale(s) of the prior Gaussian distribution. |
None
|
kl_weight
|
float | None
|
Weight for the KL divergence term. If |
1.0
|
reduction
|
str
|
Specifies the reduction to apply to the output: |
'mean'
|
Methods:
| Name | Description |
|---|---|
forward |
|
Attributes:
| Name | Type | Description |
|---|---|---|
kl_weight |
|
|
model |
|
|
nll |
|
|
numel_mean_parameters |
|
|
prior_loc |
|
|
prior_scale |
|
|
reduction |
|
inputs_and_expanded_targets
#
Ensure loss can be computed with additional dimensions of (sampled) predictions in inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
Inputs (predictions). |
required | |
targets
|
Targets. |
required |