02 - Model comparison#

Train all four MaldiDeepKit classifiers on the MALDI-Kleb-AI Amikacin task and compare AUROC. See notebook 01 for the dataset cache; first run downloads 370 MB once.

Each classifier exposes a from_spectrum(bin_width, input_dim, **overrides) factory that scales architectural defaults (kernel sizes, etc.) for a given binning.

[1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path.cwd().parent))
from notebooks._demo import binary_labels, load_maldi_kleb_ai

demo = load_maldi_kleb_ai(antibiotic='Amikacin', verbose=True)
X, y = binary_labels(demo)
print(f'X: {X.shape} | prevalence(R): {y.mean():.2%}')
X: (741, 6000) | prevalence(R): 49.80%
[2]:
from sklearn.model_selection import train_test_split

X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.25, stratify=y, random_state=0)
input_dim = X.shape[1]
[3]:
from sklearn.metrics import roc_auc_score

from maldideepkit import (
    MaldiCNNClassifier,
    MaldiMLPClassifier,
    MaldiResNetClassifier,
    MaldiTransformerClassifier,
)

common = dict(epochs=30, random_state=0)

classifiers = {
    'MLP':         MaldiMLPClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
    'CNN':         MaldiCNNClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
    'ResNet':      MaldiResNetClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
    'Transformer': MaldiTransformerClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
}

results = {}
for name, clf in classifiers.items():
    clf.fit(X_tr, y_tr)
    proba = clf.predict_proba(X_te)
    results[name] = {
        'accuracy': clf.score(X_te, y_te),
        'auroc': roc_auc_score(y_te, proba[:, 1]),
    }
    print(f'{name:>11s}  acc={results[name]["accuracy"]:.3f}  auroc={results[name]["auroc"]:.3f}')
        MLP  acc=0.747  auroc=0.826
        CNN  acc=0.715  auroc=0.781
     ResNet  acc=0.500  auroc=0.472
Transformer  acc=0.500  auroc=0.499

Note: with only ~555 training spectra the Transformer is comfortably below its convergence regime in 30 epochs. Its training recipe (lr=3e-4, 5-epoch warmup, weight-decay 0.05) pays off on larger datasets and longer schedules.

Mean-of-probabilities ensemble#

SpectralEnsemble averages predict_proba across fitted members. The shallow MLP / CNN / ResNet families have different inductive biases and often decorrelate enough for the ensemble to beat any individual.

[4]:
from maldideepkit.utils import SpectralEnsemble

ensemble = SpectralEnsemble([
    MaldiMLPClassifier(input_dim=input_dim, epochs=30, random_state=0),
    MaldiCNNClassifier(input_dim=input_dim, epochs=30, random_state=1),
    MaldiResNetClassifier(input_dim=input_dim, epochs=30, random_state=2),
]).fit(X_tr, y_tr)

proba = ensemble.predict_proba(X_te)
print(f'   Ensemble  acc={(proba.argmax(axis=1) == y_te).mean():.3f}  auroc={roc_auc_score(y_te, proba[:, 1]):.3f}')
   Ensemble  acc=0.710  auroc=0.785