Source code for maldideepkit.utils.loss
"""Loss functions used by :class:`BaseSpectralClassifier`."""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
[docs]
class FocalLoss(nn.Module):
r"""Multi-class focal loss with optional class weighting and label smoothing.
Implements
.. math::
L = - (1 - p_t)^\gamma \log p_t
where :math:`p_t` is the predicted probability of the true class.
At :math:`\gamma = 0` and ``label_smoothing=0`` this reduces to
:class:`~torch.nn.CrossEntropyLoss`.
Parameters
----------
weight : torch.Tensor or None, default=None
Per-class weight tensor of shape ``(n_classes,)``. Applied to
every sample by gathering at its target index (matches the
:class:`CrossEntropyLoss` convention for ``weight``).
gamma : float, default=2.0
Focusing parameter. ``0`` degrades to cross-entropy; ``2`` is
the value used in Lin et al. 2017.
label_smoothing : float, default=0.0
Target smoothing in ``[0, 1)``. At ``0.0`` the target is a
one-hot vector; otherwise the target distribution becomes
``(1 - eps) * one_hot + eps / n_classes``.
reduction : {"mean", "sum", "none"}, default="mean"
How to reduce the per-sample loss tensor.
"""
[docs]
def __init__(
self,
weight: torch.Tensor | None = None,
gamma: float = 2.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
) -> None:
super().__init__()
if gamma < 0:
raise ValueError(f"gamma must be >= 0; got {gamma!r}.")
if not 0.0 <= label_smoothing < 1.0:
raise ValueError(
f"label_smoothing must be in [0, 1); got {label_smoothing!r}."
)
if reduction not in {"mean", "sum", "none"}:
raise ValueError(
f"reduction must be 'mean', 'sum', or 'none'; got {reduction!r}."
)
self.register_buffer(
"class_weight",
weight.detach().clone() if weight is not None else None,
persistent=False,
)
self.gamma = float(gamma)
self.label_smoothing = float(label_smoothing)
self.reduction = reduction
[docs]
def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""Compute focal loss for ``(N, C)`` logits.
Accepts either integer targets of shape ``(N,)`` or a soft
probability distribution of shape ``(N, C)`` (as produced by
MixUp / CutMix). When soft targets are passed the loss
becomes
.. math::
L = - \sum_c t_c \, (1 - p_c)^\gamma \log p_c
``label_smoothing`` is ignored on the soft-target path.
Class weighting follows the :class:`~torch.nn.CrossEntropyLoss`
convention with ``reduction="mean"``: the per-sample weight is
``weight[y_i]`` (or ``Σ_c t_c · weight_c`` for soft targets),
and the mean reduction divides by ``Σ_i sample_weight_i``
rather than ``N``.
"""
log_probs = F.log_softmax(logits, dim=-1)
probs = log_probs.exp()
n_classes = logits.shape[-1]
sample_weight: torch.Tensor | None = None
if target.dim() == 2:
smooth = target.to(dtype=log_probs.dtype)
focal_per_class = (1.0 - probs).clamp_min(1e-12).pow(self.gamma)
per_class_loss = -smooth * focal_per_class * log_probs
loss = per_class_loss.sum(dim=-1)
if self.class_weight is not None:
w = self.class_weight.to(loss.device)
sample_weight = (smooth * w).sum(dim=-1)
loss = loss * sample_weight
elif self.label_smoothing == 0.0:
logpt = log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
pt = probs.gather(1, target.unsqueeze(1)).squeeze(1)
focal_term = (1.0 - pt).clamp_min(1e-12).pow(self.gamma)
loss = -focal_term * logpt
if self.class_weight is not None:
sample_weight = self.class_weight.to(loss.device).gather(0, target)
loss = loss * sample_weight
else:
eps = self.label_smoothing
smooth = torch.full_like(probs, eps / n_classes)
smooth.scatter_(
1,
target.unsqueeze(1),
smooth.gather(1, target.unsqueeze(1)) + (1.0 - eps),
)
focal_per_class = (1.0 - probs).clamp_min(1e-12).pow(self.gamma)
per_class_loss = -smooth * focal_per_class * log_probs
loss = per_class_loss.sum(dim=-1)
if self.class_weight is not None:
sample_weight = self.class_weight.to(loss.device).gather(0, target)
loss = loss * sample_weight
if self.reduction == "mean":
if sample_weight is not None:
denom = sample_weight.sum().clamp_min(1e-12)
return loss.sum() / denom
return loss.mean()
if self.reduction == "sum":
return loss.sum()
return loss