evaluate_models.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # pyright: basic
  2. from __future__ import annotations
  3. import argparse
  4. from pathlib import Path
  5. from typing import Any
  6. import pandas as pd
  7. from tqdm.auto import tqdm
  8. from .analysis_modules import (
  9. run_calibration,
  10. run_longitudinal,
  11. run_performance,
  12. run_physician,
  13. )
  14. from .data_access import load_backend_evaluation, load_clinical_table
  15. from .dataset_summary import run_dataset_summary
  16. from .defaults import (
  17. DEFAULT_BACKENDS,
  18. DEFAULT_BAYESIAN_MC_PASSES,
  19. DEFAULT_CALIBRATION_BINS,
  20. DEFAULT_DECISION_THRESHOLD,
  21. DEFAULT_POSITIVE_CLASS_INDEX,
  22. noise_factor_grid,
  23. threshold_grid,
  24. )
  25. from .holdout_evaluation import ensure_backend_netcdf
  26. from .longitudinal_audit import run_longitudinal_breakdown_audit
  27. from .noise_analysis import run_noise_analysis
  28. from .noise_correlation import run_noise_accuracy_uncertainty_analysis
  29. from .runtime import backend_dir, init_runtime_paths, load_config, write_json
  30. def _plot_description(filename: str) -> str:
  31. descriptions = {
  32. "performance_threshold_accuracy.png": "Accuracy as the decision threshold varies.",
  33. "performance_threshold_f1.png": "F1 score as the decision threshold varies.",
  34. "performance_threshold_accuracy_f1.png": "Accuracy and F1 shown side-by-side as the decision threshold varies.",
  35. "performance_uncertainty_cutoff_accuracy.png": "Accuracy while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  36. "performance_uncertainty_cutoff_f1.png": "F1 score while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  37. "performance_uncertainty_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across uncertainty-cutoff restriction levels.",
  38. "performance_uncertainty_percentile_cutoff_accuracy.png": "Accuracy from least to most restricted percentile-wise subset selection.",
  39. "performance_uncertainty_percentile_cutoff_f1.png": "F1 score from least to most restricted percentile-wise subset selection.",
  40. "performance_uncertainty_percentile_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across percentile-floor restriction levels.",
  41. "calibration_reliability.png": "Reliability diagram comparing predicted probability to empirical outcome frequency.",
  42. "physician_confidence_boxplot.png": "Confidence grouped by physician confidence ratings.",
  43. "physician_std_boxplot.png": "Standard deviation grouped by physician confidence ratings.",
  44. "physician_predictive_entropy_boxplot.png": "Predictive uncertainty grouped by physician confidence ratings.",
  45. "longitudinal_cohort_confidence.png": "Longitudinal cohort comparison using confidence.",
  46. "longitudinal_cohort_std.png": "Longitudinal cohort comparison using standard deviation.",
  47. "longitudinal_cohort_predictive_entropy.png": "Longitudinal cohort comparison using predictive uncertainty.",
  48. "noise_sensitivity_accuracy.png": "Accuracy trend across increasing Gaussian noise factors.",
  49. "noise_sensitivity_f1.png": "F1 trend across increasing Gaussian noise factors.",
  50. "noise_sensitivity_accuracy_f1.png": "Accuracy and F1 shown side-by-side across increasing Gaussian noise factors.",
  51. "noise_confidence.png": "Confidence trend across increasing Gaussian noise factors.",
  52. "noise_standard_deviation.png": "Standard deviation trend across increasing Gaussian noise factors.",
  53. "noise_confidence_standard_deviation.png": "Confidence and standard deviation shown side-by-side across increasing Gaussian noise factors.",
  54. "noise_predictive_uncertainty.png": "Predictive uncertainty trend across increasing Gaussian noise factors.",
  55. "noise_confidence_predictive_uncertainty.png": "Confidence and predictive uncertainty shown side-by-side across increasing Gaussian noise factors.",
  56. "noise_accuracy_uncertainty_2d.png": "2D uncertainty-vs-accuracy relationship with linear fit (noise factor encoded by color).",
  57. "ensemble_noise_examples.png": "Representative noisy image slices across selected Gaussian noise factors.",
  58. "bayesian_noise_examples.png": "Representative noisy image slices across selected Gaussian noise factors.",
  59. "ensemble_clean_scan_example.png": "Example clean scan image with no added noise.",
  60. "bayesian_clean_scan_example.png": "Example clean scan image with no added noise.",
  61. }
  62. return descriptions.get(filename, "Generated analysis plot.")
  63. def _write_backend_plot_report(backend: str, out_dir: Path) -> Path:
  64. plots_dir = out_dir / "plots"
  65. images = sorted(plots_dir.rglob("*.png")) if plots_dir.exists() else []
  66. report_path = out_dir / "plots_report.md"
  67. lines = [
  68. f"# {backend.title()} Analysis Plot Report",
  69. "",
  70. "This document lists generated analysis plots with brief descriptions.",
  71. "",
  72. ]
  73. if not images:
  74. lines.append("No plot images were generated for this backend run.")
  75. else:
  76. for image_path in images:
  77. rel = image_path.relative_to(out_dir).as_posix()
  78. title = image_path.stem.replace("_", " ").title()
  79. lines.append(f"## {title}")
  80. lines.append(_plot_description(image_path.name))
  81. lines.append("")
  82. lines.append(f"![{title}]({rel})")
  83. lines.append("")
  84. report_path.write_text("\n".join(lines), encoding="utf-8")
  85. return report_path
  86. def _parse_args() -> argparse.Namespace:
  87. parser = argparse.ArgumentParser(
  88. description=(
  89. "Run modular evaluation analyses for ensemble and bayesian models. "
  90. "All outputs are written to alnn_rewrite/analysis_output."
  91. )
  92. )
  93. parser.add_argument(
  94. "--backend",
  95. nargs="+",
  96. choices=["ensemble", "bayesian"],
  97. default=DEFAULT_BACKENDS,
  98. help="Backends to evaluate.",
  99. )
  100. parser.add_argument(
  101. "--run-name",
  102. default=None,
  103. help="Optional run directory name under analysis_output.",
  104. )
  105. parser.add_argument(
  106. "--skip-noise",
  107. action="store_true",
  108. help="Skip Gaussian noise sensitivity analysis.",
  109. )
  110. parser.add_argument(
  111. "--longitudinal-breakdown-only",
  112. action="store_true",
  113. help=(
  114. "Run only longitudinal cohort breakdown audit from existing model "
  115. "evaluation outputs (no full analysis rerun)."
  116. ),
  117. )
  118. parser.add_argument(
  119. "--noise-correlation-only",
  120. action="store_true",
  121. help=(
  122. "Run only the noise uncertainty-vs-accuracy correlation/regression "
  123. "analysis from an existing noise_sensitivity.csv per backend."
  124. ),
  125. )
  126. parser.add_argument(
  127. "--dataset-summary-only",
  128. action="store_true",
  129. help=(
  130. "Generate only dataset composition summary documentation "
  131. "(overall and train/validation/test class breakdown)."
  132. ),
  133. )
  134. args = parser.parse_args()
  135. only_modes = [
  136. bool(args.longitudinal_breakdown_only),
  137. bool(args.noise_correlation_only),
  138. bool(args.dataset_summary_only),
  139. ]
  140. if sum(only_modes) > 1:
  141. parser.error(
  142. "Only one of --longitudinal-breakdown-only, "
  143. "--noise-correlation-only, and --dataset-summary-only may be used at once."
  144. )
  145. return args
  146. def _run_longitudinal_breakdown_only(
  147. config: dict[str, Any],
  148. backend: str,
  149. clinical_df: pd.DataFrame,
  150. out_dir: Path,
  151. ) -> dict[str, Any]:
  152. evaluation = load_backend_evaluation(
  153. config=config,
  154. backend=backend,
  155. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  156. )
  157. summary = run_longitudinal_breakdown_audit(
  158. evaluation=evaluation,
  159. clinical_df=clinical_df,
  160. output_dir=out_dir,
  161. )
  162. write_json(out_dir / "longitudinal_breakdown_backend_summary.json", summary)
  163. return summary
  164. def _run_noise_correlation_only(
  165. backend: str,
  166. out_dir: Path,
  167. ) -> dict[str, Any]:
  168. noise_table_path = out_dir / "noise_sensitivity.csv"
  169. if not noise_table_path.exists():
  170. raise FileNotFoundError(
  171. f"Expected existing noise table for --noise-correlation-only: {noise_table_path}"
  172. )
  173. noise_df = pd.read_csv(noise_table_path)
  174. summary = run_noise_accuracy_uncertainty_analysis(
  175. noise_df=noise_df,
  176. backend=backend,
  177. output_dir=out_dir,
  178. )
  179. write_json(out_dir / "noise_accuracy_uncertainty_backend_summary.json", summary)
  180. return summary
  181. def _run_backend(
  182. config: dict[str, Any],
  183. root_dir: Path,
  184. backend: str,
  185. clinical_df: pd.DataFrame,
  186. skip_noise: bool,
  187. out_dir: Path,
  188. ) -> dict[str, Any]:
  189. netcdf_path = ensure_backend_netcdf(
  190. config=config,
  191. root_dir=root_dir,
  192. backend=backend,
  193. bayesian_mc_passes=DEFAULT_BAYESIAN_MC_PASSES,
  194. )
  195. evaluation = load_backend_evaluation(
  196. config=config,
  197. backend=backend,
  198. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  199. )
  200. thresholds = threshold_grid()
  201. noise_factors = noise_factor_grid()
  202. summary: dict[str, Any] = {
  203. "backend": backend,
  204. "netcdf": str(netcdf_path),
  205. "source_file": str(evaluation.source_file),
  206. "uncertainty_metric": evaluation.uncertainty_metric,
  207. }
  208. n_stages = 4 + (0 if skip_noise else 2)
  209. stage_bar = tqdm(
  210. total=n_stages,
  211. desc=f"[{backend}] analysis stages",
  212. unit="stage",
  213. leave=False,
  214. )
  215. try:
  216. stage_bar.set_postfix_str("performance")
  217. summary["performance"] = run_performance(
  218. evaluation=evaluation,
  219. output_dir=out_dir,
  220. thresholds=thresholds,
  221. )
  222. stage_bar.update(1)
  223. stage_bar.set_postfix_str("calibration")
  224. summary["calibration"] = run_calibration(
  225. evaluation=evaluation,
  226. output_dir=out_dir,
  227. bins=DEFAULT_CALIBRATION_BINS,
  228. )
  229. stage_bar.update(1)
  230. stage_bar.set_postfix_str("physician")
  231. summary["physician"] = run_physician(
  232. evaluation=evaluation,
  233. clinical_df=clinical_df,
  234. output_dir=out_dir,
  235. )
  236. stage_bar.update(1)
  237. stage_bar.set_postfix_str("longitudinal")
  238. summary["longitudinal"] = run_longitudinal(
  239. evaluation=evaluation,
  240. clinical_df=clinical_df,
  241. output_dir=out_dir,
  242. )
  243. stage_bar.update(1)
  244. if skip_noise:
  245. summary["noise"] = {"skipped": True, "reason": "--skip-noise supplied"}
  246. summary["noise_accuracy_uncertainty"] = {
  247. "skipped": True,
  248. "reason": "Noise analysis skipped, so no noise table available.",
  249. }
  250. else:
  251. try:
  252. stage_bar.set_postfix_str("noise")
  253. summary["noise"] = run_noise_analysis(
  254. config=config,
  255. root_dir=root_dir,
  256. backend=backend,
  257. output_dir=out_dir,
  258. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  259. noise_sigmas=noise_factors,
  260. threshold=DEFAULT_DECISION_THRESHOLD,
  261. calibration_bins=DEFAULT_CALIBRATION_BINS,
  262. bayesian_mc_passes=DEFAULT_BAYESIAN_MC_PASSES,
  263. )
  264. stage_bar.update(1)
  265. stage_bar.set_postfix_str("noise-correlation")
  266. noise_table_path = Path(str(summary["noise"]["table"]))
  267. noise_df = pd.read_csv(noise_table_path)
  268. summary["noise_accuracy_uncertainty"] = (
  269. run_noise_accuracy_uncertainty_analysis(
  270. noise_df=noise_df,
  271. backend=backend,
  272. output_dir=out_dir,
  273. )
  274. )
  275. stage_bar.update(1)
  276. except Exception as exc:
  277. summary["noise"] = {
  278. "skipped": True,
  279. "reason": f"Noise analysis failed: {exc}",
  280. }
  281. summary["noise_accuracy_uncertainty"] = {
  282. "skipped": True,
  283. "reason": f"Noise relationship analysis failed: {exc}",
  284. }
  285. stage_bar.update(2)
  286. finally:
  287. stage_bar.close()
  288. report_path = _write_backend_plot_report(backend=backend, out_dir=out_dir)
  289. summary["plots_report"] = str(report_path)
  290. write_json(out_dir / "backend_summary.json", summary)
  291. return summary
  292. def main() -> None:
  293. args = _parse_args()
  294. analysis_dir = Path(__file__).resolve().parent
  295. paths = init_runtime_paths(analysis_dir=analysis_dir, run_name=args.run_name)
  296. config = load_config(paths.root_dir)
  297. clinical_df = load_clinical_table(config=config, root_dir=paths.root_dir)
  298. manifest: dict[str, Any] = {
  299. "run_dir": str(paths.run_dir),
  300. "output_root": str(paths.output_root),
  301. "mode": (
  302. "dataset_summary_only"
  303. if bool(args.dataset_summary_only)
  304. else (
  305. "longitudinal_breakdown_only"
  306. if bool(args.longitudinal_breakdown_only)
  307. else (
  308. "noise_correlation_only"
  309. if bool(args.noise_correlation_only)
  310. else "full"
  311. )
  312. )
  313. ),
  314. "positive_class_index": DEFAULT_POSITIVE_CLASS_INDEX,
  315. "threshold_sweep": {
  316. "values": [float(v) for v in threshold_grid().tolist()],
  317. },
  318. "calibration_bins": DEFAULT_CALIBRATION_BINS,
  319. "noise_factors": noise_factor_grid(),
  320. "bayesian_mc_passes": DEFAULT_BAYESIAN_MC_PASSES,
  321. "decision_threshold": DEFAULT_DECISION_THRESHOLD,
  322. "backends": {},
  323. }
  324. if args.dataset_summary_only:
  325. manifest["dataset_summary"] = run_dataset_summary(
  326. config=config,
  327. root_dir=paths.root_dir,
  328. output_dir=paths.run_dir,
  329. positive_class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  330. )
  331. write_json(paths.run_dir / "run_manifest.json", manifest)
  332. print(f"Dataset summary complete. Results saved to {paths.run_dir}")
  333. return
  334. backend_iter = tqdm(args.backend, desc="Backends", unit="backend")
  335. for backend in backend_iter:
  336. out_dir = backend_dir(paths, backend)
  337. backend_iter.set_postfix_str(backend)
  338. if args.longitudinal_breakdown_only:
  339. manifest["backends"][backend] = _run_longitudinal_breakdown_only(
  340. config=config,
  341. backend=backend,
  342. clinical_df=clinical_df,
  343. out_dir=out_dir,
  344. )
  345. elif args.noise_correlation_only:
  346. manifest["backends"][backend] = _run_noise_correlation_only(
  347. backend=backend,
  348. out_dir=out_dir,
  349. )
  350. else:
  351. manifest["backends"][backend] = _run_backend(
  352. config=config,
  353. root_dir=paths.root_dir,
  354. backend=backend,
  355. clinical_df=clinical_df,
  356. skip_noise=bool(args.skip_noise),
  357. out_dir=out_dir,
  358. )
  359. write_json(paths.run_dir / "run_manifest.json", manifest)
  360. print(f"Analysis complete. Results saved to {paths.run_dir}")
  361. if __name__ == "__main__":
  362. main()