MLP Module#

Multilayer perceptron classifier with an optional sigmoid-gated attention layer. Engineering adaptation of a standard MLP with a per-unit gate on the first hidden layer; not a novel architecture.

MaldiMLPClassifier#

class maldideepkit.MaldiMLPClassifier(input_dim=None, n_classes=2, hidden_dim=512, head_dims=(256, 128), use_attention=True, dropout_high=0.3, dropout_low=0.2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#

Bases: BaseSpectralClassifier

sklearn-compatible MLP classifier with optional attention gating.

Parameters:
  • hidden_dim (int) – Width of the projection and attention gate.

  • head_dims (tuple[int, ...]) – Widths of the hidden layers of the classification head.

  • use_attention (bool) – Toggle the sigmoid-gated attention. When False, the model is a plain MLP of the same depth.

  • dropout_high (float) – Dropout after the projection and first head layer.

  • dropout_low (float) – Dropout before the output logits.

  • input_dim (int | None)

  • n_classes (int)

  • learning_rate (float)

  • weight_decay (float)

  • grad_clip_norm (float | None)

  • label_smoothing (float)

  • loss (str)

  • focal_gamma (float)

  • use_amp (bool)

  • swa_start_epoch (int | None)

  • tune_threshold (bool)

  • threshold_metric (str)

  • calibrate_temperature (bool)

  • min_val_auroc_for_threshold_tune (float)

  • use_sam (bool)

  • sam_rho (float)

  • batch_size (int)

  • epochs (int)

  • early_stopping_patience (int)

  • val_fraction (float)

  • warmup_epochs (int)

  • standardize (bool)

  • input_transform (str | None)

  • warping (Any | None)

  • metrics_log_path (str | Path | None)

  • track_train_metrics (bool)

  • augment (Any | None)

  • mixup_alpha (float)

  • cutmix_alpha (float)

  • ema_decay (float | None)

  • retry_on_val_auroc_below (float | None)

  • max_retries (int)

  • class_weight (str | np.ndarray | list | None)

  • device (str | torch.device)

  • random_state (int)

  • verbose (bool)

Notes

Every parameter accepted by BaseSpectralClassifier (e.g. learning_rate, batch_size, epochs, warping, calibrate_temperature, device, random_state, …) is forwarded to the base class. See its docstring for the full list.

Variables:

attention_weights (ndarray or None) – Attention weights from the last fit() or predict() forward pass. Shape (n_samples_last_call, hidden_dim). Set to None when use_attention=False.

Parameters:
  • input_dim (int | None)

  • n_classes (int)

  • hidden_dim (int)

  • head_dims (tuple[int, ...])

  • use_attention (bool)

  • dropout_high (float)

  • dropout_low (float)

  • learning_rate (float)

  • weight_decay (float)

  • grad_clip_norm (float | None)

  • label_smoothing (float)

  • loss (str)

  • focal_gamma (float)

  • use_amp (bool)

  • swa_start_epoch (int | None)

  • tune_threshold (bool)

  • threshold_metric (str)

  • calibrate_temperature (bool)

  • min_val_auroc_for_threshold_tune (float)

  • use_sam (bool)

  • sam_rho (float)

  • batch_size (int)

  • epochs (int)

  • early_stopping_patience (int)

  • val_fraction (float)

  • warmup_epochs (int)

  • standardize (bool)

  • input_transform (str | None)

  • warping (Any | None)

  • metrics_log_path (str | Path | None)

  • track_train_metrics (bool)

  • augment (Any | None)

  • mixup_alpha (float)

  • cutmix_alpha (float)

  • ema_decay (float | None)

  • retry_on_val_auroc_below (float | None)

  • max_retries (int)

  • class_weight (str | np.ndarray | list | None)

  • device (str | torch.device)

  • random_state (int)

  • verbose (bool)

Examples

>>> import numpy as np
>>> from maldideepkit import MaldiMLPClassifier
>>> rng = np.random.default_rng(0)
>>> X = rng.standard_normal((64, 256)).astype("float32")
>>> y = rng.integers(0, 2, size=64)
>>> clf = MaldiMLPClassifier(epochs=2, batch_size=16, random_state=0).fit(X, y)
>>> clf.predict(X).shape
(64,)
>>> weights = clf.get_attention_weights(X[:4])
>>> weights.shape
(4, 512)
Parameters:
__init__(input_dim=None, n_classes=2, hidden_dim=512, head_dims=(256, 128), use_attention=True, dropout_high=0.3, dropout_low=0.2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
Parameters:
Return type:

None

fit(X, y, *, warm_start=False)[source]#

Fit the model and cache attention weights from the final batch.

See BaseSpectralClassifier.fit() for shared parameters, including warm_start.

Parameters:
Return type:

MaldiMLPClassifier

get_attention_weights(X)[source]#

Return attention weights for X of shape (len(X), hidden_dim).

Parameters:

X (Any) – Spectra to inspect. Must match input_dim_.

Returns:

Sigmoid-gated attention weights.

Return type:

ndarray

Raises:

RuntimeError – If the classifier was built with use_attention=False.

classmethod from_spectrum(bin_width, input_dim, **overrides)[source]#

Construct a classifier for a given (bin_width, input_dim) layout.

The MLP is architecturally scale-agnostic, so this factory only forwards input_dim and any **overrides. Provided for API symmetry with the other classifiers.

Parameters:
  • bin_width (int)

  • input_dim (int)

Return type:

MaldiMLPClassifier

set_fit_request(*, warm_start='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for warm_start parameter in fit.

  • self (MaldiMLPClassifier)

Returns:

self – The updated object.

Return type:

object

SpectralAttentionMLP#

Low-level nn.Module wrapped by MaldiMLPClassifier. Exposed for users embedding the architecture into a larger network.

class maldideepkit.attention.mlp.SpectralAttentionMLP(input_dim, n_classes=2, hidden_dim=512, head_dims=(256, 128), use_attention=True, dropout_high=0.3, dropout_low=0.2)[source]#

Bases: Module

Projection + optional sigmoid-gated attention + deep MLP head.

Parameters:
  • input_dim (int) – Number of input bins. The first linear layer projects this down to hidden_dim.

  • n_classes (int) – Number of output logits.

  • hidden_dim (int) – Width of the projection layer and attention gate.

  • head_dims (tuple[int, ...]) – Widths of the hidden layers between the gated representation and the output logits.

  • use_attention (bool) – If True, apply a sigmoid-gated element-wise attention on the projected features. If False, the model reduces to a plain MLP of the same depth.

  • dropout_high (float) – Dropout applied after the projection and the first dense layer.

  • dropout_low (float) – Dropout applied before the output logits.

Variables:

last_attention (torch.Tensor or None) – Attention weights from the most recent forward pass ((batch, hidden_dim)). None when use_attention=False.

__init__(input_dim, n_classes=2, hidden_dim=512, head_dims=(256, 128), use_attention=True, dropout_high=0.3, dropout_low=0.2)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]#

Map (batch, input_dim) to (batch, n_classes) logits.

Parameters:

x (Tensor)

Return type:

Tensor

Attention Inspection Example#

import numpy as np
from maldideepkit import MaldiMLPClassifier

rng = np.random.default_rng(0)
X = rng.standard_normal((200, 6000)).astype("float32")
y = rng.integers(0, 2, size=200)

clf = MaldiMLPClassifier(random_state=0).fit(X, y)

# Per-sample attention gates cached at the end of fit:
cached = clf.attention_weights_                 # (N, hidden_dim)

# Recompute for arbitrary inputs:
weights = clf.get_attention_weights(X[:10])     # (10, hidden_dim)

# Disable attention to get a plain MLP baseline:
plain = MaldiMLPClassifier(use_attention=False, random_state=0).fit(X, y)