02 - Model comparison#
Train all four MaldiDeepKit classifiers on the MALDI-Kleb-AI Amikacin task and compare AUROC. See notebook 01 for the dataset cache; first run downloads 370 MB once.
Each classifier exposes a from_spectrum(bin_width, input_dim, **overrides) factory that scales architectural defaults (kernel sizes, etc.) for a given binning.
[1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path.cwd().parent))
from notebooks._demo import binary_labels, load_maldi_kleb_ai
demo = load_maldi_kleb_ai(antibiotic='Amikacin', verbose=True)
X, y = binary_labels(demo)
print(f'X: {X.shape} | prevalence(R): {y.mean():.2%}')
X: (741, 6000) | prevalence(R): 49.80%
[2]:
from sklearn.model_selection import train_test_split
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.25, stratify=y, random_state=0)
input_dim = X.shape[1]
[3]:
from sklearn.metrics import roc_auc_score
from maldideepkit import (
MaldiCNNClassifier,
MaldiMLPClassifier,
MaldiResNetClassifier,
MaldiTransformerClassifier,
)
common = dict(epochs=30, random_state=0)
classifiers = {
'MLP': MaldiMLPClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
'CNN': MaldiCNNClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
'ResNet': MaldiResNetClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
'Transformer': MaldiTransformerClassifier.from_spectrum(bin_width=3, input_dim=input_dim, **common),
}
results = {}
for name, clf in classifiers.items():
clf.fit(X_tr, y_tr)
proba = clf.predict_proba(X_te)
results[name] = {
'accuracy': clf.score(X_te, y_te),
'auroc': roc_auc_score(y_te, proba[:, 1]),
}
print(f'{name:>11s} acc={results[name]["accuracy"]:.3f} auroc={results[name]["auroc"]:.3f}')
MLP acc=0.747 auroc=0.826
CNN acc=0.715 auroc=0.781
ResNet acc=0.500 auroc=0.472
Transformer acc=0.500 auroc=0.499
Note: with only ~555 training spectra the Transformer is comfortably below its convergence regime in 30 epochs. Its training recipe (
lr=3e-4, 5-epoch warmup, weight-decay 0.05) pays off on larger datasets and longer schedules.
Mean-of-probabilities ensemble#
SpectralEnsemble averages predict_proba across fitted members. The shallow MLP / CNN / ResNet families have different inductive biases and often decorrelate enough for the ensemble to beat any individual.
[4]:
from maldideepkit.utils import SpectralEnsemble
ensemble = SpectralEnsemble([
MaldiMLPClassifier(input_dim=input_dim, epochs=30, random_state=0),
MaldiCNNClassifier(input_dim=input_dim, epochs=30, random_state=1),
MaldiResNetClassifier(input_dim=input_dim, epochs=30, random_state=2),
]).fit(X_tr, y_tr)
proba = ensemble.predict_proba(X_te)
print(f' Ensemble acc={(proba.argmax(axis=1) == y_te).mean():.3f} auroc={roc_auc_score(y_te, proba[:, 1]):.3f}')
Ensemble acc=0.710 auroc=0.785