Transformer Module#

1-D Vision Transformer (ViT) adapted to MALDI-TOF binned spectra: non-overlapping Conv1d patch embedding, learned positional embedding, optional [CLS] aggregation token, a stack of pre-norm multi-head self-attention blocks with MLP + LayerScale + stochastic depth, and a linear classification head over the pooled representation.

Why a plain ViT for MALDI? Every token attends to every other token in every block, so widely-separated m/z peaks interact directly at layer 1 rather than through several stages of local-window merging.

Design deviations from the canonical ImageNet ViT.

  • embed_dim=64 and depth=6 by default (literature ViT-S is embed_dim=384, depth=12); the smaller recipe is calibrated for MALDI-TOF cohort sizes (few thousand spectra) where data efficiency dominates over raw capacity.

  • LayerScale is on by default (layerscale_init=1e-4). Each block starts as a near-identity map so training must earn each block’s contribution.

  • CLS pooling is opt-in; default is mean pool over patch tokens, which is more robust on small data (no single aggregator token to overfit).

  • Transformer training recipe baked in as defaults: lr=3e-4, weight_decay=0.05, grad_clip_norm=1.0, warmup_epochs=5. Without these the attention layers diverge on the first few batches.

MaldiTransformerClassifier#

class maldideepkit.MaldiTransformerClassifier(input_dim=None, n_classes=2, patch_size=4, embed_dim=64, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1, attention_dropout=0.0, drop_path_rate=0.1, layerscale_init=0.0001, pool='mean', head_dim=128, learning_rate=0.0003, weight_decay=0.05, grad_clip_norm=1.0, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=5, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#

Bases: BaseSpectralClassifier

sklearn-compatible 1-D ViT classifier for MALDI-TOF spectra.

Parameters:
  • patch_size (int) – Patch size of the initial Conv1D embedding. Token count is ceil(input_dim / patch_size).

  • embed_dim (int) – Token embedding dimension. Must be divisible by num_heads.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Attention heads per block.

  • mlp_ratio (int) – MLP hidden-dim multiplier inside each block.

  • dropout (float) – MLP dropout applied inside every block and before the head.

  • attention_dropout (float) – Attention-matrix dropout.

  • drop_path_rate (float) – Linearly ramped stochastic-depth rate (0 at block 0, this value at the final block).

  • layerscale_init (float | None) – LayerScale initial value. None disables LayerScale.

  • pool (str) – Token aggregation for classification.

  • head_dim (int) – Width of the hidden dense layer in the classification head.

  • input_dim (int | None)

  • n_classes (int)

  • learning_rate (float)

  • weight_decay (float)

  • grad_clip_norm (float | None)

  • label_smoothing (float)

  • loss (str)

  • focal_gamma (float)

  • use_amp (bool)

  • swa_start_epoch (int | None)

  • tune_threshold (bool)

  • threshold_metric (str)

  • calibrate_temperature (bool)

  • min_val_auroc_for_threshold_tune (float)

  • use_sam (bool)

  • sam_rho (float)

  • batch_size (int)

  • epochs (int)

  • early_stopping_patience (int)

  • val_fraction (float)

  • warmup_epochs (int)

  • standardize (bool)

  • input_transform (str | None)

  • warping (Any | None)

  • metrics_log_path (str | Path | None)

  • track_train_metrics (bool)

  • augment (Any | None)

  • mixup_alpha (float)

  • cutmix_alpha (float)

  • ema_decay (float | None)

  • retry_on_val_auroc_below (float | None)

  • max_retries (int)

  • class_weight (str | np.ndarray | list | None)

  • device (str | torch.device)

  • random_state (int)

  • verbose (bool)

Notes

Every parameter accepted by BaseSpectralClassifier (e.g. learning_rate, batch_size, epochs, warping, calibrate_temperature, device, random_state, …) is forwarded to the base class. See its docstring for the full list.

Transformer training recipe baked in as defaults: lr=3e-4, weight_decay=0.05, grad_clip_norm=1.0, warmup_epochs=5.

Examples

>>> import numpy as np
>>> from maldideepkit import MaldiTransformerClassifier
>>> rng = np.random.default_rng(0)
>>> X = rng.standard_normal((32, 256)).astype("float32")
>>> y = rng.integers(0, 2, size=32)
>>> clf = MaldiTransformerClassifier(
...     epochs=2, batch_size=8, embed_dim=32, depth=2,
...     num_heads=2, patch_size=2, random_state=0,
... ).fit(X, y)
>>> clf.predict(X).shape
(32,)
Parameters:
__init__(input_dim=None, n_classes=2, patch_size=4, embed_dim=64, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1, attention_dropout=0.0, drop_path_rate=0.1, layerscale_init=0.0001, pool='mean', head_dim=128, learning_rate=0.0003, weight_decay=0.05, grad_clip_norm=1.0, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=5, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
Parameters:
Return type:

None

classmethod from_spectrum(bin_width, input_dim, **overrides)[source]#

Construct a classifier for a given (bin_width, input_dim) layout.

The transformer is architecturally scale-agnostic, so this factory only forwards input_dim and any **overrides. Provided for API symmetry with the other classifiers.

Parameters:
  • bin_width (int)

  • input_dim (int)

Return type:

MaldiTransformerClassifier

set_fit_request(*, warm_start='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for warm_start parameter in fit.

  • self (MaldiTransformerClassifier)

Returns:

self – The updated object.

Return type:

object

SpectralTransformer1D#

class maldideepkit.transformer.transformer.SpectralTransformer1D(input_dim, n_classes=2, patch_size=4, embed_dim=64, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1, attention_dropout=0.0, drop_path_rate=0.1, layerscale_init=0.0001, pool='mean', head_dim=128)[source]#

Bases: Module

1-D Vision Transformer backbone for binned spectra.

Parameters:
  • input_dim (int) – Number of input bins.

  • n_classes (int) – Number of output logits.

  • patch_size (int) – Non-overlapping patch width. Token count is ceil(input_dim / patch_size).

  • embed_dim (int) – Token embedding dimension.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Attention heads per block. embed_dim must be divisible by num_heads.

  • mlp_ratio (int) – MLP hidden-dim multiplier.

  • dropout (float) – MLP dropout applied inside every block and before the head.

  • attention_dropout (float) – Attention-matrix dropout.

  • drop_path_rate (float) – End-of-stack stochastic-depth rate. Linearly interpolated from 0 at block 0 to drop_path_rate at the final block.

  • layerscale_init (float | None) – LayerScale initial value. None disables LayerScale.

  • pool (str) – Aggregation strategy for classification. "mean" averages over patch tokens (more robust on small data); "cls" prepends a learned token and uses its output.

  • head_dim (int) – Width of the hidden dense layer in the classification head.

__init__(input_dim, n_classes=2, patch_size=4, embed_dim=64, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1, attention_dropout=0.0, drop_path_rate=0.1, layerscale_init=0.0001, pool='mean', head_dim=128)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]#

Map (batch, input_dim) to (batch, n_classes) logits.

Notes

Inputs whose length is not a multiple of patch_size are right-padded with zeros before the patch embedding.

Parameters:

x (Tensor)

Return type:

Tensor

TransformerBlock#

class maldideepkit.transformer.transformer.TransformerBlock(dim, num_heads, mlp_ratio=4, dropout=0.0, attention_dropout=0.0, drop_path=0.0, layerscale_init=0.0001)[source]#

Bases: Module

Pre-norm transformer block with LayerScale and stochastic depth.

Residual pattern:

x = x + drop_path(γ_1 * Attn(LN(x)))
x = x + drop_path(γ_2 * MLP(LN(x)))

γ_* are per-channel learnable scales initialised near zero so every block starts as an identity map.

Parameters:
  • dim (int) – Token dimension.

  • num_heads (int) – Attention heads.

  • mlp_ratio (int) – MLP hidden-dim multiplier.

  • dropout (float) – MLP dropout.

  • attention_dropout (float) – Attention-matrix dropout.

  • drop_path (float) – Stochastic-depth probability for this block’s residuals.

  • layerscale_init (float | None) – Initial value of the LayerScale gammas. Set to None to disable LayerScale entirely.

__init__(dim, num_heads, mlp_ratio=4, dropout=0.0, attention_dropout=0.0, drop_path=0.0, layerscale_init=0.0001)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x, key_padding_mask=None)[source]#

Run pre-norm attention + MLP residual sub-blocks with optional LayerScale.

Parameters:
Return type:

Tensor

MultiHeadSelfAttention#

class maldideepkit.transformer.transformer.MultiHeadSelfAttention(dim, num_heads, attention_dropout=0.0, proj_dropout=0.0)[source]#

Bases: Module

Multi-head self-attention with QK-norm + memory-efficient SDPA.

QK-normalization applies a per-head LayerNorm to query and key tensors before the scaled-dot-product, bounding the softmax denominator regardless of input scale. Always on (universal stability improvement with negligible compute overhead).

Parameters:
  • dim (int) – Token embedding dimension. Must be divisible by num_heads.

  • num_heads (int) – Number of attention heads.

  • attention_dropout (float) – Dropout applied inside the attention kernel during training.

  • proj_dropout (float) – Dropout applied to the final projection.

__init__(dim, num_heads, attention_dropout=0.0, proj_dropout=0.0)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x, key_padding_mask=None)[source]#

Global self-attention on (B, N, C) tokens.

key_padding_mask (optional, shape (B, N), dtype bool): True = real token, False = padding to ignore.

Parameters:
Return type:

Tensor

Examples#

Default recipe:

import numpy as np
from maldideepkit import MaldiTransformerClassifier

rng = np.random.default_rng(0)
X = rng.standard_normal((400, 6000)).astype("float32")
y = rng.integers(0, 2, size=400)

clf = MaldiTransformerClassifier(random_state=0).fit(X, y)

Smaller recipe for tiny cohorts:

clf = MaldiTransformerClassifier(
    embed_dim=32, depth=4, num_heads=2, patch_size=4,
    random_state=0,
)

CLS-pool variant:

clf = MaldiTransformerClassifier(pool="cls", random_state=0)

Disable LayerScale for comparison with the legacy unstable recipe:

clf = MaldiTransformerClassifier(layerscale_init=None, random_state=0)

Auto-scale for a different spectrum layout (patch size is architecturally scale-agnostic, so from_spectrum just forwards input_dim):

clf = MaldiTransformerClassifier.from_spectrum(
    bin_width=6, input_dim=3000, random_state=0,
)