evaluate_models.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # pyright: basic
  2. from __future__ import annotations
  3. import argparse
  4. from pathlib import Path
  5. from typing import Any
  6. import pandas as pd
  7. from analysis.analysis_modules import (
  8. run_calibration,
  9. run_longitudinal,
  10. run_performance,
  11. run_physician,
  12. )
  13. from analysis.dataset_summary import run_dataset_summary
  14. from analysis.data_access import load_backend_evaluation, load_clinical_table
  15. from analysis.defaults import (
  16. DEFAULT_BACKENDS,
  17. DEFAULT_BAYESIAN_MC_PASSES,
  18. DEFAULT_CALIBRATION_BINS,
  19. DEFAULT_DECISION_THRESHOLD,
  20. DEFAULT_POSITIVE_CLASS_INDEX,
  21. noise_factor_grid,
  22. threshold_grid,
  23. )
  24. from analysis.holdout_evaluation import ensure_backend_netcdf
  25. from analysis.longitudinal_audit import run_longitudinal_breakdown_audit
  26. from analysis.noise_correlation import run_noise_accuracy_uncertainty_analysis
  27. from analysis.noise_analysis import run_noise_analysis
  28. from analysis.runtime import backend_dir, init_runtime_paths, load_config, write_json
  29. def _plot_description(filename: str) -> str:
  30. descriptions = {
  31. "performance_threshold_accuracy.png": "Accuracy as the decision threshold varies.",
  32. "performance_threshold_f1.png": "F1 score as the decision threshold varies.",
  33. "performance_threshold_accuracy_f1.png": "Accuracy and F1 shown side-by-side as the decision threshold varies.",
  34. "performance_uncertainty_cutoff_accuracy.png": "Accuracy while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  35. "performance_uncertainty_cutoff_f1.png": "F1 score while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  36. "performance_uncertainty_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across uncertainty-cutoff restriction levels.",
  37. "performance_uncertainty_percentile_cutoff_accuracy.png": "Accuracy from least to most restricted percentile-wise subset selection.",
  38. "performance_uncertainty_percentile_cutoff_f1.png": "F1 score from least to most restricted percentile-wise subset selection.",
  39. "performance_uncertainty_percentile_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across percentile-floor restriction levels.",
  40. "calibration_reliability.png": "Reliability diagram comparing predicted probability to empirical outcome frequency.",
  41. "physician_confidence_boxplot.png": "Confidence grouped by physician confidence ratings.",
  42. "physician_std_boxplot.png": "Standard deviation grouped by physician confidence ratings.",
  43. "physician_predictive_entropy_boxplot.png": "Predictive uncertainty grouped by physician confidence ratings.",
  44. "longitudinal_cohort_confidence.png": "Longitudinal cohort comparison using confidence.",
  45. "longitudinal_cohort_std.png": "Longitudinal cohort comparison using standard deviation.",
  46. "longitudinal_cohort_predictive_entropy.png": "Longitudinal cohort comparison using predictive uncertainty.",
  47. "noise_sensitivity_accuracy.png": "Accuracy trend across increasing Gaussian noise factors.",
  48. "noise_sensitivity_f1.png": "F1 trend across increasing Gaussian noise factors.",
  49. "noise_sensitivity_accuracy_f1.png": "Accuracy and F1 shown side-by-side across increasing Gaussian noise factors.",
  50. "noise_confidence.png": "Confidence trend across increasing Gaussian noise factors.",
  51. "noise_standard_deviation.png": "Standard deviation trend across increasing Gaussian noise factors.",
  52. "noise_confidence_standard_deviation.png": "Confidence and standard deviation shown side-by-side across increasing Gaussian noise factors.",
  53. "noise_predictive_uncertainty.png": "Predictive uncertainty trend across increasing Gaussian noise factors.",
  54. "noise_confidence_predictive_uncertainty.png": "Confidence and predictive uncertainty shown side-by-side across increasing Gaussian noise factors.",
  55. "noise_accuracy_uncertainty_2d.png": "2D uncertainty-vs-accuracy relationship with linear fit (noise factor encoded by color).",
  56. "ensemble_noise_examples.png": "Representative noisy image slices across selected Gaussian noise factors.",
  57. "bayesian_noise_examples.png": "Representative noisy image slices across selected Gaussian noise factors.",
  58. "ensemble_clean_scan_example.png": "Example clean scan image with no added noise.",
  59. "bayesian_clean_scan_example.png": "Example clean scan image with no added noise.",
  60. }
  61. return descriptions.get(filename, "Generated analysis plot.")
  62. def _write_backend_plot_report(backend: str, out_dir: Path) -> Path:
  63. plots_dir = out_dir / "plots"
  64. images = sorted(plots_dir.rglob("*.png")) if plots_dir.exists() else []
  65. report_path = out_dir / "plots_report.md"
  66. lines = [
  67. f"# {backend.title()} Analysis Plot Report",
  68. "",
  69. "This document lists generated analysis plots with brief descriptions.",
  70. "",
  71. ]
  72. if not images:
  73. lines.append("No plot images were generated for this backend run.")
  74. else:
  75. for image_path in images:
  76. rel = image_path.relative_to(out_dir).as_posix()
  77. title = image_path.stem.replace("_", " ").title()
  78. lines.append(f"## {title}")
  79. lines.append(_plot_description(image_path.name))
  80. lines.append("")
  81. lines.append(f"![{title}]({rel})")
  82. lines.append("")
  83. report_path.write_text("\n".join(lines), encoding="utf-8")
  84. return report_path
  85. def _parse_args() -> argparse.Namespace:
  86. parser = argparse.ArgumentParser(
  87. description=(
  88. "Run modular evaluation analyses for ensemble and bayesian models. "
  89. "All outputs are written to alnn_rewrite/analysis_output."
  90. )
  91. )
  92. parser.add_argument(
  93. "--backend",
  94. nargs="+",
  95. choices=["ensemble", "bayesian"],
  96. default=DEFAULT_BACKENDS,
  97. help="Backends to evaluate.",
  98. )
  99. parser.add_argument(
  100. "--run-name",
  101. default=None,
  102. help="Optional run directory name under analysis_output.",
  103. )
  104. parser.add_argument(
  105. "--skip-noise",
  106. action="store_true",
  107. help="Skip Gaussian noise sensitivity analysis.",
  108. )
  109. parser.add_argument(
  110. "--longitudinal-breakdown-only",
  111. action="store_true",
  112. help=(
  113. "Run only longitudinal cohort breakdown audit from existing model "
  114. "evaluation outputs (no full analysis rerun)."
  115. ),
  116. )
  117. parser.add_argument(
  118. "--noise-correlation-only",
  119. action="store_true",
  120. help=(
  121. "Run only the noise uncertainty-vs-accuracy correlation/regression "
  122. "analysis from an existing noise_sensitivity.csv per backend."
  123. ),
  124. )
  125. parser.add_argument(
  126. "--dataset-summary-only",
  127. action="store_true",
  128. help=(
  129. "Generate only dataset composition summary documentation "
  130. "(overall and train/validation/test class breakdown)."
  131. ),
  132. )
  133. args = parser.parse_args()
  134. only_modes = [
  135. bool(args.longitudinal_breakdown_only),
  136. bool(args.noise_correlation_only),
  137. bool(args.dataset_summary_only),
  138. ]
  139. if sum(only_modes) > 1:
  140. parser.error(
  141. "Only one of --longitudinal-breakdown-only, "
  142. "--noise-correlation-only, and --dataset-summary-only may be used at once."
  143. )
  144. return args
  145. def _run_longitudinal_breakdown_only(
  146. config: dict[str, Any],
  147. backend: str,
  148. clinical_df: pd.DataFrame,
  149. out_dir: Path,
  150. ) -> dict[str, Any]:
  151. evaluation = load_backend_evaluation(
  152. config=config,
  153. backend=backend,
  154. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  155. )
  156. summary = run_longitudinal_breakdown_audit(
  157. evaluation=evaluation,
  158. clinical_df=clinical_df,
  159. output_dir=out_dir,
  160. )
  161. write_json(out_dir / "longitudinal_breakdown_backend_summary.json", summary)
  162. return summary
  163. def _run_noise_correlation_only(
  164. backend: str,
  165. out_dir: Path,
  166. ) -> dict[str, Any]:
  167. noise_table_path = out_dir / "noise_sensitivity.csv"
  168. if not noise_table_path.exists():
  169. raise FileNotFoundError(
  170. f"Expected existing noise table for --noise-correlation-only: {noise_table_path}"
  171. )
  172. noise_df = pd.read_csv(noise_table_path)
  173. summary = run_noise_accuracy_uncertainty_analysis(
  174. noise_df=noise_df,
  175. backend=backend,
  176. output_dir=out_dir,
  177. )
  178. write_json(out_dir / "noise_accuracy_uncertainty_backend_summary.json", summary)
  179. return summary
  180. def _run_backend(
  181. config: dict[str, Any],
  182. root_dir: Path,
  183. backend: str,
  184. clinical_df: pd.DataFrame,
  185. skip_noise: bool,
  186. out_dir: Path,
  187. ) -> dict[str, Any]:
  188. netcdf_path = ensure_backend_netcdf(
  189. config=config,
  190. root_dir=root_dir,
  191. backend=backend,
  192. bayesian_mc_passes=DEFAULT_BAYESIAN_MC_PASSES,
  193. )
  194. evaluation = load_backend_evaluation(
  195. config=config,
  196. backend=backend,
  197. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  198. )
  199. thresholds = threshold_grid()
  200. noise_factors = noise_factor_grid()
  201. summary: dict[str, Any] = {
  202. "backend": backend,
  203. "netcdf": str(netcdf_path),
  204. "source_file": str(evaluation.source_file),
  205. "uncertainty_metric": evaluation.uncertainty_metric,
  206. }
  207. summary["performance"] = run_performance(
  208. evaluation=evaluation,
  209. output_dir=out_dir,
  210. thresholds=thresholds,
  211. )
  212. summary["calibration"] = run_calibration(
  213. evaluation=evaluation,
  214. output_dir=out_dir,
  215. bins=DEFAULT_CALIBRATION_BINS,
  216. )
  217. summary["physician"] = run_physician(
  218. evaluation=evaluation,
  219. clinical_df=clinical_df,
  220. output_dir=out_dir,
  221. )
  222. summary["longitudinal"] = run_longitudinal(
  223. evaluation=evaluation,
  224. clinical_df=clinical_df,
  225. output_dir=out_dir,
  226. )
  227. if skip_noise:
  228. summary["noise"] = {"skipped": True, "reason": "--skip-noise supplied"}
  229. summary["noise_accuracy_uncertainty"] = {
  230. "skipped": True,
  231. "reason": "Noise analysis skipped, so no noise table available.",
  232. }
  233. else:
  234. try:
  235. summary["noise"] = run_noise_analysis(
  236. config=config,
  237. root_dir=root_dir,
  238. backend=backend,
  239. output_dir=out_dir,
  240. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  241. noise_sigmas=noise_factors,
  242. threshold=DEFAULT_DECISION_THRESHOLD,
  243. calibration_bins=DEFAULT_CALIBRATION_BINS,
  244. bayesian_mc_passes=DEFAULT_BAYESIAN_MC_PASSES,
  245. )
  246. noise_table_path = Path(str(summary["noise"]["table"]))
  247. noise_df = pd.read_csv(noise_table_path)
  248. summary["noise_accuracy_uncertainty"] = (
  249. run_noise_accuracy_uncertainty_analysis(
  250. noise_df=noise_df,
  251. backend=backend,
  252. output_dir=out_dir,
  253. )
  254. )
  255. except Exception as exc:
  256. summary["noise"] = {
  257. "skipped": True,
  258. "reason": f"Noise analysis failed: {exc}",
  259. }
  260. summary["noise_accuracy_uncertainty"] = {
  261. "skipped": True,
  262. "reason": f"Noise relationship analysis failed: {exc}",
  263. }
  264. report_path = _write_backend_plot_report(backend=backend, out_dir=out_dir)
  265. summary["plots_report"] = str(report_path)
  266. write_json(out_dir / "backend_summary.json", summary)
  267. return summary
  268. def main() -> None:
  269. args = _parse_args()
  270. analysis_dir = Path(__file__).resolve().parent
  271. paths = init_runtime_paths(analysis_dir=analysis_dir, run_name=args.run_name)
  272. config = load_config(paths.root_dir)
  273. clinical_df = load_clinical_table(config=config, root_dir=paths.root_dir)
  274. manifest: dict[str, Any] = {
  275. "run_dir": str(paths.run_dir),
  276. "output_root": str(paths.output_root),
  277. "mode": (
  278. "dataset_summary_only"
  279. if bool(args.dataset_summary_only)
  280. else (
  281. "longitudinal_breakdown_only"
  282. if bool(args.longitudinal_breakdown_only)
  283. else (
  284. "noise_correlation_only"
  285. if bool(args.noise_correlation_only)
  286. else "full"
  287. )
  288. )
  289. ),
  290. "positive_class_index": DEFAULT_POSITIVE_CLASS_INDEX,
  291. "threshold_sweep": {
  292. "values": [float(v) for v in threshold_grid().tolist()],
  293. },
  294. "calibration_bins": DEFAULT_CALIBRATION_BINS,
  295. "noise_factors": noise_factor_grid(),
  296. "bayesian_mc_passes": DEFAULT_BAYESIAN_MC_PASSES,
  297. "decision_threshold": DEFAULT_DECISION_THRESHOLD,
  298. "backends": {},
  299. }
  300. if args.dataset_summary_only:
  301. manifest["dataset_summary"] = run_dataset_summary(
  302. config=config,
  303. root_dir=paths.root_dir,
  304. output_dir=paths.run_dir,
  305. positive_class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  306. )
  307. write_json(paths.run_dir / "run_manifest.json", manifest)
  308. print(f"Dataset summary complete. Results saved to {paths.run_dir}")
  309. return
  310. for backend in args.backend:
  311. out_dir = backend_dir(paths, backend)
  312. if args.longitudinal_breakdown_only:
  313. manifest["backends"][backend] = _run_longitudinal_breakdown_only(
  314. config=config,
  315. backend=backend,
  316. clinical_df=clinical_df,
  317. out_dir=out_dir,
  318. )
  319. elif args.noise_correlation_only:
  320. manifest["backends"][backend] = _run_noise_correlation_only(
  321. backend=backend,
  322. out_dir=out_dir,
  323. )
  324. else:
  325. manifest["backends"][backend] = _run_backend(
  326. config=config,
  327. root_dir=paths.root_dir,
  328. backend=backend,
  329. clinical_df=clinical_df,
  330. skip_noise=bool(args.skip_noise),
  331. out_dir=out_dir,
  332. )
  333. write_json(paths.run_dir / "run_manifest.json", manifest)
  334. print(f"Analysis complete. Results saved to {paths.run_dir}")
  335. if __name__ == "__main__":
  336. main()