Transformer Module#
1-D Vision Transformer (ViT) adapted to MALDI-TOF binned spectra:
non-overlapping Conv1d patch embedding, learned positional
embedding, optional [CLS] aggregation token, a stack of pre-norm
multi-head self-attention blocks with MLP + LayerScale + stochastic
depth, and a linear classification head over the pooled representation.
Why a plain ViT for MALDI? Every token attends to every other token in every block, so widely-separated m/z peaks interact directly at layer 1 rather than through several stages of local-window merging.
Design deviations from the canonical ImageNet ViT.
embed_dim=64anddepth=6by default (literature ViT-S isembed_dim=384, depth=12); the smaller recipe is calibrated for MALDI-TOF cohort sizes (few thousand spectra) where data efficiency dominates over raw capacity.LayerScale is on by default (
layerscale_init=1e-4). Each block starts as a near-identity map so training must earn each block’s contribution.CLS pooling is opt-in; default is mean pool over patch tokens, which is more robust on small data (no single aggregator token to overfit).
Transformer training recipe baked in as defaults:
lr=3e-4,weight_decay=0.05,grad_clip_norm=1.0,warmup_epochs=5. Without these the attention layers diverge on the first few batches.
MaldiTransformerClassifier#
- class maldideepkit.MaldiTransformerClassifier(input_dim=None, n_classes=2, patch_size=4, embed_dim=64, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1, attention_dropout=0.0, drop_path_rate=0.1, layerscale_init=0.0001, pool='mean', head_dim=128, learning_rate=0.0003, weight_decay=0.05, grad_clip_norm=1.0, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=5, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
Bases:
BaseSpectralClassifiersklearn-compatible 1-D ViT classifier for MALDI-TOF spectra.
- Parameters:
patch_size (
int) – Patch size of the initial Conv1D embedding. Token count isceil(input_dim / patch_size).embed_dim (
int) – Token embedding dimension. Must be divisible bynum_heads.depth (
int) – Number of transformer blocks.num_heads (
int) – Attention heads per block.mlp_ratio (
int) – MLP hidden-dim multiplier inside each block.dropout (
float) – MLP dropout applied inside every block and before the head.attention_dropout (
float) – Attention-matrix dropout.drop_path_rate (
float) – Linearly ramped stochastic-depth rate (0 at block 0, this value at the final block).layerscale_init (
float|None) – LayerScale initial value.Nonedisables LayerScale.pool (
str) – Token aggregation for classification.head_dim (
int) – Width of the hidden dense layer in the classification head.input_dim (int | None)
n_classes (int)
learning_rate (float)
weight_decay (float)
grad_clip_norm (float | None)
label_smoothing (float)
loss (str)
focal_gamma (float)
use_amp (bool)
swa_start_epoch (int | None)
tune_threshold (bool)
threshold_metric (str)
calibrate_temperature (bool)
min_val_auroc_for_threshold_tune (float)
use_sam (bool)
sam_rho (float)
batch_size (int)
epochs (int)
early_stopping_patience (int)
val_fraction (float)
warmup_epochs (int)
standardize (bool)
input_transform (str | None)
warping (Any | None)
metrics_log_path (str | Path | None)
track_train_metrics (bool)
augment (Any | None)
mixup_alpha (float)
cutmix_alpha (float)
ema_decay (float | None)
retry_on_val_auroc_below (float | None)
max_retries (int)
device (str | torch.device)
random_state (int)
verbose (bool)
Notes
Every parameter accepted by
BaseSpectralClassifier(e.g.learning_rate,batch_size,epochs,warping,calibrate_temperature,device,random_state, …) is forwarded to the base class. See its docstring for the full list.Transformer training recipe baked in as defaults:
lr=3e-4,weight_decay=0.05,grad_clip_norm=1.0,warmup_epochs=5.Examples
>>> import numpy as np >>> from maldideepkit import MaldiTransformerClassifier >>> rng = np.random.default_rng(0) >>> X = rng.standard_normal((32, 256)).astype("float32") >>> y = rng.integers(0, 2, size=32) >>> clf = MaldiTransformerClassifier( ... epochs=2, batch_size=8, embed_dim=32, depth=2, ... num_heads=2, patch_size=2, random_state=0, ... ).fit(X, y) >>> clf.predict(X).shape (32,)
- Parameters:
n_classes (
int)learning_rate (
float)weight_decay (
float)label_smoothing (
float)loss (
str)focal_gamma (
float)use_amp (
bool)tune_threshold (
bool)threshold_metric (
str)calibrate_temperature (
bool)min_val_auroc_for_threshold_tune (
float)use_sam (
bool)sam_rho (
float)batch_size (
int)epochs (
int)early_stopping_patience (
int)val_fraction (
float)warmup_epochs (
int)standardize (
bool)track_train_metrics (
bool)mixup_alpha (
float)cutmix_alpha (
float)max_retries (
int)random_state (
int)verbose (
bool)patch_size (int)
embed_dim (int)
depth (int)
num_heads (int)
mlp_ratio (int)
dropout (float)
attention_dropout (float)
drop_path_rate (float)
layerscale_init (float | None)
pool (str)
head_dim (int)
- __init__(input_dim=None, n_classes=2, patch_size=4, embed_dim=64, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1, attention_dropout=0.0, drop_path_rate=0.1, layerscale_init=0.0001, pool='mean', head_dim=128, learning_rate=0.0003, weight_decay=0.05, grad_clip_norm=1.0, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=5, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
- Parameters:
n_classes (
int)patch_size (
int)embed_dim (
int)depth (
int)num_heads (
int)mlp_ratio (
int)dropout (
float)attention_dropout (
float)drop_path_rate (
float)pool (
str)head_dim (
int)learning_rate (
float)weight_decay (
float)label_smoothing (
float)loss (
str)focal_gamma (
float)use_amp (
bool)tune_threshold (
bool)threshold_metric (
str)calibrate_temperature (
bool)min_val_auroc_for_threshold_tune (
float)use_sam (
bool)sam_rho (
float)batch_size (
int)epochs (
int)early_stopping_patience (
int)val_fraction (
float)warmup_epochs (
int)standardize (
bool)track_train_metrics (
bool)mixup_alpha (
float)cutmix_alpha (
float)max_retries (
int)random_state (
int)verbose (
bool)
- Return type:
None
- classmethod from_spectrum(bin_width, input_dim, **overrides)[source]#
Construct a classifier for a given
(bin_width, input_dim)layout.The transformer is architecturally scale-agnostic, so this factory only forwards
input_dimand any**overrides. Provided for API symmetry with the other classifiers.- Parameters:
- Return type:
- set_fit_request(*, warm_start='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
warm_startparameter infit.self (MaldiTransformerClassifier)
- Returns:
self – The updated object.
- Return type:
SpectralTransformer1D#
- class maldideepkit.transformer.transformer.SpectralTransformer1D(input_dim, n_classes=2, patch_size=4, embed_dim=64, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1, attention_dropout=0.0, drop_path_rate=0.1, layerscale_init=0.0001, pool='mean', head_dim=128)[source]#
Bases:
Module1-D Vision Transformer backbone for binned spectra.
- Parameters:
input_dim (
int) – Number of input bins.n_classes (
int) – Number of output logits.patch_size (
int) – Non-overlapping patch width. Token count isceil(input_dim / patch_size).embed_dim (
int) – Token embedding dimension.depth (
int) – Number of transformer blocks.num_heads (
int) – Attention heads per block.embed_dimmust be divisible bynum_heads.mlp_ratio (
int) – MLP hidden-dim multiplier.dropout (
float) – MLP dropout applied inside every block and before the head.attention_dropout (
float) – Attention-matrix dropout.drop_path_rate (
float) – End-of-stack stochastic-depth rate. Linearly interpolated from0at block 0 todrop_path_rateat the final block.layerscale_init (
float|None) – LayerScale initial value.Nonedisables LayerScale.pool (
str) – Aggregation strategy for classification."mean"averages over patch tokens (more robust on small data);"cls"prepends a learned token and uses its output.head_dim (
int) – Width of the hidden dense layer in the classification head.
TransformerBlock#
- class maldideepkit.transformer.transformer.TransformerBlock(dim, num_heads, mlp_ratio=4, dropout=0.0, attention_dropout=0.0, drop_path=0.0, layerscale_init=0.0001)[source]#
Bases:
ModulePre-norm transformer block with LayerScale and stochastic depth.
Residual pattern:
x = x + drop_path(γ_1 * Attn(LN(x))) x = x + drop_path(γ_2 * MLP(LN(x)))
γ_*are per-channel learnable scales initialised near zero so every block starts as an identity map.- Parameters:
dim (
int) – Token dimension.num_heads (
int) – Attention heads.mlp_ratio (
int) – MLP hidden-dim multiplier.dropout (
float) – MLP dropout.attention_dropout (
float) – Attention-matrix dropout.drop_path (
float) – Stochastic-depth probability for this block’s residuals.layerscale_init (
float|None) – Initial value of the LayerScale gammas. Set toNoneto disable LayerScale entirely.
MultiHeadSelfAttention#
- class maldideepkit.transformer.transformer.MultiHeadSelfAttention(dim, num_heads, attention_dropout=0.0, proj_dropout=0.0)[source]#
Bases:
ModuleMulti-head self-attention with QK-norm + memory-efficient SDPA.
QK-normalization applies a per-head
LayerNormto query and key tensors before the scaled-dot-product, bounding the softmax denominator regardless of input scale. Always on (universal stability improvement with negligible compute overhead).- Parameters:
Examples#
Default recipe:
import numpy as np
from maldideepkit import MaldiTransformerClassifier
rng = np.random.default_rng(0)
X = rng.standard_normal((400, 6000)).astype("float32")
y = rng.integers(0, 2, size=400)
clf = MaldiTransformerClassifier(random_state=0).fit(X, y)
Smaller recipe for tiny cohorts:
clf = MaldiTransformerClassifier(
embed_dim=32, depth=4, num_heads=2, patch_size=4,
random_state=0,
)
CLS-pool variant:
clf = MaldiTransformerClassifier(pool="cls", random_state=0)
Disable LayerScale for comparison with the legacy unstable recipe:
clf = MaldiTransformerClassifier(layerscale_init=None, random_state=0)
Auto-scale for a different spectrum layout (patch size is
architecturally scale-agnostic, so from_spectrum just forwards
input_dim):
clf = MaldiTransformerClassifier.from_spectrum(
bin_width=6, input_dim=3000, random_state=0,
)