Source code for maldideepkit.uncertainty.laplace

"""Laplace-approximation uncertainty estimator.

Wraps the ``laplace-torch`` package to fit a (last-layer or full)
Gaussian approximation of the posterior over the network's weights and
turn the predictive variance into an epistemic / aleatoric
decomposition.
"""

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
import torch
from torch.utils.data import DataLoader

from ..base.classifier import BaseSpectralClassifier
from ..base.data import SpectralDataset
from ._base import BaseUncertaintyEstimator, _apply_classifier_preprocessing
from ._result import UncertaintyResult, _entropy


[docs] class LaplaceEstimator(BaseUncertaintyEstimator): """Laplace-approximation uncertainty estimator. A thin wrapper around the ``laplace-torch`` package (`<https://github.com/aleximmer/Laplace>`_) that fits a Gaussian posterior over the classifier's weights and turns the predictive variance into a per-sample uncertainty estimate. Parameters ---------- classifier : BaseSpectralClassifier A fitted classifier; its ``model_`` is reused as the network whose posterior is approximated. subset_of_weights : {"last_layer", "all"}, default="last_layer" Which subset of weights to model. ``"last_layer"`` is the standard, cheap default and works for any of the MaldiDeepKit backbones whose final layer is a :class:`torch.nn.Linear`. hessian_structure : {"diag", "kron"}, default="diag" Approximation structure for the Hessian. ``"diag"`` is the cheapest; ``"kron"`` (Kronecker-factored) is more accurate at a moderate compute cost. sigma_noise : float, default=1.0 Forwarded to ``laplace-torch``. For classification it has no effect (it controls the regression noise scale) but is exposed for interface symmetry. Raises ------ ImportError If ``laplace-torch`` is not installed. """
[docs] def __init__( self, classifier: BaseSpectralClassifier, subset_of_weights: str = "last_layer", hessian_structure: str = "diag", sigma_noise: float = 1.0, ) -> None: super().__init__(classifier) if subset_of_weights not in {"last_layer", "all"}: raise ValueError( f"subset_of_weights must be 'last_layer' or 'all'; " f"got {subset_of_weights!r}." ) if hessian_structure not in {"diag", "kron"}: raise ValueError( f"hessian_structure must be 'diag' or 'kron'; " f"got {hessian_structure!r}." ) try: import laplace # noqa: F401 except ImportError as exc: # pragma: no cover - guard tested via importorskip raise ImportError( "laplace-torch is required for LaplaceEstimator. " "Install it with: pip install laplace-torch" ) from exc self.subset_of_weights = subset_of_weights self.hessian_structure = hessian_structure self.sigma_noise = float(sigma_noise) self.la_: Any | None = None
[docs] def calibrate( self, X_cal: Any, y_cal: Any, *, batch_size: int | None = None, ) -> "LaplaceEstimator": """Fit the Laplace approximation on ``(X_cal, y_cal)``. Applies the classifier's preprocessing to ``X_cal``, builds an internal :class:`~torch.utils.data.DataLoader`, fits the Laplace approximation, and runs marginal-likelihood prior precision optimisation. Parameters ---------- X_cal : array-like or MaldiSet of shape (n_samples, n_bins) Calibration spectra. y_cal : array-like of shape (n_samples,) Calibration labels using the original label space stored in ``classifier.classes_``. Re-encoded to ``0..n_classes-1`` internally. batch_size : int or None, default=None DataLoader batch size. When ``None``, falls back to ``classifier.batch_size``. Returns ------- LaplaceEstimator ``self``, with the fitted Laplace approximation stored on :attr:`la_`. """ from laplace import Laplace X_proc = _apply_classifier_preprocessing(self.classifier, X_cal) if hasattr(y_cal, "to_numpy"): y_cal = y_cal.to_numpy() y_np = np.asarray(y_cal).ravel() if X_proc.shape[0] != y_np.shape[0]: raise ValueError( f"X_cal has {X_proc.shape[0]} rows but y_cal has {y_np.shape[0]}." ) classes = np.asarray(self.classifier.classes_) if not np.all(np.isin(y_np, classes)): unknown = np.setdiff1d(np.unique(y_np), classes).tolist() raise ValueError(f"y_cal contains labels not seen at fit time: {unknown}.") y_encoded = np.searchsorted(classes, y_np).astype(np.int64) bs = max( 1, int(batch_size if batch_size is not None else self.classifier.batch_size) ) dataset = SpectralDataset(X_proc, y_encoded, standardize=False) loader = DataLoader(dataset, batch_size=bs, shuffle=False, drop_last=False) device = self.classifier._device_ model = self.classifier.model_ model.eval() la = Laplace( model, likelihood="classification", subset_of_weights=self.subset_of_weights, hessian_structure=self.hessian_structure, sigma_noise=self.sigma_noise, ) la.fit(loader) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=r"By default `link_approx` is `probit`.*", category=UserWarning, ) try: la.optimize_prior_precision(method="marglik", link_approx="probit") except TypeError: la.optimize_prior_precision() self.la_ = la self._device_ = device return self
[docs] def predict_with_uncertainty(self, X: Any) -> UncertaintyResult: """Return predictions and Laplace-derived uncertainty for ``X``. Uses ``pred_type="glm"`` and ``link_approx="probit"`` to map the weight-space posterior into a softmax-domain predictive distribution. Per-sample predictive variance is summarised as the mean diagonal entry, then squashed into ``[0, 1]`` via ``1 - exp(-v)`` and stored as ``epistemic``. The scalar :attr:`UncertaintyResult.uncertainty` field is the normalised entropy of the predictive mean; ``aleatoric`` is the non-negative residual ``uncertainty - epistemic``. Raises ------ RuntimeError If :meth:`calibrate` has not been called. """ if self.la_ is None: raise RuntimeError( "LaplaceEstimator has not been calibrated. " "Call calibrate(X_cal, y_cal) before predict_with_uncertainty." ) X_proc = _apply_classifier_preprocessing(self.classifier, X) device = self.classifier._device_ X_t = torch.from_numpy(X_proc.astype(np.float32)).to(device) with torch.no_grad(): probs = self.la_(X_t, pred_type="glm", link_approx="probit") if isinstance(probs, tuple): probs = probs[0] proba_mean = probs.detach().cpu().numpy().astype(np.float64) epistemic_raw = self._epistemic_variance(X_t) epistemic = 1.0 - np.exp(-epistemic_raw) epistemic = np.clip(epistemic, 0.0, 1.0) total = _entropy(proba_mean) aleatoric = np.clip(total - epistemic, 0.0, 1.0) idx = np.argmax(proba_mean, axis=1) predictions = np.asarray(self.classifier.classes_)[idx] metadata: dict[str, Any] = { "subset_of_weights": self.subset_of_weights, "hessian_structure": self.hessian_structure, "predictive_variance": epistemic_raw.astype(np.float64, copy=False), } return UncertaintyResult( predictions=predictions, proba_mean=proba_mean, uncertainty=total.astype(np.float64, copy=False), epistemic=epistemic.astype(np.float64, copy=False), aleatoric=aleatoric.astype(np.float64, copy=False), method="laplace", metadata=metadata, )
def _epistemic_variance(self, X_t: torch.Tensor) -> np.ndarray: """Per-sample mean predictive variance from the Laplace posterior. Tries the public sampling API first (works on every ``laplace-torch`` release that exposes :meth:`predictive_samples`) and falls back to the GLM predictive distribution otherwise. """ with torch.no_grad(): try: samples = self.la_.predictive_samples( X_t, pred_type="glm", n_samples=100 ) var = samples.var(dim=0) return var.mean(dim=-1).detach().cpu().numpy().astype(np.float64) except (AttributeError, TypeError): pass f_mu, f_var = self.la_._glm_predictive_distribution(X_t) if f_var.dim() == 3: diag = torch.diagonal(f_var, dim1=-2, dim2=-1) else: diag = f_var return diag.mean(dim=-1).detach().cpu().numpy().astype(np.float64)
__all__ = ["LaplaceEstimator"]