Base Module#
Abstract base class and data utilities shared by every MaldiDeepKit
classifier. Users implementing a new architecture only need to inherit
from BaseSpectralClassifier and override
_build_model().
BaseSpectralClassifier#
- class maldideepkit.BaseSpectralClassifier(input_dim=None, n_classes=2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
Bases:
ClassifierMixin,BaseEstimatorAbstract base for all MaldiDeepKit classifiers.
Concrete subclasses only need to override
_build_model(), which should return atorch.nn.Modulethat maps an input of shape(batch, input_dim)to logits of shape(batch, n_classes). Everything else (device placement, validation split, early stopping, checkpointing, predict / predict_proba, save / load) is provided here.- Parameters:
input_dim (
int|None) – Number of input bins. IfNone, inferred fromXatfit()time and stored asinput_dim_.n_classes (
int) – Number of output classes. Overwritten with the true number of classes found inyatfit()time.learning_rate (
float) – Initial learning rate for the optimizer (Adam by default; AdamW whenweight_decay > 0).weight_decay (
float) – L2 penalty applied via decoupled weight decay. When> 0the optimizer switches fromAdamtoAdamW.grad_clip_norm (
float|None) – If set, clip gradient global L2 norm to this value before every optimizer step.1.0is a common default for transformers.label_smoothing (
float) – Label smoothing factor in[0, 1)passed to the loss. Applied to both cross-entropy and focal-loss paths.loss (
str) – Classification loss."focal"usesFocalLosswithgamma=focal_gamma. Good for highly imbalanced problems.focal_gamma (
float) – Focal-loss focusing parameter. Ignored whenloss="cross_entropy".use_amp (
bool) – IfTrueand the resolved device is CUDA, run forward + loss undertorch.autocast()and usetorch.amp.GradScalerfor backward. ~2x wall-time speedup on recent NVIDIA GPUs. On CPU this is a no-op.swa_start_epoch (
int|None) – If set, start Stochastic Weight Averaging at this epoch. The SWA average replaces the best-val checkpoint at the end of fit. Typical value: 60-80% ofepochs.tune_threshold (
bool) – (Binary classification only.) After fit, sweep thresholds on the validation split and store the one that maximisesthreshold_metric.predict()uses this threshold instead ofargmax @ 0.5.threshold_metric (
str) – Metric used bytune_threshold.calibrate_temperature (
bool) – IfTrue, after fit run LBFGS-based temperature scaling on held-out validation logits (Guo et al. 2017). The fitted temperature is stored astemperature_and applied inpredict_proba()to sharpen / smooth probabilities without changing the argmax.min_val_auroc_for_threshold_tune (
float) – Binary-classification guardrail ontune_threshold=True: if the validation AUROC is below this value, the threshold sweep is skipped andthreshold_falls back to0.5. Set to0.0to disable.use_sam (
bool) – IfTrue, wrap the base optimizer inSAMOptimizerand run the two-step Sharpness-Aware Minimization update. Doubles forward / backward compute per step; typically helps generalisation on small datasets.sam_rho (
float) – Size of the SAM ascent step. Ignored whenuse_sam=False.batch_size (
int) – Training mini-batch size.epochs (
int) – Maximum number of training epochs.early_stopping_patience (
int) – Number of epochs without validation-loss improvement before training is stopped.val_fraction (
float) – Fraction of the training data held out for the internal validation split.warmup_epochs (
int) – If positive, linearly ramp each optimizer param group’s learning rate from0to its configured target over the firstwarmup_epochsepochs. Useful for transformer architectures that can diverge at full learning rate during the first few steps.standardize (
bool) – Shorthand forinput_transform="standardize"(when True) or"none"(when False). Kept for backwards compatibility;input_transformis the modern interface and wins when both are supplied.input_transform (
str|None) – One of{"none", "standardize", "log1p", "robust", "log1p+standardize"}. Fit on the (warped) training split only and stored asinput_transform_state_; reapplied atpredict()/predict_proba()time.warping (
Any|None) – Spectral alignment / warping transformer applied before standardization. Fitted on the training split only, then used to transform both splits during training and new data atpredict()/predict_proba()time. The fitted transformer is stored aswarper_.metrics_log_path (
str|Path|None) – If set, write a per-epoch metrics CSV to this path duringfit(). One row per epoch with columnsepoch, train_loss, val_loss, lr, mean_grad_norm, n_grad_updates(+train_auroc, val_aurocwhentrack_train_metrics=True).track_train_metrics (
bool) – Only used whenmetrics_log_pathis set. IfTrue, after every epoch run a no-grad forward pass over the full training split and recordtrain_auroc+val_aurocalongside the losses. Adds one extra pass per epoch; binary classification only.augment (
Optional[Callable[[Tensor],Tensor]]) – Per-batch augmentation applied to training batches only. The usual choice isSpectrumAugment.mixup_alpha (
float) – If positive, apply MixUp augmentation per training batch with a Beta(mixup_alpha,mixup_alpha) mixing coefficient.0.0disables MixUp. Composable withcutmix_alpha.cutmix_alpha (
float) – If positive, apply CutMix augmentation per training batch with a Beta(cutmix_alpha,cutmix_alpha) mixing coefficient.0.0disables CutMix.ema_decay (
float|None) – If set (typically0.999), maintain an exponential moving average of model weights during training and use the EMA weights at inference time.retry_on_val_auroc_below (
float|None) – Binary-classification guardrail. If set and the post-fit validation AUROC is below this threshold, retrain with a different RNG seed up tomax_retriestimes. Useful for unstable small-data fits.max_retries (
int) – Maximum number of automatic refits triggered byretry_on_val_auroc_below. Ignored when that guardrail is unset.class_weight (
str|ndarray|list|None) – Per-class weights applied toCrossEntropyLoss."balanced"usesn_samples / (n_classes * class_count).device (
str|device) – Device used for training and inference.random_state (
int) – Seeds Python, NumPy, and PyTorch RNGs and the validation split.verbose (
bool) – IfTrue, prints one line per training epoch.
- Variables:
model (torch.nn.Module) – The fitted PyTorch model.
classes (ndarray of shape (n_classes,)) – Original class labels seen during
fit().input_dim (int) – Resolved number of input features.
n_classes (int) – Resolved number of classes.
feature_mean (ndarray or None) – Per-feature mean used when
standardize=True.feature_std (ndarray or None) – Per-feature std used when
standardize=True.n_features_in (int) – Number of features seen at
fit()(sklearn convention).
- __init__(input_dim=None, n_classes=2, learning_rate=0.001, weight_decay=0.0, grad_clip_norm=None, label_smoothing=0.0, loss='cross_entropy', focal_gamma=2.0, use_amp=False, swa_start_epoch=None, tune_threshold=False, threshold_metric='balanced_accuracy', calibrate_temperature=False, min_val_auroc_for_threshold_tune=0.6, use_sam=False, sam_rho=0.05, batch_size=32, epochs=100, early_stopping_patience=10, val_fraction=0.1, warmup_epochs=0, standardize=False, input_transform=None, warping=None, metrics_log_path=None, track_train_metrics=False, augment=None, mixup_alpha=0.0, cutmix_alpha=0.0, ema_decay=None, retry_on_val_auroc_below=None, max_retries=2, class_weight=None, device='auto', random_state=0, verbose=False)[source]#
- Parameters:
n_classes (
int)learning_rate (
float)weight_decay (
float)label_smoothing (
float)loss (
str)focal_gamma (
float)use_amp (
bool)tune_threshold (
bool)threshold_metric (
str)calibrate_temperature (
bool)min_val_auroc_for_threshold_tune (
float)use_sam (
bool)sam_rho (
float)batch_size (
int)epochs (
int)early_stopping_patience (
int)val_fraction (
float)warmup_epochs (
int)standardize (
bool)track_train_metrics (
bool)mixup_alpha (
float)cutmix_alpha (
float)max_retries (
int)random_state (
int)verbose (
bool)
- Return type:
None
- fit(X, y, *, warm_start=False)[source]#
Fit the model on
(X, y).- Parameters:
X (
Any) – Training spectra. NumPy arrays, pandas DataFrames, and objects with a DataFrame-like.Xattribute are accepted.y (
Any) – Integer or string class labels. Re-encoded to0..n_classes-1internally; original labels are preserved inclasses_.warm_start (
bool) – WhenTrueand the estimator already has a fittedmodel_, the underlyingtorch.nn.Moduleis reused as the starting point of training instead of being rebuilt from scratch via_build_model(). This unblocks federated learning, continual learning, and fine-tuning workflows that needfit()to resume from the current weights rather than reinitialise.warm_startapplies only to the first training attempt; retries triggered byretry_on_val_auroc_belowalways rebuild via_build_model()(the warm-start weights already failed once). Whenwarm_start=Truebut no priormodel_exists, falls back silently to a fresh build (sklearn convention).
- Returns:
self – The fitted estimator.
- Return type:
- predict_proba(X)[source]#
Return softmax class probabilities of shape
(n_samples, n_classes).- Parameters:
X (
Any) – Spectra to score. Must have the same number of features as the training matrix.- Returns:
Softmax probabilities that sum to 1 along the class axis.
- Return type:
- Raises:
ValueError – If
X.shape[1] != input_dim_.
- predict(X)[source]#
Return hard class predictions.
- Parameters:
X (
Any) – Spectra to classify.- Returns:
Predicted labels, drawn from
classes_.- Return type:
Notes
For binary classifiers fit with
tune_threshold=True, the decision uses the fittedthreshold_on the positive class probability instead ofargmax.
- save(path)[source]#
Persist the fitted estimator to
path.pt+path.json.The PyTorch state dict is written to
<path>.ptand the hyperparameters plus fitted metadata to<path>.json. A single.ptor.jsonsuffix onpathis stripped soclf.save("model")andclf.save("model.pt")produce the same pair of files.
- classmethod load(path)[source]#
Load a saved estimator from a
save()-produced pair of files.- Parameters:
- Returns:
Fitted estimator ready for
predict()/predict_proba().- Return type:
- Raises:
ValueError – If the JSON file identifies a different class from
cls.FileNotFoundError – If either
.ptor.jsonis missing.
- classmethod __init_subclass__(**kwargs)#
Set the
set_{method}_requestmethods.This uses PEP-487 [1] to set the
set_{method}_requestmethods. It looks for the information available in the set default values which are set using__metadata_request__*class attributes, or inferred from method signatures.The
__metadata_request__*class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the defaultNone.References
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequestencapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)#
Get parameters for this estimator.
- set_fit_request(*, warm_start='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
warm_startparameter infit.self (BaseSpectralClassifier)
- Returns:
self – The updated object.
- Return type:
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters:
**params (dict) – Estimator parameters.
- Returns:
self – Estimator instance.
- Return type:
estimator instance
SpectralDataset#
- class maldideepkit.SpectralDataset(X, y=None, *, standardize=False, mean=None, std=None)[source]#
Bases:
DatasetPyTorch
Datasetwrapping a binned MALDI-TOF feature matrix.The dataset stores its spectra as a single float32 tensor in memory and optionally standardizes each feature on the fly using statistics computed once at construction time.
- Parameters:
X (
Any) – Feature matrix of shape(n_samples, n_bins). A NumPy array, a pandas DataFrame, or any object with a DataFrame-like.Xattribute is accepted.y (
Any|None) – Integer class labels of shape(n_samples,). WhenNone(inference usage) the dataset yields only features.standardize (
bool) – IfTrue, subtract the per-column mean and divide by the per-column standard deviation computed fromX. Columns with zero variance are left untouched.mean (
ndarray|None) – Pre-computed per-feature means. Used together withstdto apply an external standardization (e.g. one fitted on a training fold). Ignored whenstandardize=False.std (
ndarray|None) – Pre-computed per-feature standard deviations. Ignored whenstandardize=False.
- Variables:
X (torch.Tensor) – Stored features as a float32 tensor.
y (torch.Tensor or None) – Stored labels as a long tensor, or
Nonefor inference.mean (torch.Tensor or None) – Feature-wise mean used for standardization.
std (torch.Tensor or None) – Feature-wise standard deviation used for standardization.
SpectralDataset accepts NumPy arrays, pandas DataFrames, and any
object with a DataFrame-like .X attribute (e.g.
maldiamrkit.MaldiSet):
import numpy as np
import pandas as pd
from maldideepkit import SpectralDataset
ds_array = SpectralDataset(np.zeros((10, 6000)))
ds_frame = SpectralDataset(pd.DataFrame(np.zeros((10, 6000))))
make_loaders#
- maldideepkit.make_loaders(X, y, *, batch_size=32, val_size=0.1, random_state=0, standardize=False, input_transform=None, stratify=True, num_workers=0, warper=None)[source]#
Build stratified train / validation
DataLoaderpairs.Pipeline order, applied after the train/val split so nothing from the validation split leaks into training statistics:
Spectral warping / alignment (if
warperis given): fit on the training split, then transform both splits.Per-feature standardization (if
standardize=True): fit mean/std on the (warped) training split, then apply to both splits.
- Parameters:
X (
Any) – Feature matrix of shape(n_samples, n_bins).y (
Any) – Integer class labels of shape(n_samples,).batch_size (
int) – Mini-batch size for the training loader.val_size (
float) – Fraction of the input held out for validation.standardize (
bool) – Shorthand forinput_transform="standardize"(when True) orinput_transform="none"(when False). Kept for backwards compatibility; the modern interface isinput_transform. Ignored wheneverinput_transformis given explicitly.input_transform (
str|None) – One of{"none", "standardize", "log1p", "robust", "log1p+standardize"}. Fitted on the (warped) training split only and applied to both splits. Overridesstandardizewhen both are given.stratify (
bool) – IfTrueand all classes have at least two samples, stratify the split ony. Falls back to random split otherwise.num_workers (
int) –DataLoaderworker count.warper (
Any|None) – Unfitted spectral-alignment transformer withfit(X) -> self+transform(X) -> X. Fitted on the training split only and used to transform both splits. The fitted object is returned instats["warper"].
- Return type:
tuple[DataLoader,DataLoader,dict[str,Any]]- Returns:
train_loader (DataLoader) – Shuffling training loader. Drops the last batch when it would contain a single sample (avoids
BatchNormissues).val_loader (DataLoader) – Non-shuffling validation loader.
stats (dict) –
{"mean": array or None, "std": array or None, "warper": fitted warper or None, "input_transform_state": dict}.
Loader Example#
import numpy as np
from maldideepkit import make_loaders
X = np.random.default_rng(0).standard_normal((200, 6000)).astype("float32")
y = np.random.default_rng(0).integers(0, 2, size=200)
train, val, stats = make_loaders(
X, y, batch_size=32, val_size=0.1, standardize=True, random_state=0,
)