Source code for maldideepkit.utils.reproducibility

"""Seeding and device-placement helpers for deterministic training."""

from __future__ import annotations

import os
import random

import numpy as np
import torch


[docs] def seed_everything(seed: int, deterministic: bool = False) -> None: """Seed Python, NumPy, and PyTorch (CPU + CUDA) RNGs in one call. Parameters ---------- seed : int Non-negative integer used for every RNG. Also fixes ``PYTHONHASHSEED`` in the current process environment. deterministic : bool, default=False When ``True``, additionally enable PyTorch's deterministic algorithm mode. Sets ``torch.use_deterministic_algorithms(True, warn_only=True)``, ``torch.backends.cudnn.deterministic = True``, ``torch.backends.cudnn.benchmark = False``, and ``CUBLAS_WORKSPACE_CONFIG=:4096:8``. The env-var must be set before the first CUDA context is created. Once enabled, determinism is **sticky** - subsequent plain ``seed_everything(seed)`` calls do not turn it off. """ os.environ["PYTHONHASHSEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) if deterministic: os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") torch.use_deterministic_algorithms(True, warn_only=True) if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs] def resolve_device(device: str | torch.device | None) -> torch.device: """Resolve a user-facing device specifier to a :class:`torch.device`. Parameters ---------- device : {"auto", "cpu", "cuda"} or torch.device or None ``"auto"`` (or ``None``) picks ``cuda`` when available and falls back to ``cpu``. Returns ------- torch.device The resolved device. Raises ------ ValueError If ``device`` is an unknown string. """ if device is None or device == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") if isinstance(device, torch.device): return device if isinstance(device, str): return torch.device(device) raise ValueError(f"Unsupported device specifier: {device!r}")