{ "cells": [ { "cell_type": "markdown", "id": "intro", "metadata": {}, "source": [ "# 03 - Attention interpretation\n", "\n", "Train `MaldiMLPClassifier` with the default sigmoid-gated attention on real MALDI-TOF spectra and visualize the learned per-unit gates.\n", "\n", "Uses the **MALDI-Kleb-AI** dataset (Rocchi *et al.*, 2026; [Zenodo DOI 10.5281/zenodo.17405072](https://zenodo.org/records/17405072)); see notebook 01 for caching details." ] }, { "cell_type": "code", "execution_count": 1, "id": "load", "metadata": { "execution": { "iopub.execute_input": "2026-04-28T09:49:42.315477Z", "iopub.status.busy": "2026-04-28T09:49:42.315379Z", "iopub.status.idle": "2026-04-28T09:49:59.582878Z", "shell.execute_reply": "2026-04-28T09:49:59.582124Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "83ca5d45ef854665a13701455e94ea03", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing spectra: 0%| | 0/743 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", "im = ax.imshow(weights, aspect='auto', cmap='viridis', vmin=0.4, vmax=0.6)\n", "ax.set_xlabel('hidden unit')\n", "ax.set_ylabel('sample')\n", "ax.set_yticks([0, len(sus), len(sus) + len(res) - 1])\n", "ax.set_yticklabels([labels_view[0], labels_view[len(sus)], labels_view[-1]])\n", "ax.axhline(len(sus) - 0.5, color='white', lw=0.8)\n", "fig.colorbar(im, ax=ax, label='attention weight')\n", "ax.set_title('Sigmoid-gated attention on the projected hidden layer (S top / R bottom)')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "discussion", "metadata": {}, "source": [ "The attention vector is computed on the *projected* hidden representation, not on the raw m/z bins, so the column index is a learned feature axis rather than a Dalton coordinate. Units that stay close to 1 across samples are unconditionally retained; units near 0 are gated off. The variance across samples - and especially across the S/R split - points at the units that are class-discriminative.\n", "\n", "Setting `use_attention=False` recovers a plain MLP of the same depth and disables `get_attention_weights`." ] } ], "metadata": { "kernelspec": { "display_name": "maldideepkit (3.10.12)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": { "0391dce6ec7042b7a5be3f145e64a397": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "07f752059b5f42398bc6065b4256cfc2": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3d8af56e548d429f94e4de8075eb4f42": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_07f752059b5f42398bc6065b4256cfc2", "placeholder": "", "style": "IPY_MODEL_d804be807fd1479e9597b64b8b143104", "tabbable": null, "tooltip": null, "value": "Processing spectra: 100%" } }, "6a2f111e650b4cff8d316d009de1b314": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8b122c7b2cc9417faf55522e0f91aec2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_0391dce6ec7042b7a5be3f145e64a397", "max": 743, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c167895afe894329bdf34e82cd67afbe", "tabbable": null, "tooltip": null, "value": 743 } }, "a248398bd4734185b7326e749d0a5c8b": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ae43e2e23e974e128fb87a98b16d2c65": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_6a2f111e650b4cff8d316d009de1b314", "placeholder": "", "style": "IPY_MODEL_e33b6c1c82624d0db504ca551a8d3e5b", "tabbable": null, "tooltip": null, "value": " 743/743 [00:00<00:00, 4434.39spectrum/s]" } }, "c167895afe894329bdf34e82cd67afbe": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "d804be807fd1479e9597b64b8b143104": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null } }, "e33b6c1c82624d0db504ca551a8d3e5b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null } }, "ec585f296d784e10b87c09e18c0696c5": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_3d8af56e548d429f94e4de8075eb4f42", "IPY_MODEL_8b122c7b2cc9417faf55522e0f91aec2", "IPY_MODEL_ae43e2e23e974e128fb87a98b16d2c65" ], "layout": "IPY_MODEL_a248398bd4734185b7326e749d0a5c8b", "tabbable": null, "tooltip": null } } }, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }