# pyright: basic from __future__ import annotations import torch from data.dataset import ADNIDataset def configure_bayesian_sampling_mode( model: torch.nn.Module, *, stochastic: bool, freeze_batchnorm: bool = True, ) -> None: # if stochastic: # model.train() # if freeze_batchnorm: # for module in model.modules(): # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): # module.eval() # else: model.eval() def _compute_mri_intensity_range( dataset: ADNIDataset, ) -> tuple[float, float, float]: global_min = float("inf") global_max = float("-inf") for idx in range(len(dataset)): mri, _, _, _ = dataset[idx] sample_min = float(mri.detach().min().cpu()) sample_max = float(mri.detach().max().cpu()) if sample_min < global_min: global_min = sample_min if sample_max > global_max: global_max = sample_max intensity_range = global_max - global_min if not np.isfinite(intensity_range) or intensity_range <= 0: intensity_range = 1.0 print( "Computed MRI intensity range across the full dataset: " f"min={global_min:.6g}, max={global_max:.6g}, range={intensity_range:.6g}" ) return float(global_min), float(global_max), float(intensity_range)