regenerate_plots.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # pyright: basic
  2. """Regenerate analysis plots from existing computed data (CSV files).
  3. This script regenerates all plots from previously computed analysis results
  4. without re-running the full analysis pipeline. Useful when making changes
  5. to plotting parameters or fixing visualizations.
  6. Usage: Run from the project root (alnn_rewrite directory):
  7. python analysis/regenerate_plots.py /path/to/run_directory/backend_name
  8. Example:
  9. python analysis/regenerate_plots.py analysis_output/run_20260428_120000/ensemble
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import sys
  14. from pathlib import Path
  15. from typing import Any
  16. import numpy as np
  17. import pandas as pd
  18. if __package__ in (None, ""):
  19. sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
  20. from analysis.analysis_modules import _uncertainty_cutoff_analysis
  21. from analysis.defaults import (
  22. DEFAULT_CALIBRATION_BINS,
  23. DEFAULT_DECISION_THRESHOLD,
  24. uncertainty_cutoff_percentiles,
  25. )
  26. from analysis.plotting import (
  27. plots_dir,
  28. save_calibration_plot,
  29. save_performance_threshold_pair_plot,
  30. save_performance_threshold_plot,
  31. save_uncertainty_cutoff_pair_plot,
  32. save_uncertainty_cutoff_plot,
  33. )
  34. from analysis.runtime import write_json
  35. else:
  36. from .analysis_modules import _uncertainty_cutoff_analysis
  37. from .defaults import (
  38. DEFAULT_CALIBRATION_BINS,
  39. DEFAULT_DECISION_THRESHOLD,
  40. uncertainty_cutoff_percentiles,
  41. )
  42. from .plotting import (
  43. plots_dir,
  44. save_calibration_plot,
  45. save_performance_threshold_pair_plot,
  46. save_performance_threshold_plot,
  47. save_uncertainty_cutoff_pair_plot,
  48. save_uncertainty_cutoff_plot,
  49. )
  50. from .runtime import write_json
  51. def _plot_description(filename: str) -> str:
  52. descriptions = {
  53. "performance_threshold_accuracy.png": "Accuracy as the decision threshold varies.",
  54. "performance_threshold_f1.png": "F1 score as the decision threshold varies.",
  55. "performance_threshold_accuracy_f1.png": "Accuracy and F1 shown side-by-side as the decision threshold varies.",
  56. "performance_uncertainty_cutoff_accuracy.png": "Accuracy while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  57. "performance_uncertainty_cutoff_f1.png": "F1 score while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  58. "performance_uncertainty_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across uncertainty-cutoff restriction levels.",
  59. "performance_uncertainty_percentile_cutoff_accuracy.png": "Accuracy from least to most restricted percentile-wise subset selection.",
  60. "performance_uncertainty_percentile_cutoff_f1.png": "F1 score from least to most restricted percentile-wise subset selection.",
  61. "performance_uncertainty_percentile_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across percentile-floor restriction levels.",
  62. "calibration_reliability.png": "Reliability diagram comparing predicted probability to empirical outcome frequency.",
  63. "performance_threshold_accuracy_coverage.png": "Sample distribution (correct vs incorrect) across decision thresholds.",
  64. "performance_threshold_f1_coverage.png": "Sample distribution (correct vs incorrect) across decision thresholds.",
  65. "performance_threshold_accuracy_f1_coverage.png": "Sample distribution (correct vs incorrect) across decision thresholds.",
  66. "performance_uncertainty_cutoff_accuracy_coverage.png": "Sample coverage breakdown across restriction levels.",
  67. "performance_uncertainty_cutoff_f1_coverage.png": "Sample coverage breakdown across restriction levels.",
  68. "performance_uncertainty_cutoff_accuracy_f1_coverage.png": "Sample coverage breakdown across restriction levels.",
  69. "performance_uncertainty_percentile_cutoff_accuracy_coverage.png": "Sample coverage breakdown as percentile floor increases.",
  70. "performance_uncertainty_percentile_cutoff_f1_coverage.png": "Sample coverage breakdown as percentile floor increases.",
  71. "performance_uncertainty_percentile_cutoff_accuracy_f1_coverage.png": "Sample coverage breakdown as percentile floor increases.",
  72. }
  73. return descriptions.get(filename, "Generated analysis plot.")
  74. def _write_backend_plot_report(backend: str, out_dir: Path) -> Path:
  75. plots = out_dir / "plots"
  76. images = sorted(plots.rglob("*.png")) if plots.exists() else []
  77. report_path = out_dir / "plots_report.md"
  78. lines = [
  79. f"# {backend.title()} Analysis Plot Report (Regenerated)",
  80. "",
  81. "This document lists regenerated analysis plots with brief descriptions.",
  82. "",
  83. ]
  84. if not images:
  85. lines.append("No plot images were found for this backend run.")
  86. else:
  87. for image_path in images:
  88. rel = image_path.relative_to(out_dir).as_posix()
  89. title = image_path.stem.replace("_", " ").title()
  90. lines.append(f"## {title}")
  91. lines.append(_plot_description(image_path.name))
  92. lines.append("")
  93. lines.append(f"![{title}]({rel})")
  94. lines.append("")
  95. report_path.write_text("\n".join(lines), encoding="utf-8")
  96. return report_path
  97. def regenerate_performance_plots(backend_dir: Path) -> dict[str, Any]:
  98. """Regenerate performance threshold plots from existing CSV."""
  99. perf_csv = backend_dir / "performance_threshold_sweep.csv"
  100. if not perf_csv.exists():
  101. return {"status": "skipped", "reason": "no performance_threshold_sweep.csv"}
  102. df = pd.read_csv(perf_csv)
  103. backend = backend_dir.name if backend_dir.name != "plots" else "ensemble"
  104. # Get backend name from parent directory name if not found
  105. if backend_dir.parent.name not in ["ensemble", "bayesian"]:
  106. parent_name = backend_dir.name
  107. if parent_name in {"ensemble", "bayesian"}:
  108. backend = parent_name
  109. accuracy_plot_path = plots_dir(backend_dir) / "performance_threshold_accuracy.png"
  110. f1_plot_path = plots_dir(backend_dir) / "performance_threshold_f1.png"
  111. pair_plot_path = plots_dir(backend_dir) / "performance_threshold_accuracy_f1.png"
  112. save_performance_threshold_plot(
  113. df=df,
  114. backend=backend,
  115. output_path=accuracy_plot_path,
  116. metric_column="accuracy",
  117. metric_label="Accuracy",
  118. plot_key="performance_threshold_accuracy",
  119. )
  120. save_performance_threshold_plot(
  121. df=df,
  122. backend=backend,
  123. output_path=f1_plot_path,
  124. metric_column="f1",
  125. metric_label="F1",
  126. plot_key="performance_threshold_f1",
  127. )
  128. save_performance_threshold_pair_plot(
  129. df=df,
  130. backend=backend,
  131. output_path=pair_plot_path,
  132. plot_key="performance_threshold_accuracy_f1",
  133. )
  134. return {
  135. "status": "regenerated",
  136. "performance_threshold_accuracy": str(accuracy_plot_path),
  137. "performance_threshold_f1": str(f1_plot_path),
  138. "performance_threshold_accuracy_f1": str(pair_plot_path),
  139. }
  140. def regenerate_uncertainty_cutoff_plots(backend_dir: Path) -> dict[str, Any]:
  141. """Regenerate uncertainty cutoff plots from existing CSV."""
  142. cutoff_csv = backend_dir / "performance_uncertainty_cutoff.csv"
  143. percentile_csv = backend_dir / "performance_uncertainty_percentile_cutoff.csv"
  144. results = {"status": "skipped", "reason": "no cutoff CSV files found"}
  145. if cutoff_csv.exists():
  146. cutoff_df = pd.read_csv(cutoff_csv)
  147. results["status"] = "regenerated"
  148. # Create plots by uncertainty type
  149. for uncertainty_name in sorted(pd.unique(cutoff_df["uncertainty_type"])):
  150. sub_df = cutoff_df[cutoff_df["uncertainty_type"] == uncertainty_name].copy()
  151. slug = uncertainty_name.lower().replace(" ", "_")
  152. sub_accuracy_plot_path = (
  153. plots_dir(backend_dir)
  154. / f"performance_uncertainty_cutoff_{slug}_accuracy.png"
  155. )
  156. sub_f1_plot_path = (
  157. plots_dir(backend_dir) / f"performance_uncertainty_cutoff_{slug}_f1.png"
  158. )
  159. sub_pair_plot_path = (
  160. plots_dir(backend_dir)
  161. / f"performance_uncertainty_cutoff_{slug}_accuracy_f1.png"
  162. )
  163. save_uncertainty_cutoff_plot(
  164. cutoff_df=sub_df,
  165. title_prefix="Model Output / Uncertainty Cutoff Percentile",
  166. x_label="Restriction Level (0 = all samples, 100 = most restricted subset)",
  167. output_path=sub_accuracy_plot_path,
  168. metric_column="accuracy",
  169. metric_label="Accuracy",
  170. plot_key="performance_uncertainty_cutoff_accuracy",
  171. )
  172. save_uncertainty_cutoff_plot(
  173. cutoff_df=sub_df,
  174. title_prefix="Model Output / Uncertainty Cutoff Percentile",
  175. x_label="Restriction Level (0 = all samples, 100 = most restricted subset)",
  176. output_path=sub_f1_plot_path,
  177. metric_column="f1",
  178. metric_label="F1",
  179. plot_key="performance_uncertainty_cutoff_f1",
  180. )
  181. save_uncertainty_cutoff_pair_plot(
  182. cutoff_df=sub_df,
  183. title_prefix="Model Output / Uncertainty Cutoff Percentile",
  184. x_label="Restriction Level (0 = all samples, 100 = most restricted subset)",
  185. output_path=sub_pair_plot_path,
  186. plot_key="performance_uncertainty_cutoff_accuracy_f1",
  187. )
  188. if percentile_csv.exists():
  189. percentile_df = pd.read_csv(percentile_csv)
  190. results["status"] = "regenerated"
  191. # Create plots by uncertainty type
  192. for uncertainty_name in sorted(pd.unique(percentile_df["uncertainty_type"])):
  193. sub_df = percentile_df[
  194. percentile_df["uncertainty_type"] == uncertainty_name
  195. ].copy()
  196. slug = uncertainty_name.lower().replace(" ", "_")
  197. sub_accuracy_plot_path = (
  198. plots_dir(backend_dir)
  199. / f"performance_uncertainty_percentile_cutoff_{slug}_accuracy.png"
  200. )
  201. sub_f1_plot_path = (
  202. plots_dir(backend_dir)
  203. / f"performance_uncertainty_percentile_cutoff_{slug}_f1.png"
  204. )
  205. sub_pair_plot_path = (
  206. plots_dir(backend_dir)
  207. / f"performance_uncertainty_percentile_cutoff_{slug}_accuracy_f1.png"
  208. )
  209. save_uncertainty_cutoff_plot(
  210. cutoff_df=sub_df,
  211. title_prefix="Model Output / Uncertainty Percentile Floor",
  212. x_label="Percentile Floor (0 = all samples, 100 = top percentile subset)",
  213. output_path=sub_accuracy_plot_path,
  214. metric_column="accuracy",
  215. metric_label="Accuracy",
  216. plot_key="performance_uncertainty_percentile_cutoff_accuracy",
  217. )
  218. save_uncertainty_cutoff_plot(
  219. cutoff_df=sub_df,
  220. title_prefix="Model Output / Uncertainty Percentile Floor",
  221. x_label="Percentile Floor (0 = all samples, 100 = top percentile subset)",
  222. output_path=sub_f1_plot_path,
  223. metric_column="f1",
  224. metric_label="F1",
  225. plot_key="performance_uncertainty_percentile_cutoff_f1",
  226. )
  227. save_uncertainty_cutoff_pair_plot(
  228. cutoff_df=sub_df,
  229. title_prefix="Model Output / Uncertainty Percentile Floor",
  230. x_label="Percentile Floor (0 = all samples, 100 = top percentile subset)",
  231. output_path=sub_pair_plot_path,
  232. plot_key="performance_uncertainty_percentile_cutoff_accuracy_f1",
  233. )
  234. return results
  235. def regenerate_calibration_plots(backend_dir: Path) -> dict[str, Any]:
  236. """Regenerate calibration plots from existing calibration data."""
  237. calib_path = backend_dir / "calibration_per_bin.npy"
  238. if not calib_path.exists():
  239. return {"status": "skipped", "reason": "no calibration_per_bin.npy"}
  240. per_bin = np.load(calib_path)
  241. backend = backend_dir.name if backend_dir.name != "plots" else "ensemble"
  242. # Get backend name from parent directory name if not found
  243. if backend_dir.parent.name not in ["ensemble", "bayesian"]:
  244. parent_name = backend_dir.name
  245. if parent_name in {"ensemble", "bayesian"}:
  246. backend = parent_name
  247. plot_path = plots_dir(backend_dir) / "calibration_reliability.png"
  248. save_calibration_plot(per_bin=per_bin, backend=backend, output_path=plot_path)
  249. return {
  250. "status": "regenerated",
  251. "calibration_reliability": str(plot_path),
  252. }
  253. def main() -> None:
  254. parser = argparse.ArgumentParser(
  255. description="Regenerate analysis plots from existing computed data CSV files."
  256. )
  257. parser.add_argument(
  258. "backend_dir",
  259. type=Path,
  260. help="Path to backend-specific analysis output directory "
  261. "(e.g., analysis_output/run_xxx/ensemble)",
  262. )
  263. args = parser.parse_args()
  264. backend_dir = args.backend_dir.resolve()
  265. if not backend_dir.exists():
  266. print(
  267. f"Error: Backend directory does not exist: {backend_dir}", file=sys.stderr
  268. )
  269. sys.exit(1)
  270. print(f"Regenerating plots from: {backend_dir}")
  271. results: dict[str, Any] = {
  272. "backend_dir": str(backend_dir),
  273. "performance": regenerate_performance_plots(backend_dir),
  274. "uncertainty_cutoff": regenerate_uncertainty_cutoff_plots(backend_dir),
  275. "calibration": regenerate_calibration_plots(backend_dir),
  276. }
  277. # Write updated report
  278. report_path = _write_backend_plot_report(
  279. backend=backend_dir.name, out_dir=backend_dir
  280. )
  281. results["plots_report"] = str(report_path)
  282. print(f"\nPlot regeneration complete!")
  283. print(f"Results summary:")
  284. print(f" Performance plots: {results['performance'].get('status', 'unknown')}")
  285. print(
  286. f" Uncertainty cutoff plots: {results['uncertainty_cutoff'].get('status', 'unknown')}"
  287. )
  288. print(f" Calibration plots: {results['calibration'].get('status', 'unknown')}")
  289. print(f" Report written to: {report_path}")
  290. write_json(backend_dir / "plot_regeneration_log.json", results)
  291. if __name__ == "__main__":
  292. main()