Source code for maldideepkit.transformer.transformer
"""1-D Vision Transformer for binned MALDI-TOF spectra.
A plain ViT backbone adapted to 1-D spectra:
non-overlapping patch embedding, learned positional embedding,
pre-LayerNorm residual blocks with LayerScale and stochastic depth,
global self-attention in every block, and mean-pool aggregation by
default (CLS token optional).
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
from torch import nn
from .._blocks import DropPath, PatchEmbed1D
from ..base.classifier import BaseSpectralClassifier
[docs]
class MultiHeadSelfAttention(nn.Module):
"""Multi-head self-attention with QK-norm + memory-efficient SDPA.
QK-normalization applies a per-head :class:`~torch.nn.LayerNorm`
to query and key tensors before the scaled-dot-product, bounding
the softmax denominator regardless of input scale. Always on
(universal stability improvement with negligible compute overhead).
Parameters
----------
dim : int
Token embedding dimension. Must be divisible by ``num_heads``.
num_heads : int
Number of attention heads.
attention_dropout : float, default=0.0
Dropout applied inside the attention kernel during training.
proj_dropout : float, default=0.0
Dropout applied to the final projection.
"""
[docs]
def __init__(
self,
dim: int,
num_heads: int,
attention_dropout: float = 0.0,
proj_dropout: float = 0.0,
) -> None:
super().__init__()
if dim % num_heads != 0:
raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}.")
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, 3 * dim)
self.q_norm = nn.LayerNorm(self.head_dim)
self.k_norm = nn.LayerNorm(self.head_dim)
self.attention_dropout = float(attention_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_dropout)
[docs]
def forward(
self,
x: torch.Tensor,
key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Global self-attention on ``(B, N, C)`` tokens.
``key_padding_mask`` (optional, shape ``(B, N)``, dtype bool):
``True`` = real token, ``False`` = padding to ignore.
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0)
q = self.q_norm(q)
k = self.k_norm(k)
attn_mask: torch.Tensor | None = None
if key_padding_mask is not None:
attn_mask = torch.zeros((B, 1, 1, N), dtype=q.dtype, device=q.device)
attn_mask = attn_mask.masked_fill(
~key_padding_mask[:, None, None, :], float("-inf")
)
dropout_p = self.attention_dropout if self.training else 0.0
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False
)
out = out.transpose(1, 2).reshape(B, N, C)
return self.proj_drop(self.proj(out))
[docs]
class TransformerBlock(nn.Module):
"""Pre-norm transformer block with LayerScale and stochastic depth.
Residual pattern::
x = x + drop_path(γ_1 * Attn(LN(x)))
x = x + drop_path(γ_2 * MLP(LN(x)))
``γ_*`` are per-channel learnable scales initialised near zero so
every block starts as an identity map.
Parameters
----------
dim : int
Token dimension.
num_heads : int
Attention heads.
mlp_ratio : int, default=4
MLP hidden-dim multiplier.
dropout : float, default=0.0
MLP dropout.
attention_dropout : float, default=0.0
Attention-matrix dropout.
drop_path : float, default=0.0
Stochastic-depth probability for this block's residuals.
layerscale_init : float, default=1e-4
Initial value of the LayerScale gammas. Set to ``None`` to
disable LayerScale entirely.
"""
[docs]
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: int = 4,
dropout: float = 0.0,
attention_dropout: float = 0.0,
drop_path: float = 0.0,
layerscale_init: float | None = 1e-4,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadSelfAttention(
dim, num_heads, attention_dropout=attention_dropout, proj_dropout=dropout
)
self.drop_path1 = DropPath(drop_path)
self.norm2 = nn.LayerNorm(dim)
hidden = int(mlp_ratio * dim)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, dim),
nn.Dropout(dropout),
)
self.drop_path2 = DropPath(drop_path)
self.use_layerscale = layerscale_init is not None
if self.use_layerscale:
self.gamma1 = nn.Parameter(torch.full((dim,), float(layerscale_init)))
self.gamma2 = nn.Parameter(torch.full((dim,), float(layerscale_init)))
[docs]
def forward(
self,
x: torch.Tensor,
key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Run pre-norm attention + MLP residual sub-blocks with optional LayerScale."""
attn_out = self.attn(self.norm1(x), key_padding_mask=key_padding_mask)
mlp_out_src = self.norm2(x)
if self.use_layerscale:
attn_out = attn_out * self.gamma1
x = x + self.drop_path1(attn_out)
mlp_out = self.mlp(mlp_out_src)
if self.use_layerscale:
mlp_out = mlp_out * self.gamma2
x = x + self.drop_path2(mlp_out)
return x
[docs]
class SpectralTransformer1D(nn.Module):
"""1-D Vision Transformer backbone for binned spectra.
Parameters
----------
input_dim : int
Number of input bins.
n_classes : int, default=2
Number of output logits.
patch_size : int, default=4
Non-overlapping patch width. Token count is
``ceil(input_dim / patch_size)``.
embed_dim : int, default=64
Token embedding dimension.
depth : int, default=6
Number of transformer blocks.
num_heads : int, default=4
Attention heads per block. ``embed_dim`` must be divisible by
``num_heads``.
mlp_ratio : int, default=4
MLP hidden-dim multiplier.
dropout : float, default=0.1
MLP dropout applied inside every block and before the head.
attention_dropout : float, default=0.0
Attention-matrix dropout.
drop_path_rate : float, default=0.1
End-of-stack stochastic-depth rate. Linearly interpolated
from ``0`` at block 0 to ``drop_path_rate`` at the final block.
layerscale_init : float or None, default=1e-4
LayerScale initial value. ``None`` disables LayerScale.
pool : {"cls", "mean"}, default="mean"
Aggregation strategy for classification. ``"mean"`` averages
over patch tokens (more robust on small data); ``"cls"``
prepends a learned token and uses its output.
head_dim : int, default=128
Width of the hidden dense layer in the classification head.
"""
[docs]
def __init__(
self,
input_dim: int,
n_classes: int = 2,
patch_size: int = 4,
embed_dim: int = 64,
depth: int = 6,
num_heads: int = 4,
mlp_ratio: int = 4,
dropout: float = 0.1,
attention_dropout: float = 0.0,
drop_path_rate: float = 0.1,
layerscale_init: float | None = 1e-4,
pool: str = "mean",
head_dim: int = 128,
) -> None:
super().__init__()
if pool not in {"mean", "cls"}:
raise ValueError(f"pool must be 'mean' or 'cls'; got {pool!r}.")
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim={embed_dim} must be divisible by num_heads={num_heads}."
)
if depth < 1:
raise ValueError(f"depth must be >= 1; got {depth!r}.")
self.pool = pool
self.patch_size = patch_size
self.embed = PatchEmbed1D(
patch_size=patch_size, in_channels=1, embed_dim=embed_dim
)
n_tokens = -(-input_dim // patch_size)
self.n_tokens = n_tokens
self.cls_token: nn.Parameter | None
if pool == "cls":
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.trunc_normal_(self.cls_token, std=0.02)
pos_len = n_tokens + 1
else:
self.register_parameter("cls_token", None)
pos_len = n_tokens
self.pos_embed = nn.Parameter(torch.zeros(1, pos_len, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.pos_drop = nn.Dropout(dropout)
dpr = [float(x) for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList(
[
TransformerBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
drop_path=dpr[i],
layerscale_init=layerscale_init,
)
for i in range(depth)
]
)
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Sequential(
nn.Linear(embed_dim, head_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(head_dim, n_classes),
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Map ``(batch, input_dim)`` to ``(batch, n_classes)`` logits.
Notes
-----
Inputs whose length is not a multiple of ``patch_size`` are
right-padded with zeros before the patch embedding.
"""
x = x.unsqueeze(1)
pad = (-x.shape[-1]) % self.patch_size
if pad:
x = torch.nn.functional.pad(x, (0, pad))
tokens = self.embed(x)
if self.cls_token is not None:
cls = self.cls_token.expand(tokens.shape[0], -1, -1)
tokens = torch.cat([cls, tokens], dim=1)
tokens = self.pos_drop(tokens + self.pos_embed[:, : tokens.shape[1]])
for block in self.blocks:
tokens = block(tokens)
tokens = self.norm(tokens)
if self.pool == "cls":
pooled = tokens[:, 0]
else:
start = 1 if self.cls_token is not None else 0
pooled = tokens[:, start:].mean(dim=1)
return self.head(pooled)
[docs]
class MaldiTransformerClassifier(BaseSpectralClassifier):
"""sklearn-compatible 1-D ViT classifier for MALDI-TOF spectra.
Parameters
----------
patch_size : int, default=4
Patch size of the initial Conv1D embedding. Token count is
``ceil(input_dim / patch_size)``.
embed_dim : int, default=64
Token embedding dimension. Must be divisible by ``num_heads``.
depth : int, default=6
Number of transformer blocks.
num_heads : int, default=4
Attention heads per block.
mlp_ratio : int, default=4
MLP hidden-dim multiplier inside each block.
dropout : float, default=0.1
MLP dropout applied inside every block and before the head.
attention_dropout : float, default=0.0
Attention-matrix dropout.
drop_path_rate : float, default=0.1
Linearly ramped stochastic-depth rate (0 at block 0, this
value at the final block).
layerscale_init : float or None, default=1e-4
LayerScale initial value. ``None`` disables LayerScale.
pool : {"mean", "cls"}, default="mean"
Token aggregation for classification.
head_dim : int, default=128
Width of the hidden dense layer in the classification head.
Notes
-----
Every parameter accepted by
:class:`~maldideepkit.base.classifier.BaseSpectralClassifier`
(e.g. ``learning_rate``, ``batch_size``, ``epochs``, ``warping``,
``calibrate_temperature``, ``device``, ``random_state``, ...) is
forwarded to the base class. See its docstring for the full list.
Transformer training recipe baked in as defaults: ``lr=3e-4``,
``weight_decay=0.05``, ``grad_clip_norm=1.0``, ``warmup_epochs=5``.
Examples
--------
>>> import numpy as np
>>> from maldideepkit import MaldiTransformerClassifier
>>> rng = np.random.default_rng(0)
>>> X = rng.standard_normal((32, 256)).astype("float32")
>>> y = rng.integers(0, 2, size=32)
>>> clf = MaldiTransformerClassifier(
... epochs=2, batch_size=8, embed_dim=32, depth=2,
... num_heads=2, patch_size=2, random_state=0,
... ).fit(X, y)
>>> clf.predict(X).shape
(32,)
"""
[docs]
def __init__(
self,
input_dim: int | None = None,
n_classes: int = 2,
patch_size: int = 4,
embed_dim: int = 64,
depth: int = 6,
num_heads: int = 4,
mlp_ratio: int = 4,
dropout: float = 0.1,
attention_dropout: float = 0.0,
drop_path_rate: float = 0.1,
layerscale_init: float | None = 1e-4,
pool: str = "mean",
head_dim: int = 128,
learning_rate: float = 3e-4,
weight_decay: float = 0.05,
grad_clip_norm: float | None = 1.0,
label_smoothing: float = 0.0,
loss: str = "cross_entropy",
focal_gamma: float = 2.0,
use_amp: bool = False,
swa_start_epoch: int | None = None,
tune_threshold: bool = False,
threshold_metric: str = "balanced_accuracy",
calibrate_temperature: bool = False,
min_val_auroc_for_threshold_tune: float = 0.6,
use_sam: bool = False,
sam_rho: float = 0.05,
batch_size: int = 32,
epochs: int = 100,
early_stopping_patience: int = 10,
val_fraction: float = 0.1,
warmup_epochs: int = 5,
standardize: bool = False,
input_transform: str | None = None,
warping: Any | None = None,
metrics_log_path: str | Path | None = None,
track_train_metrics: bool = False,
augment: Any | None = None,
mixup_alpha: float = 0.0,
cutmix_alpha: float = 0.0,
ema_decay: float | None = None,
retry_on_val_auroc_below: float | None = None,
max_retries: int = 2,
class_weight: str | np.ndarray | list | None = None,
device: str | torch.device = "auto",
random_state: int = 0,
verbose: bool = False,
) -> None:
super().__init__(
input_dim=input_dim,
n_classes=n_classes,
learning_rate=learning_rate,
weight_decay=weight_decay,
grad_clip_norm=grad_clip_norm,
label_smoothing=label_smoothing,
loss=loss,
focal_gamma=focal_gamma,
use_amp=use_amp,
swa_start_epoch=swa_start_epoch,
tune_threshold=tune_threshold,
threshold_metric=threshold_metric,
calibrate_temperature=calibrate_temperature,
min_val_auroc_for_threshold_tune=min_val_auroc_for_threshold_tune,
use_sam=use_sam,
sam_rho=sam_rho,
batch_size=batch_size,
epochs=epochs,
early_stopping_patience=early_stopping_patience,
val_fraction=val_fraction,
warmup_epochs=warmup_epochs,
standardize=standardize,
input_transform=input_transform,
warping=warping,
metrics_log_path=metrics_log_path,
track_train_metrics=track_train_metrics,
augment=augment,
mixup_alpha=mixup_alpha,
cutmix_alpha=cutmix_alpha,
ema_decay=ema_decay,
retry_on_val_auroc_below=retry_on_val_auroc_below,
max_retries=max_retries,
class_weight=class_weight,
device=device,
random_state=random_state,
verbose=verbose,
)
self.patch_size = patch_size
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.dropout = dropout
self.attention_dropout = attention_dropout
self.drop_path_rate = drop_path_rate
self.layerscale_init = layerscale_init
self.pool = pool
self.head_dim = head_dim
def _build_model(self) -> nn.Module:
return SpectralTransformer1D(
input_dim=self.input_dim_,
n_classes=self.n_classes_,
patch_size=int(self.patch_size),
embed_dim=int(self.embed_dim),
depth=int(self.depth),
num_heads=int(self.num_heads),
mlp_ratio=int(self.mlp_ratio),
dropout=float(self.dropout),
attention_dropout=float(self.attention_dropout),
drop_path_rate=float(self.drop_path_rate),
layerscale_init=self.layerscale_init,
pool=str(self.pool),
head_dim=int(self.head_dim),
)
[docs]
@classmethod
def from_spectrum(
cls, bin_width: int, input_dim: int, **overrides
) -> "MaldiTransformerClassifier":
"""Construct a classifier for a given ``(bin_width, input_dim)`` layout.
The transformer is architecturally scale-agnostic, so this
factory only forwards ``input_dim`` and any ``**overrides``.
Provided for API symmetry with the other classifiers.
"""
del bin_width
kwargs: dict[str, Any] = {"input_dim": input_dim}
kwargs.update(overrides)
return cls(**kwargs)