"""MLP classifier with optional sigmoid-gated attention.
The architecture is a dense network with a learned per-feature gate on
the first hidden layer that doubles as an interpretable attention map
over the projected bin representation.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
from sklearn.utils.validation import check_is_fitted
from torch import nn
from ..base.classifier import BaseSpectralClassifier
from ..base.data import _to_numpy
[docs]
class SpectralAttentionMLP(nn.Module):
"""Projection + optional sigmoid-gated attention + deep MLP head.
Parameters
----------
input_dim : int
Number of input bins. The first linear layer projects this down
to :attr:`hidden_dim`.
n_classes : int, default=2
Number of output logits.
hidden_dim : int, default=512
Width of the projection layer and attention gate.
head_dims : sequence of int, default=(256, 128)
Widths of the hidden layers between the gated representation and
the output logits.
use_attention : bool, default=True
If ``True``, apply a sigmoid-gated element-wise attention on the
projected features. If ``False``, the model reduces to a plain
MLP of the same depth.
dropout_high : float, default=0.3
Dropout applied after the projection and the first dense layer.
dropout_low : float, default=0.2
Dropout applied before the output logits.
Attributes
----------
last_attention : torch.Tensor or None
Attention weights from the most recent forward pass
(``(batch, hidden_dim)``). ``None`` when ``use_attention=False``.
"""
[docs]
def __init__(
self,
input_dim: int,
n_classes: int = 2,
hidden_dim: int = 512,
head_dims: tuple[int, ...] = (256, 128),
use_attention: bool = True,
dropout_high: float = 0.3,
dropout_low: float = 0.2,
) -> None:
super().__init__()
self.use_attention = use_attention
self.hidden_dim = hidden_dim
self.proj = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout_high),
)
if use_attention:
self.attn: nn.Module = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.Sigmoid()
)
else:
self.attn = nn.Identity()
head_layers: list[nn.Module] = []
prev = hidden_dim
for i, width in enumerate(head_dims):
head_layers += [
nn.Linear(prev, width),
nn.BatchNorm1d(width),
nn.ReLU(),
nn.Dropout(dropout_high if i == 0 else dropout_low),
]
prev = width
head_layers.append(nn.Linear(prev, n_classes))
self.head = nn.Sequential(*head_layers)
self.last_attention: torch.Tensor | None = None
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Map ``(batch, input_dim)`` to ``(batch, n_classes)`` logits."""
projected = self.proj(x)
if self.use_attention:
weights = self.attn(projected)
self.last_attention = weights.detach()
gated = projected * weights
else:
self.last_attention = None
gated = projected
return self.head(gated)
[docs]
class MaldiMLPClassifier(BaseSpectralClassifier):
"""sklearn-compatible MLP classifier with optional attention gating.
Parameters
----------
hidden_dim : int, default=512
Width of the projection and attention gate.
head_dims : sequence of int, default=(256, 128)
Widths of the hidden layers of the classification head.
use_attention : bool, default=True
Toggle the sigmoid-gated attention. When ``False``, the model
is a plain MLP of the same depth.
dropout_high : float, default=0.3
Dropout after the projection and first head layer.
dropout_low : float, default=0.2
Dropout before the output logits.
Notes
-----
Every parameter accepted by
:class:`~maldideepkit.base.classifier.BaseSpectralClassifier`
(e.g. ``learning_rate``, ``batch_size``, ``epochs``, ``warping``,
``calibrate_temperature``, ``device``, ``random_state``, ...) is
forwarded to the base class. See its docstring for the full list.
Attributes
----------
attention_weights_ : ndarray or None
Attention weights from the last :meth:`fit` or :meth:`predict`
forward pass. Shape ``(n_samples_last_call, hidden_dim)``. Set to
``None`` when ``use_attention=False``.
Examples
--------
>>> import numpy as np
>>> from maldideepkit import MaldiMLPClassifier
>>> rng = np.random.default_rng(0)
>>> X = rng.standard_normal((64, 256)).astype("float32")
>>> y = rng.integers(0, 2, size=64)
>>> clf = MaldiMLPClassifier(epochs=2, batch_size=16, random_state=0).fit(X, y)
>>> clf.predict(X).shape
(64,)
>>> weights = clf.get_attention_weights(X[:4])
>>> weights.shape
(4, 512)
"""
[docs]
def __init__(
self,
input_dim: int | None = None,
n_classes: int = 2,
hidden_dim: int = 512,
head_dims: tuple[int, ...] = (256, 128),
use_attention: bool = True,
dropout_high: float = 0.3,
dropout_low: float = 0.2,
learning_rate: float = 1e-3,
weight_decay: float = 0.0,
grad_clip_norm: float | None = None,
label_smoothing: float = 0.0,
loss: str = "cross_entropy",
focal_gamma: float = 2.0,
use_amp: bool = False,
swa_start_epoch: int | None = None,
tune_threshold: bool = False,
threshold_metric: str = "balanced_accuracy",
calibrate_temperature: bool = False,
min_val_auroc_for_threshold_tune: float = 0.6,
use_sam: bool = False,
sam_rho: float = 0.05,
batch_size: int = 32,
epochs: int = 100,
early_stopping_patience: int = 10,
val_fraction: float = 0.1,
warmup_epochs: int = 0,
standardize: bool = False,
input_transform: str | None = None,
warping: Any | None = None,
metrics_log_path: str | Path | None = None,
track_train_metrics: bool = False,
augment: Any | None = None,
mixup_alpha: float = 0.0,
cutmix_alpha: float = 0.0,
ema_decay: float | None = None,
retry_on_val_auroc_below: float | None = None,
max_retries: int = 2,
class_weight: str | np.ndarray | list | None = None,
device: str | torch.device = "auto",
random_state: int = 0,
verbose: bool = False,
) -> None:
super().__init__(
input_dim=input_dim,
n_classes=n_classes,
learning_rate=learning_rate,
weight_decay=weight_decay,
grad_clip_norm=grad_clip_norm,
label_smoothing=label_smoothing,
loss=loss,
focal_gamma=focal_gamma,
use_amp=use_amp,
swa_start_epoch=swa_start_epoch,
tune_threshold=tune_threshold,
threshold_metric=threshold_metric,
calibrate_temperature=calibrate_temperature,
min_val_auroc_for_threshold_tune=min_val_auroc_for_threshold_tune,
use_sam=use_sam,
sam_rho=sam_rho,
batch_size=batch_size,
epochs=epochs,
early_stopping_patience=early_stopping_patience,
val_fraction=val_fraction,
warmup_epochs=warmup_epochs,
standardize=standardize,
input_transform=input_transform,
warping=warping,
metrics_log_path=metrics_log_path,
track_train_metrics=track_train_metrics,
augment=augment,
mixup_alpha=mixup_alpha,
cutmix_alpha=cutmix_alpha,
ema_decay=ema_decay,
retry_on_val_auroc_below=retry_on_val_auroc_below,
max_retries=max_retries,
class_weight=class_weight,
device=device,
random_state=random_state,
verbose=verbose,
)
self.hidden_dim = hidden_dim
self.head_dims = head_dims
self.use_attention = use_attention
self.dropout_high = dropout_high
self.dropout_low = dropout_low
self.attention_weights_: np.ndarray | None = None
def _build_model(self) -> nn.Module:
return SpectralAttentionMLP(
input_dim=self.input_dim_,
n_classes=self.n_classes_,
hidden_dim=int(self.hidden_dim),
head_dims=tuple(self.head_dims),
use_attention=bool(self.use_attention),
dropout_high=float(self.dropout_high),
dropout_low=float(self.dropout_low),
)
def _forward_logits(self, X: Any) -> np.ndarray:
logits = super()._forward_logits(X)
if self.use_attention and self.model_.last_attention is not None:
self.attention_weights_ = self.model_.last_attention.detach().cpu().numpy()
else:
self.attention_weights_ = None
return logits
[docs]
def fit( # type: ignore[override]
self, X: Any, y: Any, *, warm_start: bool = False
) -> MaldiMLPClassifier:
"""Fit the model and cache attention weights from the final batch.
See :meth:`BaseSpectralClassifier.fit` for shared parameters, including ``warm_start``.
"""
super().fit(X, y, warm_start=warm_start)
if self.use_attention:
X_np = _to_numpy(X)
tail = X_np[: min(len(X_np), 64)]
self._forward_logits(tail)
else:
self.attention_weights_ = None
return self
[docs]
def get_attention_weights(self, X: Any) -> np.ndarray:
"""Return attention weights for ``X`` of shape ``(len(X), hidden_dim)``.
Parameters
----------
X : array-like or MaldiSet of shape (n_samples, n_bins)
Spectra to inspect. Must match ``input_dim_``.
Returns
-------
ndarray of shape (n_samples, hidden_dim)
Sigmoid-gated attention weights.
Raises
------
RuntimeError
If the classifier was built with ``use_attention=False``.
"""
check_is_fitted(self, "model_")
if not self.use_attention:
raise RuntimeError(
"get_attention_weights is only available when use_attention=True."
)
self._forward_logits(X)
if self.attention_weights_ is None:
raise RuntimeError(
"Attention weights were not captured during forward; "
"ensure the model was built with use_attention=True."
)
return self.attention_weights_
[docs]
@classmethod
def from_spectrum(
cls, bin_width: int, input_dim: int, **overrides
) -> "MaldiMLPClassifier":
"""Construct a classifier for a given ``(bin_width, input_dim)`` layout.
The MLP is architecturally scale-agnostic, so this factory
only forwards ``input_dim`` and any ``**overrides``. Provided
for API symmetry with the other classifiers.
"""
del bin_width
kwargs: dict[str, Any] = {"input_dim": input_dim}
kwargs.update(overrides)
return cls(**kwargs)