Source code for maldideepkit.resnet.resnet

"""1-D ResNet classifier for binned MALDI-TOF spectra.

ResNet-18 residual-block template adapted to 1-D spectral input: a
stem Conv1D, four stages of BasicBlock1D pairs with strided
downsampling between stages, global average pooling, and a single
linear head.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
import torch
from torch import nn

from .._bin_scaling import scale_odd_kernel
from ..base.classifier import BaseSpectralClassifier


[docs] class BasicBlock1D(nn.Module): """Two Conv1D layers with a residual shortcut. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. stride : int, default=1 Stride of the first Conv1D. A stride ``!= 1`` (or a channel mismatch) triggers a 1x1 projection on the shortcut path. kernel_size : int, default=3 Kernel size of both Conv1D layers. """
[docs] def __init__( self, in_channels: int, out_channels: int, stride: int = 1, kernel_size: int = 3, ) -> None: super().__init__() padding = kernel_size // 2 self.conv1 = nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, ) self.bn1 = nn.BatchNorm1d(out_channels) self.conv2 = nn.Conv1d( out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False, ) self.bn2 = nn.BatchNorm1d(out_channels) self.relu = nn.ReLU(inplace=True) if stride != 1 or in_channels != out_channels: self.shortcut: nn.Module = nn.Sequential( nn.Conv1d( in_channels, out_channels, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm1d(out_channels), ) else: self.shortcut = nn.Identity()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the residual block to a ``(N, C, L)`` activation tensor.""" identity = self.shortcut(x) out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out = out + identity return self.relu(out)
[docs] class SpectralResNet1D(nn.Module): """1-D ResNet-18 style backbone with a linear classification head. Parameters ---------- input_dim : int Number of input bins. n_classes : int, default=2 Number of output logits. stem_channels : int, default=32 Output channels of the initial ``Conv1d`` + pool stem. stage_channels : sequence of int, default=(64, 128, 256, 512) Output channels of each stage. The first stage keeps the stem resolution; later stages downsample by 2. blocks_per_stage : sequence of int, default=(2, 2, 2, 2) Number of BasicBlock1D pairs per stage. stem_kernel_size : int, default=7 Kernel size of the stem Conv1D. stem_stride : int, default=1 Stride of the stem Conv1D. Defaults to ``1`` (vs literal ResNet-18's ``2``) to preserve peak-scale features through the first stage on ~6000-bin spectra. block_kernel_size : int, default=7 Kernel size inside every :class:`BasicBlock1D`. Widened from ResNet-18's 3 to give each block enough local context for peak-scale features. use_stem_pool : bool, default=False If ``True``, append the literal ResNet-18 ``MaxPool1d(kernel=3, stride=2)`` after the stem Conv. Defaults to ``False`` because the combined stem-Conv(stride=2) + MaxPool(stride=2) initial 4x downsampling collapses peak-scale features. Set to ``True`` to reproduce the literal ResNet-18 backbone. dropout : float, default=0.2 Dropout applied after global average pooling. """
[docs] def __init__( self, input_dim: int, n_classes: int = 2, stem_channels: int = 32, stage_channels: tuple[int, ...] = (64, 128, 256, 512), blocks_per_stage: tuple[int, ...] = (2, 2, 2, 2), stem_kernel_size: int = 7, stem_stride: int = 1, block_kernel_size: int = 7, use_stem_pool: bool = False, dropout: float = 0.2, ) -> None: super().__init__() if len(stage_channels) != len(blocks_per_stage): raise ValueError( "stage_channels and blocks_per_stage must have the same length." ) stem_layers: list[nn.Module] = [ nn.Conv1d( 1, stem_channels, kernel_size=stem_kernel_size, stride=stem_stride, padding=stem_kernel_size // 2, bias=False, ), nn.BatchNorm1d(stem_channels), nn.ReLU(inplace=True), ] if use_stem_pool: stem_layers.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) self.stem = nn.Sequential(*stem_layers) stages: list[nn.Module] = [] prev = stem_channels for i, (ch, n_blocks) in enumerate( zip(stage_channels, blocks_per_stage, strict=True) ): stride = 1 if i == 0 else 2 layer = [ BasicBlock1D(prev, ch, stride=stride, kernel_size=block_kernel_size) ] layer += [ BasicBlock1D(ch, ch, stride=1, kernel_size=block_kernel_size) for _ in range(n_blocks - 1) ] stages.append(nn.Sequential(*layer)) prev = ch self.stages = nn.Sequential(*stages) self.gap = nn.AdaptiveAvgPool1d(1) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(prev, n_classes) self.input_dim = input_dim
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Map ``(batch, input_dim)`` to ``(batch, n_classes)`` logits.""" x = x.unsqueeze(1) x = self.stem(x) x = self.stages(x) x = self.gap(x).squeeze(-1) x = self.dropout(x) return self.fc(x)
[docs] class MaldiResNetClassifier(BaseSpectralClassifier): """sklearn-compatible 1-D ResNet classifier for MALDI-TOF spectra. Parameters ---------- stem_channels : int, default=32 Output channels of the stem. stage_channels : sequence of int, default=(64, 128, 256, 512) Output channels of each residual stage. blocks_per_stage : sequence of int, default=(2, 2, 2, 2) Number of residual blocks per stage (ResNet-18 topology). stem_kernel_size : int, default=7 Kernel size of the stem Conv1D. Calibrated for ``bin_width=3``; :meth:`from_spectrum` scales it for other bin widths. stem_stride : int, default=1 Stride of the stem Conv1D. Defaults to ``1`` so the first residual stage sees full-resolution input. block_kernel_size : int, default=7 Kernel size inside every :class:`BasicBlock1D`. Widened from ResNet-18's 3 so each block has enough local context for peak-scale features. use_stem_pool : bool, default=False If ``True``, append the literal ResNet-18 MaxPool after the stem Conv. Defaults to ``False`` because the combined stem(stride=2) + MaxPool(stride=2) 4x downsampling is too aggressive for MALDI-TOF peak structure. Set to ``True`` to reproduce the literal ResNet-18 backbone. dropout : float, default=0.2 Dropout before the final linear layer. Notes ----- Every parameter accepted by :class:`~maldideepkit.base.classifier.BaseSpectralClassifier` (e.g. ``learning_rate``, ``batch_size``, ``epochs``, ``warping``, ``calibrate_temperature``, ``device``, ``random_state``, ...) is forwarded to the base class. The ResNet defaults pre-set ``weight_decay=1e-4``, ``grad_clip_norm=1.0``, ``warmup_epochs=5``, and ``input_transform="log1p"``; override any of them at construction time. Defaults deviate from literal ResNet-18 (He et al., 2016) in three ways for MALDI-TOF: ``stem_stride=1`` (was 2), ``block_kernel_size=7`` (was 3), and ``use_stem_pool=False`` (was True). The literal backbone is reproducible via ``stem_stride=2, block_kernel_size=3, use_stem_pool=True``. The final head is a ``Linear(stage_channels[-1], n_classes)`` after global average pooling, so the parameter count is **independent of ``input_dim``**. Examples -------- >>> import numpy as np >>> from maldideepkit import MaldiResNetClassifier >>> rng = np.random.default_rng(0) >>> X = rng.standard_normal((32, 512)).astype("float32") >>> y = rng.integers(0, 2, size=32) >>> clf = MaldiResNetClassifier( ... epochs=2, batch_size=8, stage_channels=(16, 32), ... blocks_per_stage=(1, 1), random_state=0 ... ).fit(X, y) >>> clf.predict(X).shape (32,) """
[docs] def __init__( self, input_dim: int | None = None, n_classes: int = 2, stem_channels: int = 32, stage_channels: tuple[int, ...] = (64, 128, 256, 512), blocks_per_stage: tuple[int, ...] = (2, 2, 2, 2), stem_kernel_size: int = 7, stem_stride: int = 1, block_kernel_size: int = 7, use_stem_pool: bool = False, dropout: float = 0.2, learning_rate: float = 1e-3, weight_decay: float = 1e-4, grad_clip_norm: float | None = 1.0, label_smoothing: float = 0.0, loss: str = "cross_entropy", focal_gamma: float = 2.0, use_amp: bool = False, swa_start_epoch: int | None = None, tune_threshold: bool = False, threshold_metric: str = "balanced_accuracy", calibrate_temperature: bool = False, min_val_auroc_for_threshold_tune: float = 0.6, use_sam: bool = False, sam_rho: float = 0.05, batch_size: int = 32, epochs: int = 100, early_stopping_patience: int = 10, val_fraction: float = 0.1, warmup_epochs: int = 5, standardize: bool = False, input_transform: str | None = "log1p", warping: Any | None = None, metrics_log_path: str | Path | None = None, track_train_metrics: bool = False, augment: Any | None = None, mixup_alpha: float = 0.0, cutmix_alpha: float = 0.0, ema_decay: float | None = None, retry_on_val_auroc_below: float | None = None, max_retries: int = 2, class_weight: str | np.ndarray | list | None = None, device: str | torch.device = "auto", random_state: int = 0, verbose: bool = False, ) -> None: super().__init__( input_dim=input_dim, n_classes=n_classes, learning_rate=learning_rate, weight_decay=weight_decay, grad_clip_norm=grad_clip_norm, label_smoothing=label_smoothing, loss=loss, focal_gamma=focal_gamma, use_amp=use_amp, swa_start_epoch=swa_start_epoch, tune_threshold=tune_threshold, threshold_metric=threshold_metric, calibrate_temperature=calibrate_temperature, min_val_auroc_for_threshold_tune=min_val_auroc_for_threshold_tune, use_sam=use_sam, sam_rho=sam_rho, batch_size=batch_size, epochs=epochs, early_stopping_patience=early_stopping_patience, val_fraction=val_fraction, warmup_epochs=warmup_epochs, standardize=standardize, input_transform=input_transform, warping=warping, metrics_log_path=metrics_log_path, track_train_metrics=track_train_metrics, augment=augment, mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, ema_decay=ema_decay, retry_on_val_auroc_below=retry_on_val_auroc_below, max_retries=max_retries, class_weight=class_weight, device=device, random_state=random_state, verbose=verbose, ) self.stem_channels = stem_channels self.stage_channels = stage_channels self.blocks_per_stage = blocks_per_stage self.stem_kernel_size = stem_kernel_size self.stem_stride = stem_stride self.block_kernel_size = block_kernel_size self.use_stem_pool = use_stem_pool self.dropout = dropout
def _build_model(self) -> nn.Module: return SpectralResNet1D( input_dim=self.input_dim_, n_classes=self.n_classes_, stem_channels=int(self.stem_channels), stage_channels=tuple(self.stage_channels), blocks_per_stage=tuple(self.blocks_per_stage), stem_kernel_size=int(self.stem_kernel_size), stem_stride=int(self.stem_stride), block_kernel_size=int(self.block_kernel_size), use_stem_pool=bool(self.use_stem_pool), dropout=float(self.dropout), )
[docs] @classmethod def from_spectrum( cls, bin_width: int, input_dim: int, **overrides ) -> "MaldiResNetClassifier": """Construct a peak-friendly ResNet for a given spectrum layout. Scales ``stem_kernel_size`` inversely with ``bin_width`` relative to the package reference (``bin_width=3``, ``stem_kernel_size=7``). Also sets ``stem_stride=1`` and ``use_stem_pool=False`` (the class defaults). Any keyword in ``**overrides`` wins over the auto-selected values. """ kwargs: dict[str, Any] = { "input_dim": input_dim, "stem_kernel_size": scale_odd_kernel(bin_width), "stem_stride": 1, "use_stem_pool": False, } kwargs.update(overrides) return cls(**kwargs)