| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- # pyright: basic
- from __future__ import annotations
- import torch
- from data.dataset import ADNIDataset
- def configure_bayesian_sampling_mode(
- model: torch.nn.Module,
- *,
- stochastic: bool,
- freeze_batchnorm: bool = True,
- ) -> None:
- # if stochastic:
- # model.train()
- # if freeze_batchnorm:
- # for module in model.modules():
- # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- # module.eval()
- # else:
- model.eval()
- def _compute_mri_intensity_range(
- dataset: ADNIDataset,
- ) -> tuple[float, float, float]:
- global_min = float("inf")
- global_max = float("-inf")
- for idx in range(len(dataset)):
- mri, _, _, _ = dataset[idx]
- sample_min = float(mri.detach().min().cpu())
- sample_max = float(mri.detach().max().cpu())
- if sample_min < global_min:
- global_min = sample_min
- if sample_max > global_max:
- global_max = sample_max
- intensity_range = global_max - global_min
- if not np.isfinite(intensity_range) or intensity_range <= 0:
- intensity_range = 1.0
- print(
- "Computed MRI intensity range across the full dataset: "
- f"min={global_min:.6g}, max={global_max:.6g}, range={intensity_range:.6g}"
- )
- return float(global_min), float(global_max), float(intensity_range)
|