"""Post-hoc calibration helpers used by :class:`BaseSpectralClassifier`.
- :func:`tune_threshold` picks the binary decision threshold on a
validation set that maximises a chosen metric.
- :func:`fit_temperature` optimises a single temperature scalar by
LBFGS on held-out logits for probability calibration.
"""
from __future__ import annotations
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score, f1_score, roc_curve
[docs]
def tune_threshold(
y_true: np.ndarray,
y_proba: np.ndarray,
metric: str = "balanced_accuracy",
) -> float:
"""Pick the binary decision threshold that maximises ``metric``.
Sweeps the unique observed probabilities (capped at 1000 quantiles)
so severely-imbalanced settings still resolve. Falls back to a
99-point ``linspace(0.01, 0.99)`` only when no probability lies
strictly inside ``(0, 1)``.
Parameters
----------
y_true : array-like of shape (n_samples,)
Binary ground-truth labels in ``{0, 1}``.
y_proba : array-like of shape (n_samples,) or (n_samples, 2)
Predicted positive-class probabilities. If a 2-D array is
given, column index ``1`` is used.
metric : {"balanced_accuracy", "f1", "youden"}, default="balanced_accuracy"
Which metric to maximise. ``"youden"`` = TPR - FPR.
Returns
-------
float
Threshold in ``(0, 1)``. Use as ``y_pred = (y_proba >= t)``.
"""
y_true = np.asarray(y_true).ravel().astype(int)
y_proba_arr = np.asarray(y_proba, dtype=float)
if y_proba_arr.ndim == 2:
if y_proba_arr.shape[1] != 2:
raise ValueError(
"tune_threshold is binary-only; "
f"got y_proba with {y_proba_arr.shape[1]} columns."
)
y_proba_arr = y_proba_arr[:, 1]
y_proba_arr = y_proba_arr.ravel()
if metric == "youden":
fpr, tpr, thr = roc_curve(y_true, y_proba_arr)
valid = (thr > 0) & (thr < 1)
if not valid.any():
return 0.5
j = tpr[valid] - fpr[valid]
return float(thr[valid][int(np.argmax(j))])
unique = np.unique(y_proba_arr)
unique = unique[(unique > 0) & (unique < 1)]
if unique.size == 0:
candidates = np.linspace(0.01, 0.99, 99)
elif unique.size > 1000:
candidates = np.quantile(unique, np.linspace(0.0, 1.0, 1000))
else:
candidates = unique
best_t, best_score = 0.5, -np.inf
for t in candidates:
pred = (y_proba_arr >= t).astype(int)
if metric == "balanced_accuracy":
score = balanced_accuracy_score(y_true, pred)
elif metric == "f1":
score = f1_score(y_true, pred, zero_division=0)
else:
raise ValueError(
f"Unknown metric={metric!r}; "
"expected 'balanced_accuracy', 'f1', or 'youden'."
)
if score > best_score:
best_score, best_t = score, float(t)
return best_t
[docs]
def fit_temperature(
logits: torch.Tensor | np.ndarray,
y_true: torch.Tensor | np.ndarray,
max_iter: int = 200,
lr: float = 1e-1,
) -> float:
"""Fit a scalar temperature by LBFGS minimisation of NLL.
Applies to raw logits (not probabilities). Returns the temperature
``T`` such that ``softmax(logits / T)`` is better-calibrated than
the unscaled softmax.
Parameters
----------
logits : torch.Tensor or ndarray of shape (n_samples, n_classes)
Held-out logits.
y_true : torch.Tensor or ndarray of shape (n_samples,)
Ground-truth class indices.
max_iter : int, default=200
LBFGS max iterations.
lr : float, default=1e-1
LBFGS step size.
Returns
-------
float
Fitted temperature; strictly positive.
"""
if not isinstance(logits, torch.Tensor):
logits = torch.as_tensor(logits, dtype=torch.float32)
if not isinstance(y_true, torch.Tensor):
y_true = torch.as_tensor(np.asarray(y_true).ravel(), dtype=torch.long)
else:
y_true = y_true.to(torch.long).view(-1)
log_temperature = torch.zeros(1, device=logits.device, requires_grad=True)
optimizer = torch.optim.LBFGS([log_temperature], lr=lr, max_iter=max_iter)
def _closure():
optimizer.zero_grad()
t = torch.exp(log_temperature)
loss = F.cross_entropy(logits / t, y_true)
loss.backward()
return loss
optimizer.step(_closure)
return float(torch.exp(log_temperature).detach().item())