regenerate_plots.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # pyright: basic
  2. """Regenerate analysis plots from existing computed data (CSV files).
  3. This script regenerates all plots from previously computed analysis results
  4. without re-running the full analysis pipeline. Useful when making changes
  5. to plotting parameters or fixing visualizations.
  6. Usage: Run from the project root (alnn_rewrite directory):
  7. python analysis/regenerate_plots.py /path/to/run_directory/backend_name
  8. Example:
  9. python analysis/regenerate_plots.py analysis_output/run_20260428_120000/ensemble
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import sys
  14. from pathlib import Path
  15. from typing import Any
  16. import numpy as np
  17. import pandas as pd
  18. # Add parent directory to path for imports
  19. sys.path.insert(0, str(Path(__file__).parent.parent))
  20. from analysis.analysis_modules import _uncertainty_cutoff_analysis
  21. from analysis.defaults import (
  22. DEFAULT_CALIBRATION_BINS,
  23. DEFAULT_DECISION_THRESHOLD,
  24. uncertainty_cutoff_percentiles,
  25. )
  26. from analysis.plotting import (
  27. plots_dir,
  28. save_calibration_plot,
  29. save_performance_threshold_pair_plot,
  30. save_performance_threshold_plot,
  31. save_uncertainty_cutoff_pair_plot,
  32. save_uncertainty_cutoff_plot,
  33. )
  34. from analysis.runtime import write_json
  35. def _plot_description(filename: str) -> str:
  36. descriptions = {
  37. "performance_threshold_accuracy.png": "Accuracy as the decision threshold varies.",
  38. "performance_threshold_f1.png": "F1 score as the decision threshold varies.",
  39. "performance_threshold_accuracy_f1.png": "Accuracy and F1 shown side-by-side as the decision threshold varies.",
  40. "performance_uncertainty_cutoff_accuracy.png": "Accuracy while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  41. "performance_uncertainty_cutoff_f1.png": "F1 score while progressively restricting to higher-confidence and uncertainty-metric subsets.",
  42. "performance_uncertainty_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across uncertainty-cutoff restriction levels.",
  43. "performance_uncertainty_percentile_cutoff_accuracy.png": "Accuracy from least to most restricted percentile-wise subset selection.",
  44. "performance_uncertainty_percentile_cutoff_f1.png": "F1 score from least to most restricted percentile-wise subset selection.",
  45. "performance_uncertainty_percentile_cutoff_accuracy_f1.png": "Accuracy and F1 shown side-by-side across percentile-floor restriction levels.",
  46. "calibration_reliability.png": "Reliability diagram comparing predicted probability to empirical outcome frequency.",
  47. "performance_threshold_accuracy_coverage.png": "Sample distribution (correct vs incorrect) across decision thresholds.",
  48. "performance_threshold_f1_coverage.png": "Sample distribution (correct vs incorrect) across decision thresholds.",
  49. "performance_threshold_accuracy_f1_coverage.png": "Sample distribution (correct vs incorrect) across decision thresholds.",
  50. "performance_uncertainty_cutoff_accuracy_coverage.png": "Sample coverage breakdown across restriction levels.",
  51. "performance_uncertainty_cutoff_f1_coverage.png": "Sample coverage breakdown across restriction levels.",
  52. "performance_uncertainty_cutoff_accuracy_f1_coverage.png": "Sample coverage breakdown across restriction levels.",
  53. "performance_uncertainty_percentile_cutoff_accuracy_coverage.png": "Sample coverage breakdown as percentile floor increases.",
  54. "performance_uncertainty_percentile_cutoff_f1_coverage.png": "Sample coverage breakdown as percentile floor increases.",
  55. "performance_uncertainty_percentile_cutoff_accuracy_f1_coverage.png": "Sample coverage breakdown as percentile floor increases.",
  56. }
  57. return descriptions.get(filename, "Generated analysis plot.")
  58. def _write_backend_plot_report(backend: str, out_dir: Path) -> Path:
  59. plots = out_dir / "plots"
  60. images = sorted(plots.rglob("*.png")) if plots.exists() else []
  61. report_path = out_dir / "plots_report.md"
  62. lines = [
  63. f"# {backend.title()} Analysis Plot Report (Regenerated)",
  64. "",
  65. "This document lists regenerated analysis plots with brief descriptions.",
  66. "",
  67. ]
  68. if not images:
  69. lines.append("No plot images were found for this backend run.")
  70. else:
  71. for image_path in images:
  72. rel = image_path.relative_to(out_dir).as_posix()
  73. title = image_path.stem.replace("_", " ").title()
  74. lines.append(f"## {title}")
  75. lines.append(_plot_description(image_path.name))
  76. lines.append("")
  77. lines.append(f"![{title}]({rel})")
  78. lines.append("")
  79. report_path.write_text("\n".join(lines), encoding="utf-8")
  80. return report_path
  81. def regenerate_performance_plots(backend_dir: Path) -> dict[str, Any]:
  82. """Regenerate performance threshold plots from existing CSV."""
  83. perf_csv = backend_dir / "performance_threshold_sweep.csv"
  84. if not perf_csv.exists():
  85. return {"status": "skipped", "reason": "no performance_threshold_sweep.csv"}
  86. df = pd.read_csv(perf_csv)
  87. backend = backend_dir.name if backend_dir.name != "plots" else "ensemble"
  88. # Get backend name from parent directory name if not found
  89. if backend_dir.parent.name not in ["ensemble", "bayesian"]:
  90. parent_name = backend_dir.name
  91. if parent_name in {"ensemble", "bayesian"}:
  92. backend = parent_name
  93. accuracy_plot_path = plots_dir(backend_dir) / "performance_threshold_accuracy.png"
  94. f1_plot_path = plots_dir(backend_dir) / "performance_threshold_f1.png"
  95. pair_plot_path = plots_dir(backend_dir) / "performance_threshold_accuracy_f1.png"
  96. save_performance_threshold_plot(
  97. df=df,
  98. backend=backend,
  99. output_path=accuracy_plot_path,
  100. metric_column="accuracy",
  101. metric_label="Accuracy",
  102. plot_key="performance_threshold_accuracy",
  103. )
  104. save_performance_threshold_plot(
  105. df=df,
  106. backend=backend,
  107. output_path=f1_plot_path,
  108. metric_column="f1",
  109. metric_label="F1",
  110. plot_key="performance_threshold_f1",
  111. )
  112. save_performance_threshold_pair_plot(
  113. df=df,
  114. backend=backend,
  115. output_path=pair_plot_path,
  116. plot_key="performance_threshold_accuracy_f1",
  117. )
  118. return {
  119. "status": "regenerated",
  120. "performance_threshold_accuracy": str(accuracy_plot_path),
  121. "performance_threshold_f1": str(f1_plot_path),
  122. "performance_threshold_accuracy_f1": str(pair_plot_path),
  123. }
  124. def regenerate_uncertainty_cutoff_plots(backend_dir: Path) -> dict[str, Any]:
  125. """Regenerate uncertainty cutoff plots from existing CSV."""
  126. cutoff_csv = backend_dir / "performance_uncertainty_cutoff.csv"
  127. percentile_csv = backend_dir / "performance_uncertainty_percentile_cutoff.csv"
  128. results = {"status": "skipped", "reason": "no cutoff CSV files found"}
  129. if cutoff_csv.exists():
  130. cutoff_df = pd.read_csv(cutoff_csv)
  131. results["status"] = "regenerated"
  132. # Create plots by uncertainty type
  133. for uncertainty_name in sorted(pd.unique(cutoff_df["uncertainty_type"])):
  134. sub_df = cutoff_df[cutoff_df["uncertainty_type"] == uncertainty_name].copy()
  135. slug = uncertainty_name.lower().replace(" ", "_")
  136. sub_accuracy_plot_path = (
  137. plots_dir(backend_dir)
  138. / f"performance_uncertainty_cutoff_{slug}_accuracy.png"
  139. )
  140. sub_f1_plot_path = (
  141. plots_dir(backend_dir) / f"performance_uncertainty_cutoff_{slug}_f1.png"
  142. )
  143. sub_pair_plot_path = (
  144. plots_dir(backend_dir)
  145. / f"performance_uncertainty_cutoff_{slug}_accuracy_f1.png"
  146. )
  147. save_uncertainty_cutoff_plot(
  148. cutoff_df=sub_df,
  149. title_prefix="Model Output / Uncertainty Cutoff Percentile",
  150. x_label="Restriction Level (0 = all samples, 100 = most restricted subset)",
  151. output_path=sub_accuracy_plot_path,
  152. metric_column="accuracy",
  153. metric_label="Accuracy",
  154. plot_key="performance_uncertainty_cutoff_accuracy",
  155. )
  156. save_uncertainty_cutoff_plot(
  157. cutoff_df=sub_df,
  158. title_prefix="Model Output / Uncertainty Cutoff Percentile",
  159. x_label="Restriction Level (0 = all samples, 100 = most restricted subset)",
  160. output_path=sub_f1_plot_path,
  161. metric_column="f1",
  162. metric_label="F1",
  163. plot_key="performance_uncertainty_cutoff_f1",
  164. )
  165. save_uncertainty_cutoff_pair_plot(
  166. cutoff_df=sub_df,
  167. title_prefix="Model Output / Uncertainty Cutoff Percentile",
  168. x_label="Restriction Level (0 = all samples, 100 = most restricted subset)",
  169. output_path=sub_pair_plot_path,
  170. plot_key="performance_uncertainty_cutoff_accuracy_f1",
  171. )
  172. if percentile_csv.exists():
  173. percentile_df = pd.read_csv(percentile_csv)
  174. results["status"] = "regenerated"
  175. # Create plots by uncertainty type
  176. for uncertainty_name in sorted(pd.unique(percentile_df["uncertainty_type"])):
  177. sub_df = percentile_df[
  178. percentile_df["uncertainty_type"] == uncertainty_name
  179. ].copy()
  180. slug = uncertainty_name.lower().replace(" ", "_")
  181. sub_accuracy_plot_path = (
  182. plots_dir(backend_dir)
  183. / f"performance_uncertainty_percentile_cutoff_{slug}_accuracy.png"
  184. )
  185. sub_f1_plot_path = (
  186. plots_dir(backend_dir)
  187. / f"performance_uncertainty_percentile_cutoff_{slug}_f1.png"
  188. )
  189. sub_pair_plot_path = (
  190. plots_dir(backend_dir)
  191. / f"performance_uncertainty_percentile_cutoff_{slug}_accuracy_f1.png"
  192. )
  193. save_uncertainty_cutoff_plot(
  194. cutoff_df=sub_df,
  195. title_prefix="Model Output / Uncertainty Percentile Floor",
  196. x_label="Percentile Floor (0 = all samples, 100 = top percentile subset)",
  197. output_path=sub_accuracy_plot_path,
  198. metric_column="accuracy",
  199. metric_label="Accuracy",
  200. plot_key="performance_uncertainty_percentile_cutoff_accuracy",
  201. )
  202. save_uncertainty_cutoff_plot(
  203. cutoff_df=sub_df,
  204. title_prefix="Model Output / Uncertainty Percentile Floor",
  205. x_label="Percentile Floor (0 = all samples, 100 = top percentile subset)",
  206. output_path=sub_f1_plot_path,
  207. metric_column="f1",
  208. metric_label="F1",
  209. plot_key="performance_uncertainty_percentile_cutoff_f1",
  210. )
  211. save_uncertainty_cutoff_pair_plot(
  212. cutoff_df=sub_df,
  213. title_prefix="Model Output / Uncertainty Percentile Floor",
  214. x_label="Percentile Floor (0 = all samples, 100 = top percentile subset)",
  215. output_path=sub_pair_plot_path,
  216. plot_key="performance_uncertainty_percentile_cutoff_accuracy_f1",
  217. )
  218. return results
  219. def regenerate_calibration_plots(backend_dir: Path) -> dict[str, Any]:
  220. """Regenerate calibration plots from existing calibration data."""
  221. calib_path = backend_dir / "calibration_per_bin.npy"
  222. if not calib_path.exists():
  223. return {"status": "skipped", "reason": "no calibration_per_bin.npy"}
  224. per_bin = np.load(calib_path)
  225. backend = backend_dir.name if backend_dir.name != "plots" else "ensemble"
  226. # Get backend name from parent directory name if not found
  227. if backend_dir.parent.name not in ["ensemble", "bayesian"]:
  228. parent_name = backend_dir.name
  229. if parent_name in {"ensemble", "bayesian"}:
  230. backend = parent_name
  231. plot_path = plots_dir(backend_dir) / "calibration_reliability.png"
  232. save_calibration_plot(per_bin=per_bin, backend=backend, output_path=plot_path)
  233. return {
  234. "status": "regenerated",
  235. "calibration_reliability": str(plot_path),
  236. }
  237. def main() -> None:
  238. parser = argparse.ArgumentParser(
  239. description="Regenerate analysis plots from existing computed data CSV files."
  240. )
  241. parser.add_argument(
  242. "backend_dir",
  243. type=Path,
  244. help="Path to backend-specific analysis output directory "
  245. "(e.g., analysis_output/run_xxx/ensemble)",
  246. )
  247. args = parser.parse_args()
  248. backend_dir = args.backend_dir.resolve()
  249. if not backend_dir.exists():
  250. print(
  251. f"Error: Backend directory does not exist: {backend_dir}", file=sys.stderr
  252. )
  253. sys.exit(1)
  254. print(f"Regenerating plots from: {backend_dir}")
  255. results: dict[str, Any] = {
  256. "backend_dir": str(backend_dir),
  257. "performance": regenerate_performance_plots(backend_dir),
  258. "uncertainty_cutoff": regenerate_uncertainty_cutoff_plots(backend_dir),
  259. "calibration": regenerate_calibration_plots(backend_dir),
  260. }
  261. # Write updated report
  262. report_path = _write_backend_plot_report(
  263. backend=backend_dir.name, out_dir=backend_dir
  264. )
  265. results["plots_report"] = str(report_path)
  266. print(f"\nPlot regeneration complete!")
  267. print(f"Results summary:")
  268. print(f" Performance plots: {results['performance'].get('status', 'unknown')}")
  269. print(
  270. f" Uncertainty cutoff plots: {results['uncertainty_cutoff'].get('status', 'unknown')}"
  271. )
  272. print(f" Calibration plots: {results['calibration'].get('status', 'unknown')}")
  273. print(f" Report written to: {report_path}")
  274. write_json(backend_dir / "plot_regeneration_log.json", results)
  275. if __name__ == "__main__":
  276. main()