"""Abstract base class for sklearn-compatible spectral classifiers."""
from __future__ import annotations
import json
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Any, Callable
import numpy as np
import torch
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_is_fitted
from torch import nn
from torch.utils.data import DataLoader
from ..utils.loss import FocalLoss
from ..utils.reproducibility import resolve_device, seed_everything
from ..utils.training import EarlyStopping, train_loop
from .data import SpectralDataset, _to_numpy, make_loaders
def _serialise_transform_state(state: dict[str, Any] | None) -> dict[str, Any] | None:
"""Convert numpy arrays inside an input-transform state to lists (JSON-safe)."""
if state is None:
return None
out: dict[str, Any] = {}
for k, v in state.items():
if isinstance(v, np.ndarray):
out[k] = v.tolist()
else:
out[k] = v
return out
def _deserialise_transform_state(
state: dict[str, Any] | None,
) -> dict[str, Any] | None:
"""Inverse of :func:`_serialise_transform_state`; numpy-ifies list values."""
if state is None:
return None
out: dict[str, Any] = {}
for k, v in state.items():
if k == "mode":
out[k] = v
elif isinstance(v, list):
out[k] = np.asarray(v, dtype=np.float32)
else:
out[k] = v
return out
[docs]
class BaseSpectralClassifier(ClassifierMixin, BaseEstimator, metaclass=ABCMeta): # type: ignore[misc]
"""Abstract base for all MaldiDeepKit classifiers.
Concrete subclasses only need to override :meth:`_build_model`,
which should return a :class:`torch.nn.Module` that maps an input of
shape ``(batch, input_dim)`` to logits of shape ``(batch, n_classes)``.
Everything else (device placement, validation split, early stopping,
checkpointing, predict / predict_proba, save / load) is provided here.
Parameters
----------
input_dim : int or None, default=None
Number of input bins. If ``None``, inferred from ``X`` at
:meth:`fit` time and stored as :attr:`input_dim_`.
n_classes : int, default=2
Number of output classes. Overwritten with the true number of
classes found in ``y`` at :meth:`fit` time.
learning_rate : float, default=1e-3
Initial learning rate for the optimizer (Adam by default; AdamW
when ``weight_decay > 0``).
weight_decay : float, default=0.0
L2 penalty applied via decoupled weight decay. When ``> 0`` the
optimizer switches from ``Adam`` to ``AdamW``.
grad_clip_norm : float or None, default=None
If set, clip gradient global L2 norm to this value before every
optimizer step. ``1.0`` is a common default for transformers.
label_smoothing : float, default=0.0
Label smoothing factor in ``[0, 1)`` passed to the loss. Applied
to both cross-entropy and focal-loss paths.
loss : {"cross_entropy", "focal"}, default="cross_entropy"
Classification loss. ``"focal"`` uses
:class:`~maldideepkit.utils.FocalLoss` with ``gamma=focal_gamma``.
Good for highly imbalanced problems.
focal_gamma : float, default=2.0
Focal-loss focusing parameter. Ignored when
``loss="cross_entropy"``.
use_amp : bool, default=False
If ``True`` and the resolved device is CUDA, run forward + loss
under :func:`torch.autocast` and use :class:`torch.amp.GradScaler`
for backward. ~2x wall-time speedup on recent NVIDIA GPUs. On CPU
this is a no-op.
swa_start_epoch : int or None, default=None
If set, start Stochastic Weight Averaging at this epoch. The SWA
average replaces the best-val checkpoint at the end of fit.
Typical value: 60-80% of ``epochs``.
tune_threshold : bool, default=False
(Binary classification only.) After fit, sweep thresholds on
the validation split and store the one that maximises
``threshold_metric``. :meth:`predict` uses this threshold
instead of ``argmax @ 0.5``.
threshold_metric : {"balanced_accuracy", "f1", "youden"}, default="balanced_accuracy"
Metric used by ``tune_threshold``.
calibrate_temperature : bool, default=False
If ``True``, after fit run LBFGS-based temperature scaling on
held-out validation logits (Guo et al. 2017). The fitted
temperature is stored as :attr:`temperature_` and applied in
:meth:`predict_proba` to sharpen / smooth probabilities
without changing the argmax.
min_val_auroc_for_threshold_tune : float, default=0.6
Binary-classification guardrail on ``tune_threshold=True``: if
the validation AUROC is below this value, the threshold sweep
is skipped and ``threshold_`` falls back to ``0.5``. Set to
``0.0`` to disable.
use_sam : bool, default=False
If ``True``, wrap the base optimizer in
:class:`~maldideepkit.utils.SAMOptimizer` and run the two-step
Sharpness-Aware Minimization update. Doubles forward / backward
compute per step; typically helps generalisation on small
datasets.
sam_rho : float, default=0.05
Size of the SAM ascent step. Ignored when ``use_sam=False``.
batch_size : int, default=32
Training mini-batch size.
epochs : int, default=100
Maximum number of training epochs.
early_stopping_patience : int, default=10
Number of epochs without validation-loss improvement before
training is stopped.
val_fraction : float, default=0.1
Fraction of the training data held out for the internal
validation split.
warmup_epochs : int, default=0
If positive, linearly ramp each optimizer param group's learning
rate from ``0`` to its configured target over the first
``warmup_epochs`` epochs. Useful for transformer architectures
that can diverge at full learning rate during the first few steps.
standardize : bool, default=False
Shorthand for ``input_transform="standardize"`` (when True) or
``"none"`` (when False). Kept for backwards compatibility;
``input_transform`` is the modern interface and wins when
both are supplied.
input_transform : str, optional
One of ``{"none", "standardize", "log1p", "robust",
"log1p+standardize"}``. Fit on the (warped) training split
only and stored as :attr:`input_transform_state_`; reapplied
at :meth:`predict` / :meth:`predict_proba` time.
warping : sklearn-style transformer, optional
Spectral alignment / warping transformer applied **before**
standardization. Fitted on the training split only, then
used to transform both splits during training and new data
at :meth:`predict` / :meth:`predict_proba` time. The fitted
transformer is stored as :attr:`warper_`.
metrics_log_path : str or Path, optional
If set, write a per-epoch metrics CSV to this path during
:meth:`fit`. One row per epoch with columns ``epoch,
train_loss, val_loss, lr, mean_grad_norm, n_grad_updates``
(+ ``train_auroc, val_auroc`` when ``track_train_metrics=True``).
track_train_metrics : bool, default=False
Only used when ``metrics_log_path`` is set. If ``True``, after
every epoch run a no-grad forward pass over the full training
split and record ``train_auroc`` + ``val_auroc`` alongside the
losses. Adds one extra pass per epoch; binary classification
only.
augment : callable, optional
Per-batch augmentation applied to training batches only. The
usual choice is :class:`~maldideepkit.augment.SpectrumAugment`.
mixup_alpha : float, default=0.0
If positive, apply MixUp augmentation per training batch with a
Beta(``mixup_alpha``, ``mixup_alpha``) mixing coefficient.
``0.0`` disables MixUp. Composable with ``cutmix_alpha``.
cutmix_alpha : float, default=0.0
If positive, apply CutMix augmentation per training batch with
a Beta(``cutmix_alpha``, ``cutmix_alpha``) mixing coefficient.
``0.0`` disables CutMix.
ema_decay : float or None, default=None
If set (typically ``0.999``), maintain an exponential moving
average of model weights during training and use the EMA weights
at inference time.
retry_on_val_auroc_below : float or None, default=None
Binary-classification guardrail. If set and the post-fit
validation AUROC is below this threshold, retrain with a
different RNG seed up to ``max_retries`` times. Useful for
unstable small-data fits.
max_retries : int, default=2
Maximum number of automatic refits triggered by
``retry_on_val_auroc_below``. Ignored when that guardrail is
unset.
class_weight : {"balanced", None} or array-like, default=None
Per-class weights applied to :class:`~torch.nn.CrossEntropyLoss`.
``"balanced"`` uses ``n_samples / (n_classes * class_count)``.
device : {"auto", "cpu", "cuda"} or torch.device, default="auto"
Device used for training and inference.
random_state : int, default=0
Seeds Python, NumPy, and PyTorch RNGs and the validation split.
verbose : bool, default=False
If ``True``, prints one line per training epoch.
Attributes
----------
model_ : torch.nn.Module
The fitted PyTorch model.
classes_ : ndarray of shape (n_classes,)
Original class labels seen during :meth:`fit`.
input_dim_ : int
Resolved number of input features.
n_classes_ : int
Resolved number of classes.
feature_mean_ : ndarray or None
Per-feature mean used when ``standardize=True``.
feature_std_ : ndarray or None
Per-feature std used when ``standardize=True``.
n_features_in_ : int
Number of features seen at :meth:`fit` (sklearn convention).
"""
[docs]
def __init__(
self,
input_dim: int | None = None,
n_classes: int = 2,
learning_rate: float = 1e-3,
weight_decay: float = 0.0,
grad_clip_norm: float | None = None,
label_smoothing: float = 0.0,
loss: str = "cross_entropy",
focal_gamma: float = 2.0,
use_amp: bool = False,
swa_start_epoch: int | None = None,
tune_threshold: bool = False,
threshold_metric: str = "balanced_accuracy",
calibrate_temperature: bool = False,
min_val_auroc_for_threshold_tune: float = 0.6,
use_sam: bool = False,
sam_rho: float = 0.05,
batch_size: int = 32,
epochs: int = 100,
early_stopping_patience: int = 10,
val_fraction: float = 0.1,
warmup_epochs: int = 0,
standardize: bool = False,
input_transform: str | None = None,
warping: Any | None = None,
metrics_log_path: str | Path | None = None,
track_train_metrics: bool = False,
augment: Callable[[torch.Tensor], torch.Tensor] | None = None,
mixup_alpha: float = 0.0,
cutmix_alpha: float = 0.0,
ema_decay: float | None = None,
retry_on_val_auroc_below: float | None = None,
max_retries: int = 2,
class_weight: str | np.ndarray | list | None = None,
device: str | torch.device = "auto",
random_state: int = 0,
verbose: bool = False,
) -> None:
self.input_dim = input_dim
self.n_classes = n_classes
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.grad_clip_norm = grad_clip_norm
self.label_smoothing = label_smoothing
self.loss = loss
self.focal_gamma = focal_gamma
self.use_amp = use_amp
self.swa_start_epoch = swa_start_epoch
self.tune_threshold = tune_threshold
self.threshold_metric = threshold_metric
self.calibrate_temperature = calibrate_temperature
self.min_val_auroc_for_threshold_tune = min_val_auroc_for_threshold_tune
self.use_sam = use_sam
self.sam_rho = sam_rho
self.batch_size = batch_size
self.epochs = epochs
self.early_stopping_patience = early_stopping_patience
self.val_fraction = val_fraction
self.warmup_epochs = warmup_epochs
self.standardize = standardize
self.input_transform = input_transform
self.warping = warping
self.metrics_log_path = metrics_log_path
self.track_train_metrics = track_train_metrics
self.augment = augment
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.ema_decay = ema_decay
self.retry_on_val_auroc_below = retry_on_val_auroc_below
self.max_retries = max_retries
self.class_weight = class_weight
self.device = device
self.random_state = random_state
self.verbose = verbose
@abstractmethod
def _build_model(self) -> nn.Module:
"""Return a fresh :class:`nn.Module` for the current hyperparameters.
Implementations should use :attr:`input_dim_` and
:attr:`n_classes_` rather than the constructor arguments, since
those are the values resolved at :meth:`fit` time.
"""
def _optimizer_param_groups(self, model: nn.Module) -> list[dict[str, Any]]:
"""Return parameter groups for the optimizer.
Default: a single group containing every parameter of ``model``.
Override in a subclass to give specific parameters a different
learning rate or weight decay.
Parameters
----------
model : nn.Module
The freshly built, on-device training model.
Returns
-------
list of dict
Param-group dicts suitable for :class:`torch.optim.Adam`.
Groups without an explicit ``lr`` or ``weight_decay``
inherit the defaults from :meth:`fit`.
"""
return [{"params": list(model.parameters())}]
def _resolve_device(self) -> torch.device:
return resolve_device(self.device)
def _compute_class_weight(self, y: np.ndarray) -> torch.Tensor | None:
if self.class_weight is None:
return None
if isinstance(self.class_weight, str):
if self.class_weight != "balanced":
raise ValueError(
f"Unknown class_weight={self.class_weight!r}; "
"use 'balanced', None, or an array."
)
counts = np.bincount(y, minlength=self.n_classes_)
if np.any(counts == 0):
missing = np.flatnonzero(counts == 0).tolist()
raise ValueError(
"class_weight='balanced' requires every class to be "
f"present in y; missing class indices: {missing}."
)
weights = len(y) / (self.n_classes_ * counts)
return torch.tensor(weights, dtype=torch.float32)
weights = np.asarray(self.class_weight, dtype=np.float32)
if weights.shape != (self.n_classes_,):
raise ValueError(
f"class_weight has shape {weights.shape}, "
f"expected ({self.n_classes_},)."
)
return torch.from_numpy(weights)
def _build_criterion(self, class_weight: torch.Tensor | None) -> nn.Module:
if self.loss == "cross_entropy":
return nn.CrossEntropyLoss(
weight=class_weight,
label_smoothing=float(self.label_smoothing),
)
if self.loss == "focal":
return FocalLoss(
weight=class_weight,
gamma=float(self.focal_gamma),
label_smoothing=float(self.label_smoothing),
)
raise ValueError(
f"Unknown loss={self.loss!r}; expected 'cross_entropy' or 'focal'."
)
def _make_mix_generator(self) -> torch.Generator | None:
if float(self.mixup_alpha) <= 0 and float(self.cutmix_alpha) <= 0:
return None
gen = torch.Generator()
gen.manual_seed(int(self.random_state))
return gen
def _prepare_inputs(self, X: Any, y: Any) -> tuple[np.ndarray, np.ndarray]:
X_np = _to_numpy(X)
if hasattr(y, "to_numpy"):
y = y.to_numpy()
y_np = np.asarray(y).ravel()
if X_np.shape[0] != y_np.shape[0]:
raise ValueError(f"X has {X_np.shape[0]} rows but y has {y_np.shape[0]}.")
self.classes_ = unique_labels(y_np)
self.n_classes_ = int(len(self.classes_))
if self.n_classes_ < 2:
raise ValueError(
f"{type(self).__name__} needs at least 2 classes in y; "
f"got {self.n_classes_}."
)
self.input_dim_ = (
int(X_np.shape[1]) if self.input_dim is None else int(self.input_dim)
)
if self.input_dim_ != X_np.shape[1]:
raise ValueError(
f"input_dim={self.input_dim_} does not match X.shape[1]={X_np.shape[1]}."
)
self.n_features_in_ = self.input_dim_
y_encoded = np.searchsorted(self.classes_, y_np).astype(np.int64)
return X_np, y_encoded
[docs]
def fit(
self, X: Any, y: Any, *, warm_start: bool = False
) -> BaseSpectralClassifier:
"""Fit the model on ``(X, y)``.
Parameters
----------
X : array-like or MaldiSet of shape (n_samples, n_bins)
Training spectra. NumPy arrays, pandas DataFrames, and
objects with a DataFrame-like ``.X`` attribute are accepted.
y : array-like of shape (n_samples,)
Integer or string class labels. Re-encoded to ``0..n_classes-1``
internally; original labels are preserved in :attr:`classes_`.
warm_start : bool, default=False
When ``True`` and the estimator already has a fitted
:attr:`model_`, the underlying :class:`torch.nn.Module` is
reused as the starting point of training instead of being
rebuilt from scratch via :meth:`_build_model`. This unblocks
federated learning, continual learning, and fine-tuning
workflows that need ``fit()`` to *resume* from the current
weights rather than reinitialise. ``warm_start`` applies only
to the first training attempt; retries triggered by
``retry_on_val_auroc_below`` always rebuild via
:meth:`_build_model` (the warm-start weights already failed
once). When ``warm_start=True`` but no prior ``model_``
exists, falls back silently to a fresh build (sklearn
convention).
Returns
-------
self : BaseSpectralClassifier
The fitted estimator.
"""
seed_everything(int(self.random_state))
X_np, y_encoded = self._prepare_inputs(X, y)
device = self._resolve_device()
warper = clone(self.warping) if self.warping is not None else None
train_loader, val_loader, stats = make_loaders(
X_np,
y_encoded,
batch_size=int(self.batch_size),
val_size=float(self.val_fraction),
random_state=int(self.random_state),
standardize=bool(self.standardize),
input_transform=self.input_transform,
warper=warper,
)
self.feature_mean_ = stats["mean"]
self.feature_std_ = stats["std"]
self.warper_ = stats["warper"]
self.input_transform_state_ = stats["input_transform_state"]
class_weight = self._compute_class_weight(y_encoded)
if class_weight is not None:
class_weight = class_weight.to(device)
criterion = self._build_criterion(class_weight)
X_val_t, y_val_t = self._collect_validation(val_loader, device)
# Retry-on-collapse: when ``retry_on_val_auroc_below`` is set,
# reseed and retrain up to ``max_retries`` more times if the
# final val AUROC is below threshold.
retry_threshold = self.retry_on_val_auroc_below
max_retries = max(0, int(self.max_retries))
total_attempts = 1 + (max_retries if retry_threshold is not None else 0)
final_val_auroc = float("nan")
base_seed = int(self.random_state)
fitted_model: nn.Module | None = None
for attempt in range(total_attempts):
if attempt > 0:
seed_everything(base_seed + 1_000_003 * attempt)
if (
warm_start
and attempt == 0
and getattr(self, "model_", None) is not None
):
# Resume from the previously-fitted module. Skip
# _build_model entirely so the federated / continual
# learning caller's pre-loaded weights are not wiped.
model = self.model_.to(device)
else:
model = self._build_model().to(device)
opt_cls = (
torch.optim.AdamW if float(self.weight_decay) > 0 else torch.optim.Adam
)
param_groups = self._optimizer_param_groups(model)
if bool(self.use_sam):
from ..utils.sam import SAMOptimizer
optimizer = SAMOptimizer(
param_groups,
base_optimizer=opt_cls,
rho=float(self.sam_rho),
lr=float(self.learning_rate),
weight_decay=float(self.weight_decay),
)
else:
optimizer = opt_cls(
param_groups,
lr=float(self.learning_rate),
weight_decay=float(self.weight_decay),
)
warmup = max(0, int(self.warmup_epochs))
t_max = max(1, int(self.epochs) - warmup)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=t_max, eta_min=1e-6
)
metrics_recorder = self._build_metrics_recorder(
model, train_loader, X_val_t, y_val_t, device
)
early = EarlyStopping(patience=int(self.early_stopping_patience))
fitted_model = train_loop(
model,
train_loader,
(X_val_t, y_val_t),
criterion,
optimizer,
scheduler,
device,
int(self.epochs),
early,
verbose=bool(self.verbose),
warmup_epochs=int(self.warmup_epochs),
grad_clip_norm=(
float(self.grad_clip_norm)
if self.grad_clip_norm is not None
else None
),
use_amp=bool(self.use_amp),
swa_start_epoch=(
int(self.swa_start_epoch)
if self.swa_start_epoch is not None
else None
),
use_sam=bool(self.use_sam),
metrics_recorder=metrics_recorder,
augment=self.augment,
mixup_alpha=float(self.mixup_alpha),
cutmix_alpha=float(self.cutmix_alpha),
n_classes=int(self.n_classes_),
mix_generator=self._make_mix_generator(),
ema_decay=(
float(self.ema_decay) if self.ema_decay is not None else None
),
)
if retry_threshold is None or attempt == total_attempts - 1:
break
final_val_auroc = self._attempt_val_auroc(
fitted_model, X_val_t, y_val_t, device
)
if not np.isfinite(final_val_auroc) or final_val_auroc >= float(
retry_threshold
):
break
assert fitted_model is not None
self.model_ = fitted_model
self._device_ = device
self.threshold_ = None
self.temperature_ = None
if self.metrics_log_path is not None:
self._write_post_fit_sidecar(X_val_t, y_val_t)
if bool(self.tune_threshold) or bool(self.calibrate_temperature):
self._fit_post_hoc_calibration(X_val_t, y_val_t)
return self
def _attempt_val_auroc(
self,
model: nn.Module,
X_val: torch.Tensor,
y_val: torch.Tensor,
device: torch.device,
) -> float:
"""Return binary val AUROC of a freshly-fitted model, or NaN.
Returns ``NaN`` for multi-class or when the val split has only
one class present.
"""
from sklearn.metrics import roc_auc_score
model.eval()
with torch.no_grad():
logits = model(X_val)
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()
y_np = y_val.detach().cpu().numpy()
if probs.shape[1] != 2 or len(np.unique(y_np)) < 2:
return float("nan")
try:
return float(roc_auc_score(y_np, probs[:, 1]))
except ValueError:
return float("nan")
def _write_post_fit_sidecar(self, X_val: torch.Tensor, y_val: torch.Tensor) -> None:
"""Compute val loss + AUROC on the deployed model and write a JSON sidecar."""
from sklearn.metrics import roc_auc_score
log_path = Path(self.metrics_log_path)
sidecar_path = log_path.with_suffix(log_path.suffix + ".post_fit.json")
self.model_.eval()
with torch.no_grad():
logits = self.model_(X_val)
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()
y_np = y_val.detach().cpu().numpy()
ce = nn.CrossEntropyLoss()(logits, y_val)
val_loss = float(ce.item())
val_auroc: float | None = None
n_classes = probs.shape[1]
try:
if n_classes == 2:
val_auroc = float(roc_auc_score(y_np, probs[:, 1]))
else:
val_auroc = float(
roc_auc_score(y_np, probs, multi_class="ovr", average="macro")
)
except ValueError:
pass
if self.ema_decay is not None:
source = "ema"
elif self.swa_start_epoch is not None:
source = "swa"
else:
source = "best_val"
payload = {
"val_loss": val_loss,
"val_auroc": val_auroc,
"weights_source": source,
"n_classes": int(n_classes),
}
sidecar_path.write_text(json.dumps(payload, indent=2))
def _build_metrics_recorder(
self,
model: nn.Module,
train_loader: DataLoader,
X_val_t: torch.Tensor,
y_val_t: torch.Tensor,
device: torch.device,
) -> Callable[[dict[str, float]], None] | None:
"""Return a per-epoch recorder that appends diagnostics to a CSV.
Returns ``None`` (no recording) when ``metrics_log_path`` is unset.
"""
if self.metrics_log_path is None:
return None
log_path = Path(self.metrics_log_path)
log_path.parent.mkdir(parents=True, exist_ok=True)
if log_path.exists():
log_path.unlink()
track_train = bool(self.track_train_metrics)
binary = self.n_classes_ == 2
def _collect_train_val_auroc() -> tuple[float, float]:
from sklearn.metrics import roc_auc_score
if not binary:
import math
return math.nan, math.nan
y_val_np = y_val_t.detach().cpu().numpy()
y_tr_parts: list[np.ndarray] = []
proba_tr_parts: list[np.ndarray] = []
model.eval()
with torch.no_grad():
for xb, yb in train_loader:
xb = xb.to(device, non_blocking=True)
logits = model(xb).detach()
proba = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
proba_tr_parts.append(proba)
y_tr_parts.append(np.asarray(yb).ravel())
val_logits = model(X_val_t).detach()
val_proba = torch.softmax(val_logits, dim=-1)[:, 1].cpu().numpy()
model.train()
y_tr = np.concatenate(y_tr_parts)
proba_tr = np.concatenate(proba_tr_parts)
try:
train_auroc = float(roc_auc_score(y_tr, proba_tr))
except ValueError:
train_auroc = float("nan")
try:
val_auroc = float(roc_auc_score(y_val_np, val_proba))
except ValueError:
val_auroc = float("nan")
return train_auroc, val_auroc
header_written = False
def recorder(payload: dict[str, float]) -> None:
nonlocal header_written
row: dict[str, float | int] = dict(payload)
if track_train:
train_auroc, val_auroc = _collect_train_val_auroc()
row["train_auroc"] = train_auroc
row["val_auroc"] = val_auroc
columns = [
"epoch",
"train_loss",
"val_loss",
"lr",
"mean_grad_norm",
"n_grad_updates",
]
if track_train:
columns += ["train_auroc", "val_auroc"]
with open(log_path, "a") as fh:
if not header_written:
fh.write(",".join(columns) + "\n")
header_written = True
fh.write(",".join(str(row.get(k, "")) for k in columns) + "\n")
return recorder
def _fit_post_hoc_calibration(
self, X_val_t: torch.Tensor, y_val_t: torch.Tensor
) -> None:
"""Collect held-out logits once, then fit threshold / temperature."""
from ..utils.calibration import fit_temperature, tune_threshold
self.model_.eval()
with torch.no_grad():
val_logits = self.model_(X_val_t).detach().cpu()
y_val_np = y_val_t.detach().cpu().numpy()
if bool(self.calibrate_temperature):
self.temperature_ = float(fit_temperature(val_logits, y_val_np))
if bool(self.tune_threshold):
if self.n_classes_ != 2:
self.threshold_ = None
else:
from sklearn.metrics import roc_auc_score
logits_np = val_logits.numpy()
if self.temperature_ is not None:
logits_np = logits_np / float(self.temperature_)
logits_np = logits_np - logits_np.max(axis=1, keepdims=True)
exp = np.exp(logits_np)
proba = exp / exp.sum(axis=1, keepdims=True)
try:
val_auroc = float(roc_auc_score(y_val_np, proba[:, 1]))
except ValueError:
val_auroc = float("nan")
gate = float(self.min_val_auroc_for_threshold_tune)
if np.isfinite(val_auroc) and val_auroc >= gate:
self.threshold_ = float(
tune_threshold(
y_val_np, proba[:, 1], metric=self.threshold_metric
)
)
else:
import logging
logging.getLogger(__name__).info(
"tune_threshold skipped: val AUROC=%.3f < %.2f; "
"falling back to threshold_=0.5",
val_auroc,
gate,
)
self.threshold_ = 0.5
@staticmethod
def _collect_validation(
val_loader: DataLoader, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
xs, ys = [], []
for xb, yb in val_loader:
xs.append(xb)
ys.append(yb)
X_val = torch.cat(xs, dim=0).to(device)
y_val = torch.cat(ys, dim=0).to(device)
return X_val, y_val
def _check_input_dim(self, X: np.ndarray) -> None:
if X.shape[1] != self.input_dim_:
raise ValueError(
f"X has {X.shape[1]} features but estimator was fitted with "
f"input_dim={self.input_dim_}. Retrain or re-bin your data "
"to match the original resolution."
)
def _forward_logits(self, X: Any) -> np.ndarray:
check_is_fitted(self, "model_")
X_np = _to_numpy(X)
self._check_input_dim(X_np)
if getattr(self, "warper_", None) is not None:
from .data import _warp_numpy
X_np = _warp_numpy(self.warper_, X_np)
state = getattr(self, "input_transform_state_", None)
if state is not None and state.get("mode", "none") != "none":
from .data import apply_input_transform
X_np = apply_input_transform(X_np, state)
elif self.standardize and self.feature_mean_ is not None:
from .data import _STD_FLOOR
safe_std = np.maximum(self.feature_std_, _STD_FLOOR).astype(np.float32)
X_np = (X_np - self.feature_mean_) / safe_std
device = self._device_
self.model_.eval()
# Batch inference so large test folds don't OOM on attention
# architectures.
X_t_full = torch.from_numpy(X_np.astype(np.float32)).to(device)
chunk = int(getattr(self, "batch_size", 32))
chunk = max(1, chunk)
logits_chunks: list[np.ndarray] = []
with torch.no_grad():
for start in range(0, X_t_full.shape[0], chunk):
X_t = X_t_full[start : start + chunk]
logits_chunks.append(self.model_(X_t).detach().cpu().numpy())
logits = (
np.concatenate(logits_chunks, axis=0)
if logits_chunks
else np.empty((0, self.n_classes_), dtype=np.float32)
)
if logits.ndim == 1:
logits = logits.reshape(-1, 1)
return logits
[docs]
def predict_proba(self, X: Any) -> np.ndarray:
"""Return softmax class probabilities of shape ``(n_samples, n_classes)``.
Parameters
----------
X : array-like or MaldiSet of shape (n_samples, n_bins)
Spectra to score. Must have the same number of features as
the training matrix.
Returns
-------
ndarray of shape (n_samples, n_classes)
Softmax probabilities that sum to 1 along the class axis.
Raises
------
ValueError
If ``X.shape[1] != input_dim_``.
"""
logits = self._forward_logits(X)
temperature = getattr(self, "temperature_", None)
if temperature is not None:
logits = logits / float(temperature)
logits = logits - logits.max(axis=1, keepdims=True)
exp = np.exp(logits)
return exp / exp.sum(axis=1, keepdims=True)
[docs]
def predict(self, X: Any) -> np.ndarray:
"""Return hard class predictions.
Parameters
----------
X : array-like or MaldiSet of shape (n_samples, n_bins)
Spectra to classify.
Returns
-------
ndarray of shape (n_samples,)
Predicted labels, drawn from :attr:`classes_`.
Notes
-----
For binary classifiers fit with ``tune_threshold=True``, the
decision uses the fitted :attr:`threshold_` on the positive
class probability instead of ``argmax``.
"""
proba = self.predict_proba(X)
threshold = getattr(self, "threshold_", None)
if threshold is not None and proba.shape[1] == 2:
idx = (proba[:, 1] >= float(threshold)).astype(int)
else:
idx = np.argmax(proba, axis=1)
return self.classes_[idx]
[docs]
def score(self, X: Any, y: Any) -> float:
"""Return mean accuracy on ``(X, y)``.
Parameters
----------
X : array-like of shape (n_samples, n_bins)
y : array-like of shape (n_samples,)
Returns
-------
float
Accuracy in ``[0, 1]``.
"""
if hasattr(y, "to_numpy"):
y = y.to_numpy()
y = np.asarray(y).ravel()
preds = self.predict(X)
return float(np.mean(preds == y))
def _hparam_dict(self) -> dict[str, Any]:
params = self.get_params(deep=False)
if isinstance(params.get("device"), torch.device):
params["device"] = str(params["device"])
if isinstance(params.get("class_weight"), np.ndarray):
params["class_weight"] = params["class_weight"].tolist()
params["warping"] = None if params.get("warping") is None else "<provided>"
return params
[docs]
def save(self, path: str | Path) -> None:
"""Persist the fitted estimator to ``path.pt`` + ``path.json``.
The PyTorch state dict is written to ``<path>.pt`` and the
hyperparameters plus fitted metadata to ``<path>.json``. A
single ``.pt`` or ``.json`` suffix on ``path`` is stripped
so ``clf.save("model")`` and ``clf.save("model.pt")`` produce
the same pair of files.
Parameters
----------
path : str or Path
Base path without extension.
"""
check_is_fitted(self, "model_")
base = Path(path)
if base.suffix in {".pt", ".pth", ".json"}:
base = base.with_suffix("")
base.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model_.state_dict(), base.with_suffix(".pt"))
warper = getattr(self, "warper_", None)
warper_path = base.with_suffix(".warper.pkl")
if warper is not None:
import joblib
joblib.dump(warper, warper_path)
elif warper_path.exists():
warper_path.unlink()
meta: dict[str, Any] = {
"class_name": type(self).__name__,
"version": 2,
"hparams": self._hparam_dict(),
"fitted": {
"input_dim_": int(self.input_dim_),
"n_classes_": int(self.n_classes_),
"classes_": np.asarray(self.classes_).tolist(),
"n_features_in_": int(self.n_features_in_),
"feature_mean_": (
None
if self.feature_mean_ is None
else np.asarray(self.feature_mean_).tolist()
),
"feature_std_": (
None
if self.feature_std_ is None
else np.asarray(self.feature_std_).tolist()
),
"threshold_": getattr(self, "threshold_", None),
"temperature_": getattr(self, "temperature_", None),
"has_warper": warper is not None,
"input_transform_state_": _serialise_transform_state(
getattr(self, "input_transform_state_", None)
),
},
}
with open(base.with_suffix(".json"), "w") as fh:
json.dump(meta, fh, indent=2)
[docs]
@classmethod
def load(cls, path: str | Path) -> BaseSpectralClassifier:
"""Load a saved estimator from a ``save()``-produced pair of files.
Parameters
----------
path : str or Path
Base path (``.pt``/``.json`` suffix optional).
Returns
-------
BaseSpectralClassifier
Fitted estimator ready for :meth:`predict` /
:meth:`predict_proba`.
Raises
------
ValueError
If the JSON file identifies a different class from ``cls``.
FileNotFoundError
If either ``.pt`` or ``.json`` is missing.
"""
base = Path(path)
if base.suffix in {".pt", ".pth", ".json"}:
base = base.with_suffix("")
pt_path = base.with_suffix(".pt")
json_path = base.with_suffix(".json")
if not pt_path.exists():
raise FileNotFoundError(pt_path)
if not json_path.exists():
raise FileNotFoundError(json_path)
with open(json_path) as fh:
meta = json.load(fh)
if cls is not BaseSpectralClassifier and meta["class_name"] != cls.__name__:
raise ValueError(
f"Saved class is {meta['class_name']!r} but load() was called "
f"on {cls.__name__!r}."
)
target_cls = cls
if cls is BaseSpectralClassifier:
from .. import (
MaldiCNNClassifier,
MaldiMLPClassifier,
MaldiResNetClassifier,
MaldiTransformerClassifier,
)
registry = {
c.__name__: c
for c in (
MaldiMLPClassifier,
MaldiCNNClassifier,
MaldiResNetClassifier,
MaldiTransformerClassifier,
)
}
if meta["class_name"] not in registry:
raise ValueError(f"Unknown saved class: {meta['class_name']!r}")
target_cls = registry[meta["class_name"]]
hparams = dict(meta["hparams"])
hparams.pop("warping", None)
instance = target_cls(**hparams)
fitted = meta["fitted"]
instance.input_dim_ = int(fitted["input_dim_"])
instance.n_classes_ = int(fitted["n_classes_"])
instance.classes_ = np.asarray(fitted["classes_"])
instance.n_features_in_ = int(fitted["n_features_in_"])
instance.feature_mean_ = (
None
if fitted["feature_mean_"] is None
else np.asarray(fitted["feature_mean_"], dtype=np.float32)
)
instance.feature_std_ = (
None
if fitted["feature_std_"] is None
else np.asarray(fitted["feature_std_"], dtype=np.float32)
)
instance.threshold_ = fitted.get("threshold_")
instance.temperature_ = fitted.get("temperature_")
instance.input_transform_state_ = _deserialise_transform_state(
fitted.get("input_transform_state_")
)
instance.warper_ = None
if fitted.get("has_warper"):
import joblib
warper_path = base.with_suffix(".warper.pkl")
if not warper_path.exists():
raise FileNotFoundError(warper_path)
instance.warper_ = joblib.load(warper_path)
device = resolve_device(instance.device)
model = instance._build_model().to(device)
state = torch.load(pt_path, map_location=device, weights_only=True)
model.load_state_dict(state)
model.eval()
instance.model_ = model
instance._device_ = device
return instance
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "model_")
def __sklearn_tags__(self): # pragma: no cover - sklearn >=1.6 tag plumbing
try:
tags = super().__sklearn_tags__()
except AttributeError:
return None
tags.input_tags.two_d_array = True
tags.input_tags.sparse = False
tags.classifier_tags.multi_class = True
tags.classifier_tags.poor_score = True
tags.non_deterministic = False
return tags
def _more_tags(self) -> dict[str, Any]: # pragma: no cover - sklearn <1.6
return {
"binary_only": False,
"multioutput": False,
"poor_score": True,
"requires_positive_X": False,
"X_types": ["2darray"],
}
__all__ = ["BaseSpectralClassifier", "SpectralDataset"]