MLP Module#
Multilayer perceptron classifier with an optional sigmoid-gated attention layer. Engineering adaptation of a standard MLP with a per-unit gate on the first hidden layer; not a novel architecture.
MaldiMLPClassifier#
- class maldideepkit.MaldiMLPClassifier(input_dim=None, n_classes=2, hidden_dim=512, head_dims=(256, 128), use_attention=True, dropout_high=0.3, dropout_low=0.2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
Bases:
BaseSpectralClassifiersklearn-compatible MLP classifier with optional attention gating.
- Parameters:
hidden_dim (
int) – Width of the projection and attention gate.head_dims (
tuple[int,...]) – Widths of the hidden layers of the classification head.use_attention (
bool) – Toggle the sigmoid-gated attention. WhenFalse, the model is a plain MLP of the same depth.dropout_high (
float) – Dropout after the projection and first head layer.dropout_low (
float) – Dropout before the output logits.input_dim (int | None)
n_classes (int)
learning_rate (float)
weight_decay (float)
grad_clip_norm (float | None)
label_smoothing (float)
loss (str)
focal_gamma (float)
use_amp (bool)
swa_start_epoch (int | None)
tune_threshold (bool)
threshold_metric (str)
calibrate_temperature (bool)
min_val_auroc_for_threshold_tune (float)
use_sam (bool)
sam_rho (float)
batch_size (int)
epochs (int)
early_stopping_patience (int)
val_fraction (float)
warmup_epochs (int)
standardize (bool)
input_transform (str | None)
warping (Any | None)
metrics_log_path (str | Path | None)
track_train_metrics (bool)
augment (Any | None)
mixup_alpha (float)
cutmix_alpha (float)
ema_decay (float | None)
retry_on_val_auroc_below (float | None)
max_retries (int)
device (str | torch.device)
random_state (int)
verbose (bool)
Notes
Every parameter accepted by
BaseSpectralClassifier(e.g.learning_rate,batch_size,epochs,warping,calibrate_temperature,device,random_state, …) is forwarded to the base class. See its docstring for the full list.- Variables:
attention_weights (ndarray or None) – Attention weights from the last
fit()orpredict()forward pass. Shape(n_samples_last_call, hidden_dim). Set toNonewhenuse_attention=False.- Parameters:
input_dim (int | None)
n_classes (int)
hidden_dim (int)
use_attention (bool)
dropout_high (float)
dropout_low (float)
learning_rate (float)
weight_decay (float)
grad_clip_norm (float | None)
label_smoothing (float)
loss (str)
focal_gamma (float)
use_amp (bool)
swa_start_epoch (int | None)
tune_threshold (bool)
threshold_metric (str)
calibrate_temperature (bool)
min_val_auroc_for_threshold_tune (float)
use_sam (bool)
sam_rho (float)
batch_size (int)
epochs (int)
early_stopping_patience (int)
val_fraction (float)
warmup_epochs (int)
standardize (bool)
input_transform (str | None)
warping (Any | None)
metrics_log_path (str | Path | None)
track_train_metrics (bool)
augment (Any | None)
mixup_alpha (float)
cutmix_alpha (float)
ema_decay (float | None)
retry_on_val_auroc_below (float | None)
max_retries (int)
device (str | torch.device)
random_state (int)
verbose (bool)
Examples
>>> import numpy as np >>> from maldideepkit import MaldiMLPClassifier >>> rng = np.random.default_rng(0) >>> X = rng.standard_normal((64, 256)).astype("float32") >>> y = rng.integers(0, 2, size=64) >>> clf = MaldiMLPClassifier(epochs=2, batch_size=16, random_state=0).fit(X, y) >>> clf.predict(X).shape (64,) >>> weights = clf.get_attention_weights(X[:4]) >>> weights.shape (4, 512)
- Parameters:
n_classes (
int)learning_rate (
float)weight_decay (
float)label_smoothing (
float)loss (
str)focal_gamma (
float)use_amp (
bool)tune_threshold (
bool)threshold_metric (
str)calibrate_temperature (
bool)min_val_auroc_for_threshold_tune (
float)use_sam (
bool)sam_rho (
float)batch_size (
int)epochs (
int)early_stopping_patience (
int)val_fraction (
float)warmup_epochs (
int)standardize (
bool)track_train_metrics (
bool)mixup_alpha (
float)cutmix_alpha (
float)max_retries (
int)random_state (
int)verbose (
bool)hidden_dim (int)
use_attention (bool)
dropout_high (float)
dropout_low (float)
- __init__(input_dim=None, n_classes=2, hidden_dim=512, head_dims=(256, 128), use_attention=True, dropout_high=0.3, dropout_low=0.2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
- Parameters:
n_classes (
int)hidden_dim (
int)use_attention (
bool)dropout_high (
float)dropout_low (
float)learning_rate (
float)weight_decay (
float)label_smoothing (
float)loss (
str)focal_gamma (
float)use_amp (
bool)tune_threshold (
bool)threshold_metric (
str)calibrate_temperature (
bool)min_val_auroc_for_threshold_tune (
float)use_sam (
bool)sam_rho (
float)batch_size (
int)epochs (
int)early_stopping_patience (
int)val_fraction (
float)warmup_epochs (
int)standardize (
bool)track_train_metrics (
bool)mixup_alpha (
float)cutmix_alpha (
float)max_retries (
int)random_state (
int)verbose (
bool)
- Return type:
None
- fit(X, y, *, warm_start=False)[source]#
Fit the model and cache attention weights from the final batch.
See
BaseSpectralClassifier.fit()for shared parameters, includingwarm_start.- Parameters:
- Return type:
- get_attention_weights(X)[source]#
Return attention weights for
Xof shape(len(X), hidden_dim).- Parameters:
X (
Any) – Spectra to inspect. Must matchinput_dim_.- Returns:
Sigmoid-gated attention weights.
- Return type:
- Raises:
RuntimeError – If the classifier was built with
use_attention=False.
- classmethod from_spectrum(bin_width, input_dim, **overrides)[source]#
Construct a classifier for a given
(bin_width, input_dim)layout.The MLP is architecturally scale-agnostic, so this factory only forwards
input_dimand any**overrides. Provided for API symmetry with the other classifiers.- Parameters:
- Return type:
- set_fit_request(*, warm_start='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
warm_startparameter infit.self (MaldiMLPClassifier)
- Returns:
self – The updated object.
- Return type:
SpectralAttentionMLP#
Low-level nn.Module wrapped by MaldiMLPClassifier.
Exposed for users embedding the architecture into a larger network.
- class maldideepkit.attention.mlp.SpectralAttentionMLP(input_dim, n_classes=2, hidden_dim=512, head_dims=(256, 128), use_attention=True, dropout_high=0.3, dropout_low=0.2)[source]#
Bases:
ModuleProjection + optional sigmoid-gated attention + deep MLP head.
- Parameters:
input_dim (
int) – Number of input bins. The first linear layer projects this down tohidden_dim.n_classes (
int) – Number of output logits.hidden_dim (
int) – Width of the projection layer and attention gate.head_dims (
tuple[int,...]) – Widths of the hidden layers between the gated representation and the output logits.use_attention (
bool) – IfTrue, apply a sigmoid-gated element-wise attention on the projected features. IfFalse, the model reduces to a plain MLP of the same depth.dropout_high (
float) – Dropout applied after the projection and the first dense layer.dropout_low (
float) – Dropout applied before the output logits.
- Variables:
last_attention (torch.Tensor or None) – Attention weights from the most recent forward pass (
(batch, hidden_dim)).Nonewhenuse_attention=False.
Attention Inspection Example#
import numpy as np
from maldideepkit import MaldiMLPClassifier
rng = np.random.default_rng(0)
X = rng.standard_normal((200, 6000)).astype("float32")
y = rng.integers(0, 2, size=200)
clf = MaldiMLPClassifier(random_state=0).fit(X, y)
# Per-sample attention gates cached at the end of fit:
cached = clf.attention_weights_ # (N, hidden_dim)
# Recompute for arbitrary inputs:
weights = clf.get_attention_weights(X[:10]) # (10, hidden_dim)
# Disable attention to get a plain MLP baseline:
plain = MaldiMLPClassifier(use_attention=False, random_state=0).fit(X, y)