evaluate_models.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # pyright: basic
  2. from __future__ import annotations
  3. import argparse
  4. from pathlib import Path
  5. from typing import Any
  6. import numpy as np
  7. import pandas as pd
  8. from analysis.analysis_modules import (
  9. run_calibration,
  10. run_longitudinal,
  11. run_performance,
  12. run_physician,
  13. )
  14. from analysis.data_access import load_backend_evaluation, load_clinical_table
  15. from analysis.holdout_evaluation import ensure_backend_netcdf
  16. from analysis.noise_analysis import run_noise_analysis
  17. from analysis.runtime import backend_dir, init_runtime_paths, load_config, write_json
  18. def _parse_args() -> argparse.Namespace:
  19. parser = argparse.ArgumentParser(
  20. description=(
  21. "Run modular evaluation analyses for ensemble and bayesian models. "
  22. "All outputs are written to alnn_rewrite/analysis_output."
  23. )
  24. )
  25. parser.add_argument(
  26. "--backend",
  27. nargs="+",
  28. choices=["ensemble", "bayesian"],
  29. default=["ensemble", "bayesian"],
  30. help="Backends to evaluate.",
  31. )
  32. parser.add_argument(
  33. "--run-name",
  34. default=None,
  35. help="Optional run directory name under analysis_output.",
  36. )
  37. parser.add_argument(
  38. "--threshold-start",
  39. type=float,
  40. default=0.5,
  41. help="Threshold sweep start.",
  42. )
  43. parser.add_argument(
  44. "--threshold-stop",
  45. type=float,
  46. default=0.95,
  47. help="Threshold sweep stop (inclusive).",
  48. )
  49. parser.add_argument(
  50. "--threshold-step",
  51. type=float,
  52. default=0.05,
  53. help="Threshold sweep step.",
  54. )
  55. parser.add_argument(
  56. "--decision-threshold",
  57. type=float,
  58. default=0.5,
  59. help="Decision threshold used in noise analysis.",
  60. )
  61. parser.add_argument(
  62. "--positive-class-index",
  63. type=int,
  64. default=0,
  65. help="Positive class index in both predictions.img_class and labels.label.",
  66. )
  67. parser.add_argument(
  68. "--calibration-bins",
  69. type=int,
  70. default=10,
  71. help="Number of reliability bins used for ECE/MCE.",
  72. )
  73. parser.add_argument(
  74. "--skip-noise",
  75. action="store_true",
  76. help="Skip Gaussian noise sensitivity analysis.",
  77. )
  78. parser.add_argument(
  79. "--noise-sigmas",
  80. nargs="+",
  81. type=float,
  82. default=[
  83. 0.0,
  84. 0.01,
  85. 0.03,
  86. 0.05,
  87. 0.1,
  88. 0.2,
  89. 0.3,
  90. 0.4,
  91. 0.5,
  92. 0.6,
  93. 0.75,
  94. 1.0,
  95. ],
  96. help="Gaussian noise sigmas for sensitivity analysis.",
  97. )
  98. parser.add_argument(
  99. "--bayesian-mc-passes",
  100. type=int,
  101. default=20,
  102. help="MC forward passes for bayesian noise analysis.",
  103. )
  104. return parser.parse_args()
  105. def _threshold_array(start: float, stop: float, step: float) -> np.ndarray:
  106. if step <= 0:
  107. raise ValueError("threshold-step must be > 0")
  108. if stop < start:
  109. raise ValueError("threshold-stop must be >= threshold-start")
  110. # Include stop when it lands on a step boundary.
  111. n = int(round((stop - start) / step))
  112. return np.array([start + i * step for i in range(n + 1)], dtype=float)
  113. def _run_backend(
  114. config: dict[str, Any],
  115. root_dir: Path,
  116. backend: str,
  117. clinical_df: pd.DataFrame,
  118. args: argparse.Namespace,
  119. out_dir: Path,
  120. ) -> dict[str, Any]:
  121. netcdf_path = ensure_backend_netcdf(
  122. config=config,
  123. root_dir=root_dir,
  124. backend=backend,
  125. bayesian_mc_passes=int(args.bayesian_mc_passes),
  126. )
  127. evaluation = load_backend_evaluation(
  128. config=config,
  129. backend=backend,
  130. class_index=int(args.positive_class_index),
  131. )
  132. thresholds = _threshold_array(
  133. start=float(args.threshold_start),
  134. stop=float(args.threshold_stop),
  135. step=float(args.threshold_step),
  136. )
  137. summary: dict[str, Any] = {
  138. "backend": backend,
  139. "netcdf": str(netcdf_path),
  140. "source_file": str(evaluation.source_file),
  141. "uncertainty_metric": evaluation.uncertainty_metric,
  142. }
  143. summary["performance"] = run_performance(
  144. evaluation=evaluation,
  145. output_dir=out_dir,
  146. thresholds=thresholds,
  147. )
  148. summary["calibration"] = run_calibration(
  149. evaluation=evaluation,
  150. output_dir=out_dir,
  151. bins=int(args.calibration_bins),
  152. )
  153. summary["physician"] = run_physician(
  154. evaluation=evaluation,
  155. clinical_df=clinical_df,
  156. output_dir=out_dir,
  157. )
  158. summary["longitudinal"] = run_longitudinal(
  159. evaluation=evaluation,
  160. clinical_df=clinical_df,
  161. output_dir=out_dir,
  162. )
  163. if args.skip_noise:
  164. summary["noise"] = {"skipped": True, "reason": "--skip-noise supplied"}
  165. else:
  166. try:
  167. summary["noise"] = run_noise_analysis(
  168. config=config,
  169. root_dir=Path(__file__).resolve().parents[1],
  170. backend=backend,
  171. output_dir=out_dir,
  172. class_index=int(args.positive_class_index),
  173. noise_sigmas=[float(x) for x in args.noise_sigmas],
  174. threshold=float(args.decision_threshold),
  175. calibration_bins=int(args.calibration_bins),
  176. bayesian_mc_passes=int(args.bayesian_mc_passes),
  177. )
  178. except Exception as exc:
  179. summary["noise"] = {
  180. "skipped": True,
  181. "reason": f"Noise analysis failed: {exc}",
  182. }
  183. write_json(out_dir / "backend_summary.json", summary)
  184. return summary
  185. def main() -> None:
  186. args = _parse_args()
  187. analysis_dir = Path(__file__).resolve().parent
  188. paths = init_runtime_paths(analysis_dir=analysis_dir, run_name=args.run_name)
  189. config = load_config(paths.root_dir)
  190. clinical_df = load_clinical_table(config=config, root_dir=paths.root_dir)
  191. manifest: dict[str, Any] = {
  192. "run_dir": str(paths.run_dir),
  193. "output_root": str(paths.output_root),
  194. "positive_class_index": int(args.positive_class_index),
  195. "threshold_sweep": {
  196. "start": float(args.threshold_start),
  197. "stop": float(args.threshold_stop),
  198. "step": float(args.threshold_step),
  199. },
  200. "calibration_bins": int(args.calibration_bins),
  201. "noise_sigmas": [float(x) for x in args.noise_sigmas],
  202. "backends": {},
  203. }
  204. for backend in args.backend:
  205. out_dir = backend_dir(paths, backend)
  206. manifest["backends"][backend] = _run_backend(
  207. config=config,
  208. root_dir=paths.root_dir,
  209. backend=backend,
  210. clinical_df=clinical_df,
  211. args=args,
  212. out_dir=out_dir,
  213. )
  214. write_json(paths.run_dir / "run_manifest.json", manifest)
  215. print(f"Analysis complete. Results saved to {paths.run_dir}")
  216. if __name__ == "__main__":
  217. main()