Augment Module#
Per-batch data augmentation for binned MALDI-TOF spectra. All
augmentations are callables that take and return a tensor of shape
(batch, n_bins). Wire them into a classifier via the augment=
keyword on BaseSpectralClassifier; they apply
to training batches only and are bypassed during validation and
inference.
The MixUp / CutMix helpers are exposed for users who roll their own
training loop; the classifier wrappers also accept mixup_alpha /
cutmix_alpha keywords that engage them automatically.
SpectrumAugment#
- class maldideepkit.augment.SpectrumAugment(noise_std=0.0, intensity_jitter=0.0, peak_dropout_rate=0.0, mz_shift_max_bins=0, mz_warp_max_bins=0, mz_warp_n_knots=10, blur_sigma=0.0, random_state=None)[source]#
Bases:
objectComposable per-batch spectrum augmentation.
- Parameters:
noise_std (
float) – Standard deviation of additive Gaussian noise.intensity_jitter (
float) – Half-range of the per-sample multiplicative jitter: every sample is scaled by1 + U(-jitter, jitter). Must be in[0, 1).peak_dropout_rate (
float) – Per-bin Bernoulli zero-out probability. Must be in[0, 1).mz_shift_max_bins (
int) – Per-sample global m/z shift, drawn uniformly in[-mz_shift_max_bins, +mz_shift_max_bins]and applied withtorch.roll(). Units are bins. Must be non-negative.mz_warp_max_bins (
int) – Peak amplitude (in bins) of a smooth cubic-spline warp of the m/z axis.0disables the warp. Runs on CPU. A warning is emitted when the amplitude exceeds 5 % ofn_bins, since beyond that the boundary clipping starts to dominate the augmentation distribution. Must be non-negative.mz_warp_n_knots (
int) – Number of interior spline control points for the m/z warp. Only used whenmz_warp_max_bins > 0. Must be non-negative.blur_sigma (
float) – Standard deviation (in bins) of a 1-D Gaussian blur along the m/z axis. Zero disables the blur.random_state (
int|None) – If provided, the transform is seeded for deterministic batches. WhenNone(default), PyTorch’s global RNG is used.
MixUp / CutMix helpers#
- maldideepkit.augment.apply_mixup(x, y_oh, alpha, generator=None)[source]#
Mixup: convex-combine two random permutations of the batch.
- Parameters:
x (
Tensor) – Feature tensor of shape(batch, n_bins).y_oh (
Tensor) – One-hot (or soft) target tensor of shape(batch, n_classes).alpha (
float) – Beta-distribution parameter (Beta(alpha, alpha)). Typical values 0.1-0.4 for tabular-ish inputs. Must be> 0.generator (
Generator|None) – Seeded RNG for reproducibility.
- Returns:
(x_mixed, y_mixed)with the same shapes as the inputs.- Return type:
- maldideepkit.augment.apply_cutmix(x, y_oh, alpha, generator=None)[source]#
CutMix on 1-D spectra: splice a contiguous m/z window.
A window of length
w = round(n_bins * (1 - lam))is drawn uniformly along the m/z axis and copied from the shuffled sample into the original. Labels are mixed by the window fraction.- Parameters:
x (
Tensor) – Feature tensor of shape(batch, n_bins).y_oh (
Tensor) – One-hot (or soft) target tensor of shape(batch, n_classes).alpha (
float) – Beta-distribution parameter (Beta(alpha, alpha)). Typical value 1.0 (uniform over window fractions). Must be> 0.generator (
Generator|None) – Seeded RNG for reproducibility.
- Returns:
(x_mixed, y_mixed)with the same shapes as the inputs.- Return type:
Example#
import numpy as np
from maldideepkit import MaldiCNNClassifier
from maldideepkit.augment import SpectrumAugment
rng = np.random.default_rng(0)
X = rng.standard_normal((200, 6000)).astype("float32")
y = rng.integers(0, 2, size=200)
augment = SpectrumAugment(
noise_std=0.02,
intensity_jitter=0.05,
mz_shift_max_bins=2,
random_state=0,
)
clf = MaldiCNNClassifier(augment=augment, random_state=0).fit(X, y)