evaluate_models.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # pyright: basic
  2. from __future__ import annotations
  3. import argparse
  4. from pathlib import Path
  5. from typing import Any
  6. import pandas as pd
  7. from analysis.analysis_modules import (
  8. run_calibration,
  9. run_longitudinal,
  10. run_performance,
  11. run_physician,
  12. )
  13. from analysis.data_access import load_backend_evaluation, load_clinical_table
  14. from analysis.defaults import (
  15. DEFAULT_BACKENDS,
  16. DEFAULT_BAYESIAN_MC_PASSES,
  17. DEFAULT_CALIBRATION_BINS,
  18. DEFAULT_DECISION_THRESHOLD,
  19. DEFAULT_POSITIVE_CLASS_INDEX,
  20. noise_factor_grid,
  21. threshold_grid,
  22. )
  23. from analysis.holdout_evaluation import ensure_backend_netcdf
  24. from analysis.noise_analysis import run_noise_analysis
  25. from analysis.runtime import backend_dir, init_runtime_paths, load_config, write_json
  26. def _plot_description(filename: str) -> str:
  27. descriptions = {
  28. "performance_threshold_sweep.png": "Accuracy and F1 as the decision threshold varies.",
  29. "performance_uncertainty_cutoff.png": "Performance while progressively restricting to lower-uncertainty predictions.",
  30. "performance_uncertainty_percentile_cutoff.png": "Percentile-ranked low-uncertainty subset performance from least to most restricted.",
  31. "calibration_reliability.png": "Reliability diagram comparing predicted probability to empirical outcome frequency.",
  32. "physician_confidence_boxplot.png": "Model confidence grouped by physician confidence ratings.",
  33. "physician_std_boxplot.png": "Model secondary uncertainty grouped by physician confidence ratings.",
  34. "physician_predictive_entropy_boxplot.png": "Predictive entropy grouped by physician confidence ratings.",
  35. "longitudinal_cohort_confidence.png": "Longitudinal cohort comparison using model confidence.",
  36. "longitudinal_cohort_std.png": "Longitudinal cohort comparison using ensemble standard deviation uncertainty.",
  37. "longitudinal_cohort_predictive_entropy.png": "Longitudinal cohort comparison using predictive entropy uncertainty.",
  38. "noise_sensitivity.png": "Performance metrics across increasing Gaussian noise factors.",
  39. "noise_uncertainty.png": "Uncertainty metrics across increasing Gaussian noise factors.",
  40. "noise_confidence_certainty.png": "Confidence certainty trend across increasing Gaussian noise factors.",
  41. "ensemble_noise_examples.png": "Representative image slices with progressively larger Gaussian noise factors.",
  42. "bayesian_noise_examples.png": "Representative image slices with progressively larger Gaussian noise factors.",
  43. }
  44. return descriptions.get(filename, "Generated analysis plot.")
  45. def _write_backend_plot_report(backend: str, out_dir: Path) -> Path:
  46. plots_dir = out_dir / "plots"
  47. images = sorted(plots_dir.rglob("*.png")) if plots_dir.exists() else []
  48. report_path = out_dir / "plots_report.md"
  49. lines = [
  50. f"# {backend.title()} Analysis Plot Report",
  51. "",
  52. "This document lists generated analysis plots with brief descriptions.",
  53. "",
  54. ]
  55. if not images:
  56. lines.append("No plot images were generated for this backend run.")
  57. else:
  58. for image_path in images:
  59. rel = image_path.relative_to(out_dir).as_posix()
  60. title = image_path.stem.replace("_", " ").title()
  61. lines.append(f"## {title}")
  62. lines.append(_plot_description(image_path.name))
  63. lines.append("")
  64. lines.append(f"![{title}]({rel})")
  65. lines.append("")
  66. report_path.write_text("\n".join(lines), encoding="utf-8")
  67. return report_path
  68. def _parse_args() -> argparse.Namespace:
  69. parser = argparse.ArgumentParser(
  70. description=(
  71. "Run modular evaluation analyses for ensemble and bayesian models. "
  72. "All outputs are written to alnn_rewrite/analysis_output."
  73. )
  74. )
  75. parser.add_argument(
  76. "--backend",
  77. nargs="+",
  78. choices=["ensemble", "bayesian"],
  79. default=DEFAULT_BACKENDS,
  80. help="Backends to evaluate.",
  81. )
  82. parser.add_argument(
  83. "--run-name",
  84. default=None,
  85. help="Optional run directory name under analysis_output.",
  86. )
  87. parser.add_argument(
  88. "--skip-noise",
  89. action="store_true",
  90. help="Skip Gaussian noise sensitivity analysis.",
  91. )
  92. return parser.parse_args()
  93. def _run_backend(
  94. config: dict[str, Any],
  95. root_dir: Path,
  96. backend: str,
  97. clinical_df: pd.DataFrame,
  98. skip_noise: bool,
  99. out_dir: Path,
  100. ) -> dict[str, Any]:
  101. netcdf_path = ensure_backend_netcdf(
  102. config=config,
  103. root_dir=root_dir,
  104. backend=backend,
  105. bayesian_mc_passes=DEFAULT_BAYESIAN_MC_PASSES,
  106. )
  107. evaluation = load_backend_evaluation(
  108. config=config,
  109. backend=backend,
  110. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  111. )
  112. thresholds = threshold_grid()
  113. noise_factors = noise_factor_grid()
  114. summary: dict[str, Any] = {
  115. "backend": backend,
  116. "netcdf": str(netcdf_path),
  117. "source_file": str(evaluation.source_file),
  118. "uncertainty_metric": evaluation.uncertainty_metric,
  119. }
  120. summary["performance"] = run_performance(
  121. evaluation=evaluation,
  122. output_dir=out_dir,
  123. thresholds=thresholds,
  124. )
  125. summary["calibration"] = run_calibration(
  126. evaluation=evaluation,
  127. output_dir=out_dir,
  128. bins=DEFAULT_CALIBRATION_BINS,
  129. )
  130. summary["physician"] = run_physician(
  131. evaluation=evaluation,
  132. clinical_df=clinical_df,
  133. output_dir=out_dir,
  134. )
  135. summary["longitudinal"] = run_longitudinal(
  136. evaluation=evaluation,
  137. clinical_df=clinical_df,
  138. output_dir=out_dir,
  139. )
  140. if skip_noise:
  141. summary["noise"] = {"skipped": True, "reason": "--skip-noise supplied"}
  142. else:
  143. try:
  144. summary["noise"] = run_noise_analysis(
  145. config=config,
  146. root_dir=root_dir,
  147. backend=backend,
  148. output_dir=out_dir,
  149. class_index=DEFAULT_POSITIVE_CLASS_INDEX,
  150. noise_sigmas=noise_factors,
  151. threshold=DEFAULT_DECISION_THRESHOLD,
  152. calibration_bins=DEFAULT_CALIBRATION_BINS,
  153. bayesian_mc_passes=DEFAULT_BAYESIAN_MC_PASSES,
  154. )
  155. except Exception as exc:
  156. summary["noise"] = {
  157. "skipped": True,
  158. "reason": f"Noise analysis failed: {exc}",
  159. }
  160. report_path = _write_backend_plot_report(backend=backend, out_dir=out_dir)
  161. summary["plots_report"] = str(report_path)
  162. write_json(out_dir / "backend_summary.json", summary)
  163. return summary
  164. def main() -> None:
  165. args = _parse_args()
  166. analysis_dir = Path(__file__).resolve().parent
  167. paths = init_runtime_paths(analysis_dir=analysis_dir, run_name=args.run_name)
  168. config = load_config(paths.root_dir)
  169. clinical_df = load_clinical_table(config=config, root_dir=paths.root_dir)
  170. manifest: dict[str, Any] = {
  171. "run_dir": str(paths.run_dir),
  172. "output_root": str(paths.output_root),
  173. "positive_class_index": DEFAULT_POSITIVE_CLASS_INDEX,
  174. "threshold_sweep": {
  175. "values": [float(v) for v in threshold_grid().tolist()],
  176. },
  177. "calibration_bins": DEFAULT_CALIBRATION_BINS,
  178. "noise_factors": noise_factor_grid(),
  179. "bayesian_mc_passes": DEFAULT_BAYESIAN_MC_PASSES,
  180. "decision_threshold": DEFAULT_DECISION_THRESHOLD,
  181. "backends": {},
  182. }
  183. for backend in args.backend:
  184. out_dir = backend_dir(paths, backend)
  185. manifest["backends"][backend] = _run_backend(
  186. config=config,
  187. root_dir=paths.root_dir,
  188. backend=backend,
  189. clinical_df=clinical_df,
  190. skip_noise=bool(args.skip_noise),
  191. out_dir=out_dir,
  192. )
  193. write_json(paths.run_dir / "run_manifest.json", manifest)
  194. print(f"Analysis complete. Results saved to {paths.run_dir}")
  195. if __name__ == "__main__":
  196. main()