Base Module#

Abstract base class and data utilities shared by every MaldiDeepKit classifier. Users implementing a new architecture only need to inherit from BaseSpectralClassifier and override _build_model().

BaseSpectralClassifier#

class maldideepkit.BaseSpectralClassifier(input_dim=None, n_classes=2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#

Bases: ClassifierMixin, BaseEstimator

Abstract base for all MaldiDeepKit classifiers.

Concrete subclasses only need to override _build_model(), which should return a torch.nn.Module that maps an input of shape (batch, input_dim) to logits of shape (batch, n_classes). Everything else (device placement, validation split, early stopping, checkpointing, predict / predict_proba, save / load) is provided here.

Parameters:
  • input_dim (int | None) – Number of input bins. If None, inferred from X at fit() time and stored as input_dim_.

  • n_classes (int) – Number of output classes. Overwritten with the true number of classes found in y at fit() time.

  • learning_rate (float) – Initial learning rate for the optimizer (Adam by default; AdamW when weight_decay > 0).

  • weight_decay (float) – L2 penalty applied via decoupled weight decay. When > 0 the optimizer switches from Adam to AdamW.

  • grad_clip_norm (float | None) – If set, clip gradient global L2 norm to this value before every optimizer step. 1.0 is a common default for transformers.

  • label_smoothing (float) – Label smoothing factor in [0, 1) passed to the loss. Applied to both cross-entropy and focal-loss paths.

  • loss (str) – Classification loss. "focal" uses FocalLoss with gamma=focal_gamma. Good for highly imbalanced problems.

  • focal_gamma (float) – Focal-loss focusing parameter. Ignored when loss="cross_entropy".

  • use_amp (bool) – If True and the resolved device is CUDA, run forward + loss under torch.autocast() and use torch.amp.GradScaler for backward. ~2x wall-time speedup on recent NVIDIA GPUs. On CPU this is a no-op.

  • swa_start_epoch (int | None) – If set, start Stochastic Weight Averaging at this epoch. The SWA average replaces the best-val checkpoint at the end of fit. Typical value: 60-80% of epochs.

  • tune_threshold (bool) – (Binary classification only.) After fit, sweep thresholds on the validation split and store the one that maximises threshold_metric. predict() uses this threshold instead of argmax @ 0.5.

  • threshold_metric (str) – Metric used by tune_threshold.

  • calibrate_temperature (bool) – If True, after fit run LBFGS-based temperature scaling on held-out validation logits (Guo et al. 2017). The fitted temperature is stored as temperature_ and applied in predict_proba() to sharpen / smooth probabilities without changing the argmax.

  • min_val_auroc_for_threshold_tune (float) – Binary-classification guardrail on tune_threshold=True: if the validation AUROC is below this value, the threshold sweep is skipped and threshold_ falls back to 0.5. Set to 0.0 to disable.

  • use_sam (bool) – If True, wrap the base optimizer in SAMOptimizer and run the two-step Sharpness-Aware Minimization update. Doubles forward / backward compute per step; typically helps generalisation on small datasets.

  • sam_rho (float) – Size of the SAM ascent step. Ignored when use_sam=False.

  • batch_size (int) – Training mini-batch size.

  • epochs (int) – Maximum number of training epochs.

  • early_stopping_patience (int) – Number of epochs without validation-loss improvement before training is stopped.

  • val_fraction (float) – Fraction of the training data held out for the internal validation split.

  • warmup_epochs (int) – If positive, linearly ramp each optimizer param group’s learning rate from 0 to its configured target over the first warmup_epochs epochs. Useful for transformer architectures that can diverge at full learning rate during the first few steps.

  • standardize (bool) – Shorthand for input_transform="standardize" (when True) or "none" (when False). Kept for backwards compatibility; input_transform is the modern interface and wins when both are supplied.

  • input_transform (str | None) – One of {"none", "standardize", "log1p", "robust", "log1p+standardize"}. Fit on the (warped) training split only and stored as input_transform_state_; reapplied at predict() / predict_proba() time.

  • warping (Any | None) – Spectral alignment / warping transformer applied before standardization. Fitted on the training split only, then used to transform both splits during training and new data at predict() / predict_proba() time. The fitted transformer is stored as warper_.

  • metrics_log_path (str | Path | None) – If set, write a per-epoch metrics CSV to this path during fit(). One row per epoch with columns epoch, train_loss, val_loss, lr, mean_grad_norm, n_grad_updates (+ train_auroc, val_auroc when track_train_metrics=True).

  • track_train_metrics (bool) – Only used when metrics_log_path is set. If True, after every epoch run a no-grad forward pass over the full training split and record train_auroc + val_auroc alongside the losses. Adds one extra pass per epoch; binary classification only.

  • augment (Optional[Callable[[Tensor], Tensor]]) – Per-batch augmentation applied to training batches only. The usual choice is SpectrumAugment.

  • mixup_alpha (float) – If positive, apply MixUp augmentation per training batch with a Beta(mixup_alpha, mixup_alpha) mixing coefficient. 0.0 disables MixUp. Composable with cutmix_alpha.

  • cutmix_alpha (float) – If positive, apply CutMix augmentation per training batch with a Beta(cutmix_alpha, cutmix_alpha) mixing coefficient. 0.0 disables CutMix.

  • ema_decay (float | None) – If set (typically 0.999), maintain an exponential moving average of model weights during training and use the EMA weights at inference time.

  • retry_on_val_auroc_below (float | None) – Binary-classification guardrail. If set and the post-fit validation AUROC is below this threshold, retrain with a different RNG seed up to max_retries times. Useful for unstable small-data fits.

  • max_retries (int) – Maximum number of automatic refits triggered by retry_on_val_auroc_below. Ignored when that guardrail is unset.

  • class_weight (str | ndarray | list | None) – Per-class weights applied to CrossEntropyLoss. "balanced" uses n_samples / (n_classes * class_count).

  • device (str | device) – Device used for training and inference.

  • random_state (int) – Seeds Python, NumPy, and PyTorch RNGs and the validation split.

  • verbose (bool) – If True, prints one line per training epoch.

Variables:
  • model (torch.nn.Module) – The fitted PyTorch model.

  • classes (ndarray of shape (n_classes,)) – Original class labels seen during fit().

  • input_dim (int) – Resolved number of input features.

  • n_classes (int) – Resolved number of classes.

  • feature_mean (ndarray or None) – Per-feature mean used when standardize=True.

  • feature_std (ndarray or None) – Per-feature std used when standardize=True.

  • n_features_in (int) – Number of features seen at fit() (sklearn convention).

__init__(input_dim=None, n_classes=2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
Parameters:
Return type:

None

fit(X, y, *, warm_start=False)[source]#

Fit the model on (X, y).

Parameters:
  • X (Any) – Training spectra. NumPy arrays, pandas DataFrames, and objects with a DataFrame-like .X attribute are accepted.

  • y (Any) – Integer or string class labels. Re-encoded to 0..n_classes-1 internally; original labels are preserved in classes_.

  • warm_start (bool) – When True and the estimator already has a fitted model_, the underlying torch.nn.Module is reused as the starting point of training instead of being rebuilt from scratch via _build_model(). This unblocks federated learning, continual learning, and fine-tuning workflows that need fit() to resume from the current weights rather than reinitialise. warm_start applies only to the first training attempt; retries triggered by retry_on_val_auroc_below always rebuild via _build_model() (the warm-start weights already failed once). When warm_start=True but no prior model_ exists, falls back silently to a fresh build (sklearn convention).

Returns:

self – The fitted estimator.

Return type:

BaseSpectralClassifier

predict_proba(X)[source]#

Return softmax class probabilities of shape (n_samples, n_classes).

Parameters:

X (Any) – Spectra to score. Must have the same number of features as the training matrix.

Returns:

Softmax probabilities that sum to 1 along the class axis.

Return type:

ndarray

Raises:

ValueError – If X.shape[1] != input_dim_.

predict(X)[source]#

Return hard class predictions.

Parameters:

X (Any) – Spectra to classify.

Returns:

Predicted labels, drawn from classes_.

Return type:

ndarray

Notes

For binary classifiers fit with tune_threshold=True, the decision uses the fitted threshold_ on the positive class probability instead of argmax.

score(X, y)[source]#

Return mean accuracy on (X, y).

Parameters:
Returns:

Accuracy in [0, 1].

Return type:

float

save(path)[source]#

Persist the fitted estimator to path.pt + path.json.

The PyTorch state dict is written to <path>.pt and the hyperparameters plus fitted metadata to <path>.json. A single .pt or .json suffix on path is stripped so clf.save("model") and clf.save("model.pt") produce the same pair of files.

Parameters:

path (str | Path) – Base path without extension.

Return type:

None

classmethod load(path)[source]#

Load a saved estimator from a save()-produced pair of files.

Parameters:

path (str | Path) – Base path (.pt/.json suffix optional).

Returns:

Fitted estimator ready for predict() / predict_proba().

Return type:

BaseSpectralClassifier

Raises:
classmethod __init_subclass__(**kwargs)#

Set the set_{method}_request methods.

This uses PEP-487 [1] to set the set_{method}_request methods. It looks for the information available in the set default values which are set using __metadata_request__* class attributes, or inferred from method signatures.

The __metadata_request__* class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the default None.

References

get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

Get parameters for this estimator.

Parameters:

deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:

params – Parameter names mapped to their values.

Return type:

dict

set_fit_request(*, warm_start='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for warm_start parameter in fit.

  • self (BaseSpectralClassifier)

Returns:

self – The updated object.

Return type:

object

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:

**params (dict) – Estimator parameters.

Returns:

self – Estimator instance.

Return type:

estimator instance

SpectralDataset#

class maldideepkit.SpectralDataset(X, y=None, *, standardize=False, mean=None, std=None)[source]#

Bases: Dataset

PyTorch Dataset wrapping a binned MALDI-TOF feature matrix.

The dataset stores its spectra as a single float32 tensor in memory and optionally standardizes each feature on the fly using statistics computed once at construction time.

Parameters:
  • X (Any) – Feature matrix of shape (n_samples, n_bins). A NumPy array, a pandas DataFrame, or any object with a DataFrame-like .X attribute is accepted.

  • y (Any | None) – Integer class labels of shape (n_samples,). When None (inference usage) the dataset yields only features.

  • standardize (bool) – If True, subtract the per-column mean and divide by the per-column standard deviation computed from X. Columns with zero variance are left untouched.

  • mean (ndarray | None) – Pre-computed per-feature means. Used together with std to apply an external standardization (e.g. one fitted on a training fold). Ignored when standardize=False.

  • std (ndarray | None) – Pre-computed per-feature standard deviations. Ignored when standardize=False.

Variables:
  • X (torch.Tensor) – Stored features as a float32 tensor.

  • y (torch.Tensor or None) – Stored labels as a long tensor, or None for inference.

  • mean (torch.Tensor or None) – Feature-wise mean used for standardization.

  • std (torch.Tensor or None) – Feature-wise standard deviation used for standardization.

__init__(X, y=None, *, standardize=False, mean=None, std=None)[source]#
Parameters:
Return type:

None

SpectralDataset accepts NumPy arrays, pandas DataFrames, and any object with a DataFrame-like .X attribute (e.g. maldiamrkit.MaldiSet):

import numpy as np
import pandas as pd
from maldideepkit import SpectralDataset

ds_array = SpectralDataset(np.zeros((10, 6000)))
ds_frame = SpectralDataset(pd.DataFrame(np.zeros((10, 6000))))

make_loaders#

maldideepkit.make_loaders(X, y, *, batch_size=32, val_size=0.1, random_state=0, standardize=False, input_transform=None, stratify=True, num_workers=0, warper=None)[source]#

Build stratified train / validation DataLoader pairs.

Pipeline order, applied after the train/val split so nothing from the validation split leaks into training statistics:

  1. Spectral warping / alignment (if warper is given): fit on the training split, then transform both splits.

  2. Per-feature standardization (if standardize=True): fit mean/std on the (warped) training split, then apply to both splits.

Parameters:
  • X (Any) – Feature matrix of shape (n_samples, n_bins).

  • y (Any) – Integer class labels of shape (n_samples,).

  • batch_size (int) – Mini-batch size for the training loader.

  • val_size (float) – Fraction of the input held out for validation.

  • random_state (int | None) – Seed for the split.

  • standardize (bool) – Shorthand for input_transform="standardize" (when True) or input_transform="none" (when False). Kept for backwards compatibility; the modern interface is input_transform. Ignored whenever input_transform is given explicitly.

  • input_transform (str | None) – One of {"none", "standardize", "log1p", "robust", "log1p+standardize"}. Fitted on the (warped) training split only and applied to both splits. Overrides standardize when both are given.

  • stratify (bool) – If True and all classes have at least two samples, stratify the split on y. Falls back to random split otherwise.

  • num_workers (int) – DataLoader worker count.

  • warper (Any | None) – Unfitted spectral-alignment transformer with fit(X) -> self + transform(X) -> X. Fitted on the training split only and used to transform both splits. The fitted object is returned in stats["warper"].

Return type:

tuple[DataLoader, DataLoader, dict[str, Any]]

Returns:

  • train_loader (DataLoader) – Shuffling training loader. Drops the last batch when it would contain a single sample (avoids BatchNorm issues).

  • val_loader (DataLoader) – Non-shuffling validation loader.

  • stats (dict) – {"mean": array or None, "std": array or None, "warper": fitted warper or None, "input_transform_state": dict}.

Loader Example#

import numpy as np
from maldideepkit import make_loaders

X = np.random.default_rng(0).standard_normal((200, 6000)).astype("float32")
y = np.random.default_rng(0).integers(0, 2, size=200)

train, val, stats = make_loaders(
    X, y, batch_size=32, val_size=0.1, standardize=True, random_state=0,
)