model_utils.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # pyright: basic
  2. from __future__ import annotations
  3. import torch
  4. from data.dataset import ADNIDataset
  5. def configure_bayesian_sampling_mode(
  6. model: torch.nn.Module,
  7. *,
  8. stochastic: bool,
  9. freeze_batchnorm: bool = True,
  10. ) -> None:
  11. # if stochastic:
  12. # model.train()
  13. # if freeze_batchnorm:
  14. # for module in model.modules():
  15. # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  16. # module.eval()
  17. # else:
  18. model.eval()
  19. def _compute_mri_intensity_range(
  20. dataset: ADNIDataset,
  21. ) -> tuple[float, float, float]:
  22. global_min = float("inf")
  23. global_max = float("-inf")
  24. for idx in range(len(dataset)):
  25. mri, _, _, _ = dataset[idx]
  26. sample_min = float(mri.detach().min().cpu())
  27. sample_max = float(mri.detach().max().cpu())
  28. if sample_min < global_min:
  29. global_min = sample_min
  30. if sample_max > global_max:
  31. global_max = sample_max
  32. intensity_range = global_max - global_min
  33. if not np.isfinite(intensity_range) or intensity_range <= 0:
  34. intensity_range = 1.0
  35. print(
  36. "Computed MRI intensity range across the full dataset: "
  37. f"min={global_min:.6g}, max={global_max:.6g}, range={intensity_range:.6g}"
  38. )
  39. return float(global_min), float(global_max), float(intensity_range)