plotting.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import pandas as pd
  7. import torch
  8. from matplotlib.axes import Axes
  9. # Easily editable plot text overrides by plot key.
  10. # Example:
  11. # "performance_threshold": {
  12. # "title": "Custom Title",
  13. # "x_label": "Custom X",
  14. # "y_label": "Custom Y",
  15. # }
  16. # Common keys:
  17. # - performance_threshold_accuracy
  18. # - performance_threshold_f1
  19. # - performance_uncertainty_cutoff_accuracy
  20. # - performance_uncertainty_cutoff_f1
  21. # - performance_uncertainty_percentile_cutoff_accuracy
  22. # - performance_uncertainty_percentile_cutoff_f1
  23. # - calibration_reliability
  24. # - noise_sensitivity_accuracy
  25. # - noise_sensitivity_f1
  26. # - noise_confidence
  27. # - noise_standard_deviation
  28. # - noise_predictive_uncertainty
  29. # - boxplot
  30. PLOT_TEXT_OVERRIDES: dict[str, dict[str, str]] = {}
  31. def _resolve_plot_text(
  32. plot_key: str,
  33. default_title: str,
  34. default_x_label: str,
  35. default_y_label: str,
  36. ) -> tuple[str, str, str]:
  37. override = PLOT_TEXT_OVERRIDES.get(plot_key, {})
  38. return (
  39. override.get("title", default_title),
  40. override.get("x_label", default_x_label),
  41. override.get("y_label", default_y_label),
  42. )
  43. def annotate_stats_box(
  44. ax: Axes,
  45. lines: list[str],
  46. location: str = "upper left",
  47. ) -> None:
  48. if not lines:
  49. return
  50. locations: dict[str, tuple[float, float, str, str]] = {
  51. "upper left": (0.02, 0.98, "left", "top"),
  52. "upper right": (0.98, 0.98, "right", "top"),
  53. "lower left": (0.02, 0.02, "left", "bottom"),
  54. "lower right": (0.98, 0.02, "right", "bottom"),
  55. }
  56. x, y, ha, va = locations.get(location, locations["upper left"])
  57. ax.text(
  58. x,
  59. y,
  60. "\n".join(lines),
  61. transform=ax.transAxes,
  62. ha=ha,
  63. va=va,
  64. fontsize=10,
  65. bbox={
  66. "boxstyle": "round,pad=0.35",
  67. "facecolor": "white",
  68. "edgecolor": "#555555",
  69. "alpha": 0.9,
  70. },
  71. )
  72. def _plot_correct_incorrect_bars(
  73. ax: Axes,
  74. x_values: pd.Series,
  75. n_correct: pd.Series,
  76. n_incorrect: pd.Series,
  77. ) -> None:
  78. x = np.asarray(x_values, dtype=float)
  79. correct = np.asarray(n_correct, dtype=float)
  80. incorrect = np.asarray(n_incorrect, dtype=float)
  81. if x.size == 0 or correct.size == 0 or incorrect.size == 0:
  82. return
  83. width = float(np.diff(np.sort(x)).min()) * 0.8 if x.size > 1 else 0.04
  84. max_count = float(max(np.nanmax(correct), np.nanmax(incorrect), 1.0))
  85. bars_ax = ax.twinx()
  86. bars_ax.patch.set_alpha(0.0)
  87. bars_ax.bar(
  88. x,
  89. correct,
  90. width=width,
  91. color="#2ca02c",
  92. alpha=0.2,
  93. label="correct",
  94. zorder=0,
  95. align="center",
  96. )
  97. bars_ax.bar(
  98. x,
  99. -incorrect,
  100. width=width,
  101. color="#d62728",
  102. alpha=0.2,
  103. label="incorrect",
  104. zorder=0,
  105. align="center",
  106. )
  107. bars_ax.axhline(0.0, color="gray", linewidth=0.8, alpha=0.4)
  108. bars_ax.set_ylim(-1.15 * max_count, 1.15 * max_count)
  109. bars_ax.set_yticks([])
  110. bars_ax.grid(False)
  111. def save_coverage_bar_plot(
  112. x_values: pd.Series | np.ndarray,
  113. n_correct: pd.Series | np.ndarray,
  114. n_incorrect: pd.Series | np.ndarray,
  115. x_label: str,
  116. title: str,
  117. output_path: Path,
  118. ) -> None:
  119. """Save a standalone bar chart showing sample counts (correct vs incorrect)."""
  120. x = np.asarray(x_values, dtype=float)
  121. correct = np.asarray(n_correct, dtype=float)
  122. incorrect = np.asarray(n_incorrect, dtype=float)
  123. if x.size == 0 or correct.size == 0 or incorrect.size == 0:
  124. return
  125. width = float(np.diff(np.sort(x)).min()) * 0.8 if x.size > 1 else 0.04
  126. max_count = float(max(np.nanmax(correct), np.nanmax(incorrect), 1.0))
  127. fig, ax = plt.subplots(figsize=(10, 5))
  128. ax.bar(
  129. x,
  130. correct,
  131. width=width,
  132. color="#2ca02c",
  133. alpha=0.6,
  134. label="correct",
  135. align="center",
  136. )
  137. ax.bar(
  138. x,
  139. -incorrect,
  140. width=width,
  141. color="#d62728",
  142. alpha=0.6,
  143. label="incorrect",
  144. align="center",
  145. )
  146. ax.axhline(0.0, color="gray", linewidth=0.8, alpha=0.4)
  147. ax.set_ylim(-1.15 * max_count, 1.15 * max_count)
  148. ax.set_xlabel(x_label)
  149. ax.set_ylabel("Sample Count")
  150. ax.set_title(title)
  151. ax.legend()
  152. ax.grid(True, alpha=0.3)
  153. fig.tight_layout()
  154. output_path.parent.mkdir(parents=True, exist_ok=True)
  155. fig.savefig(output_path)
  156. plt.close(fig)
  157. def plots_dir(output_dir: Path) -> Path:
  158. plots = output_dir / "plots"
  159. plots.mkdir(parents=True, exist_ok=True)
  160. return plots
  161. def save_performance_threshold_plot(
  162. df: pd.DataFrame,
  163. backend: str,
  164. output_path: Path,
  165. metric_column: str,
  166. metric_label: str,
  167. plot_key: str,
  168. ) -> None:
  169. title, x_label, y_label = _resolve_plot_text(
  170. plot_key=plot_key,
  171. default_title=f"{metric_label} vs Decision Threshold ({backend})",
  172. default_x_label="Decision Threshold",
  173. default_y_label=metric_label,
  174. )
  175. n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric(
  176. df["tn"], errors="coerce"
  177. )
  178. n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric(
  179. df["fn"], errors="coerce"
  180. )
  181. fig, ax = plt.subplots(figsize=(10, 5))
  182. ax.plot(df["threshold"], df[metric_column], label=metric_label, marker="o")
  183. ax.set_xlabel(x_label)
  184. ax.set_ylabel(y_label)
  185. ax.set_title(title)
  186. ax.grid(True, alpha=0.3)
  187. ax.legend()
  188. fig.tight_layout()
  189. output_path.parent.mkdir(parents=True, exist_ok=True)
  190. fig.savefig(output_path)
  191. plt.close(fig)
  192. # Generate separate coverage bar plot
  193. coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
  194. save_coverage_bar_plot(
  195. x_values=df["threshold"],
  196. n_correct=n_correct,
  197. n_incorrect=n_incorrect,
  198. x_label=x_label,
  199. title=f"Sample Distribution vs Decision Threshold ({backend})",
  200. output_path=coverage_path,
  201. )
  202. def save_performance_threshold_pair_plot(
  203. df: pd.DataFrame,
  204. backend: str,
  205. output_path: Path,
  206. plot_key: str,
  207. ) -> None:
  208. title, x_label, _ = _resolve_plot_text(
  209. plot_key=plot_key,
  210. default_title=f"Accuracy and F1 vs Decision Threshold ({backend})",
  211. default_x_label="Decision Threshold",
  212. default_y_label="Accuracy/F1",
  213. )
  214. n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric(
  215. df["tn"], errors="coerce"
  216. )
  217. n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric(
  218. df["fn"], errors="coerce"
  219. )
  220. fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  221. for ax, metric_col, metric_label, marker in [
  222. (axes[0], "accuracy", "Accuracy", "o"),
  223. (axes[1], "f1", "F1", "s"),
  224. ]:
  225. ax.plot(df["threshold"], df[metric_col], label=metric_label, marker=marker)
  226. ax.set_xlabel(x_label)
  227. ax.set_ylabel(metric_label)
  228. ax.set_title(f"{metric_label}")
  229. ax.grid(True, alpha=0.3)
  230. ax.legend()
  231. fig.suptitle(title)
  232. fig.tight_layout()
  233. output_path.parent.mkdir(parents=True, exist_ok=True)
  234. fig.savefig(output_path)
  235. plt.close(fig)
  236. # Generate separate coverage bar plot
  237. coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
  238. save_coverage_bar_plot(
  239. x_values=df["threshold"],
  240. n_correct=n_correct,
  241. n_incorrect=n_incorrect,
  242. x_label=x_label,
  243. title=f"Sample Distribution vs Decision Threshold ({backend})",
  244. output_path=coverage_path,
  245. )
  246. def save_uncertainty_cutoff_plot(
  247. cutoff_df: pd.DataFrame,
  248. title_prefix: str,
  249. x_label: str,
  250. output_path: Path,
  251. metric_column: str,
  252. metric_label: str,
  253. plot_key: str,
  254. ) -> None:
  255. title, x_label_final, y_label = _resolve_plot_text(
  256. plot_key=plot_key,
  257. default_title=f"{metric_label} vs {title_prefix}",
  258. default_x_label=x_label,
  259. default_y_label=metric_label,
  260. )
  261. fig, ax = plt.subplots(figsize=(10, 5))
  262. for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
  263. g = group.sort_values("restriction_level")
  264. ax.plot(
  265. g["restriction_level"],
  266. g[metric_column],
  267. marker="o",
  268. label=uncertainty_name,
  269. )
  270. ax.set_title(title)
  271. ax.set_xlabel(x_label_final)
  272. ax.set_ylabel(y_label)
  273. ax.grid(True, alpha=0.3)
  274. ax.legend()
  275. fig.tight_layout()
  276. output_path.parent.mkdir(parents=True, exist_ok=True)
  277. fig.savefig(output_path)
  278. plt.close(fig)
  279. # Generate separate coverage bar plot
  280. first_group = (
  281. cutoff_df.sort_values(["uncertainty_type", "restriction_level"])
  282. .groupby("uncertainty_type", as_index=False)
  283. .head(1)
  284. )
  285. if not first_group.empty:
  286. rep_name = str(first_group.iloc[0]["uncertainty_type"])
  287. rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values(
  288. "restriction_level"
  289. )
  290. coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
  291. save_coverage_bar_plot(
  292. x_values=rep["restriction_level"],
  293. n_correct=pd.to_numeric(rep["n_correct"], errors="coerce"),
  294. n_incorrect=pd.to_numeric(rep["n_incorrect"], errors="coerce"),
  295. x_label=x_label_final,
  296. title=f"Sample Coverage vs {title_prefix}",
  297. output_path=coverage_path,
  298. )
  299. def save_uncertainty_cutoff_pair_plot(
  300. cutoff_df: pd.DataFrame,
  301. title_prefix: str,
  302. x_label: str,
  303. output_path: Path,
  304. plot_key: str,
  305. ) -> None:
  306. title, x_label_final, _ = _resolve_plot_text(
  307. plot_key=plot_key,
  308. default_title=f"Accuracy and F1 vs {title_prefix}",
  309. default_x_label=x_label,
  310. default_y_label="Accuracy/F1",
  311. )
  312. fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  313. for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
  314. g = group.sort_values("restriction_level")
  315. axes[0].plot(
  316. g["restriction_level"], g["accuracy"], marker="o", label=uncertainty_name
  317. )
  318. axes[1].plot(
  319. g["restriction_level"], g["f1"], marker="s", label=uncertainty_name
  320. )
  321. axes[0].set_title("Accuracy")
  322. axes[1].set_title("F1")
  323. for ax, metric_label in [(axes[0], "Accuracy"), (axes[1], "F1")]:
  324. ax.set_xlabel(x_label_final)
  325. ax.set_ylabel(metric_label)
  326. ax.grid(True, alpha=0.3)
  327. ax.legend()
  328. fig.suptitle(title)
  329. fig.tight_layout()
  330. output_path.parent.mkdir(parents=True, exist_ok=True)
  331. fig.savefig(output_path)
  332. plt.close(fig)
  333. # Generate separate coverage bar plot
  334. first_group = (
  335. cutoff_df.sort_values(["uncertainty_type", "restriction_level"])
  336. .groupby("uncertainty_type", as_index=False)
  337. .head(1)
  338. )
  339. if not first_group.empty:
  340. rep_name = str(first_group.iloc[0]["uncertainty_type"])
  341. rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values(
  342. "restriction_level"
  343. )
  344. coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
  345. save_coverage_bar_plot(
  346. x_values=rep["restriction_level"],
  347. n_correct=pd.to_numeric(rep["n_correct"], errors="coerce"),
  348. n_incorrect=pd.to_numeric(rep["n_incorrect"], errors="coerce"),
  349. x_label=x_label_final,
  350. title=f"Sample Coverage vs {title_prefix}",
  351. output_path=coverage_path,
  352. )
  353. def save_calibration_plot(per_bin: np.ndarray, backend: str, output_path: Path) -> None:
  354. title, x_label, y_label = _resolve_plot_text(
  355. plot_key="calibration_reliability",
  356. default_title=f"Reliability Diagram ({backend})",
  357. default_x_label="Mean Predicted Probability",
  358. default_y_label="Empirical Fraction Positive",
  359. )
  360. fig, ax = plt.subplots(figsize=(6, 6))
  361. valid = ~np.isnan(per_bin[:, 1])
  362. ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal")
  363. ax.plot(per_bin[valid, 0], per_bin[valid, 1], marker="o", label=backend)
  364. ax.set_xlabel(x_label)
  365. ax.set_ylabel(y_label)
  366. ax.set_title(title)
  367. ax.legend()
  368. ax.grid(True, alpha=0.3)
  369. fig.tight_layout()
  370. output_path.parent.mkdir(parents=True, exist_ok=True)
  371. fig.savefig(output_path)
  372. plt.close(fig)
  373. def save_boxplot(
  374. data: list[np.ndarray],
  375. tick_labels: list[str],
  376. x_label: str,
  377. y_label: str,
  378. title: str,
  379. output_path: Path,
  380. ) -> None:
  381. title_final, x_label_final, y_label_final = _resolve_plot_text(
  382. plot_key="boxplot",
  383. default_title=title,
  384. default_x_label=x_label,
  385. default_y_label=y_label,
  386. )
  387. labels_with_n = [
  388. f"{label}\n(n={len(np.asarray(values, dtype=float))})"
  389. for label, values in zip(tick_labels, data)
  390. ]
  391. fig, ax = plt.subplots(figsize=(9, 5))
  392. ax.boxplot(data, tick_labels=labels_with_n)
  393. ax.set_xlabel(x_label_final)
  394. ax.set_ylabel(y_label_final)
  395. ax.set_title(title_final)
  396. ax.grid(True, axis="y", alpha=0.3)
  397. fig.tight_layout()
  398. output_path.parent.mkdir(parents=True, exist_ok=True)
  399. fig.savefig(output_path)
  400. plt.close(fig)
  401. def _central_slice(volume: torch.Tensor) -> np.ndarray:
  402. tensor = volume.detach().cpu()
  403. if tensor.ndim == 5:
  404. tensor = tensor[0]
  405. if tensor.ndim == 4:
  406. tensor = tensor[0]
  407. if tensor.ndim != 3:
  408. raise ValueError(
  409. f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}"
  410. )
  411. center_index = tensor.shape[0] // 2
  412. return tensor[center_index].numpy().astype(float)
  413. def _normalize_for_display(image: np.ndarray) -> np.ndarray:
  414. low = float(np.percentile(image, 1))
  415. high = float(np.percentile(image, 99))
  416. if high <= low:
  417. return np.zeros_like(image, dtype=float)
  418. clipped = np.clip(image, low, high)
  419. return (clipped - low) / (high - low)
  420. def save_noise_example_grid(
  421. original_mri: torch.Tensor,
  422. noisy_by_sigma: list[tuple[float, torch.Tensor]],
  423. output_path: Path,
  424. title: str,
  425. max_images: int = 9,
  426. n_rows: int = 2,
  427. ) -> None:
  428. if not noisy_by_sigma:
  429. return
  430. n_total = len(noisy_by_sigma)
  431. target = max(1, min(int(max_images), n_total))
  432. if target >= n_total:
  433. selected = noisy_by_sigma
  434. else:
  435. # Sample indices across the full tested range, including first and last.
  436. raw_idx = np.linspace(0, n_total - 1, num=target)
  437. idx = np.round(raw_idx).astype(int)
  438. selected_indices = sorted(set(idx.tolist()))
  439. if len(selected_indices) < target:
  440. existing = set(selected_indices)
  441. for i in range(n_total):
  442. if i in existing:
  443. continue
  444. selected_indices.append(i)
  445. if len(selected_indices) >= target:
  446. break
  447. selected_indices = sorted(selected_indices)
  448. selected = [noisy_by_sigma[i] for i in selected_indices]
  449. n_images = len(selected)
  450. n_rows = max(1, int(n_rows))
  451. n_cols = int(np.ceil(n_images / n_rows))
  452. fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.8 * n_cols, 3.2 * n_rows))
  453. axes_flat = np.atleast_1d(axes).reshape(-1)
  454. for idx, (sigma, noisy_tensor) in enumerate(selected):
  455. ax = axes_flat[idx]
  456. noisy_slice = _normalize_for_display(_central_slice(noisy_tensor))
  457. ax.imshow(noisy_slice, cmap="gray")
  458. ax.set_title(f"Noise factor={sigma:g}")
  459. ax.axis("off")
  460. for idx in range(n_images, len(axes_flat)):
  461. axes_flat[idx].axis("off")
  462. fig.suptitle(title)
  463. fig.tight_layout()
  464. output_path.parent.mkdir(parents=True, exist_ok=True)
  465. fig.savefig(output_path)
  466. plt.close(fig)
  467. def save_clean_scan_image(
  468. original_mri: torch.Tensor,
  469. output_path: Path,
  470. ) -> None:
  471. image = _normalize_for_display(_central_slice(original_mri))
  472. fig, ax = plt.subplots(figsize=(4, 4))
  473. ax.imshow(image, cmap="gray")
  474. ax.axis("off")
  475. fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
  476. output_path.parent.mkdir(parents=True, exist_ok=True)
  477. fig.savefig(output_path, bbox_inches="tight", pad_inches=0)
  478. plt.close(fig)
  479. def save_noise_metrics_plot(
  480. x: pd.Series,
  481. y: pd.Series,
  482. legend_label: str,
  483. marker: str,
  484. x_label: str,
  485. y_label: str,
  486. title: str,
  487. output_path: Path,
  488. plot_key: str,
  489. ) -> None:
  490. title_final, x_label_final, y_label_final = _resolve_plot_text(
  491. plot_key=plot_key,
  492. default_title=title,
  493. default_x_label=x_label,
  494. default_y_label=y_label,
  495. )
  496. fig, ax = plt.subplots(figsize=(10, 5))
  497. ax.plot(x, y, marker=marker, label=legend_label)
  498. ax.set_xlabel(x_label_final)
  499. ax.set_ylabel(y_label_final)
  500. ax.set_title(title_final)
  501. ax.grid(True, alpha=0.3)
  502. ax.legend()
  503. fig.tight_layout()
  504. output_path.parent.mkdir(parents=True, exist_ok=True)
  505. fig.savefig(output_path)
  506. plt.close(fig)
  507. def save_metric_pair_plot(
  508. x: pd.Series,
  509. left_y: pd.Series,
  510. right_y: pd.Series,
  511. left_label: str,
  512. right_label: str,
  513. x_label: str,
  514. y_label: str,
  515. title: str,
  516. output_path: Path,
  517. plot_key: str,
  518. ) -> None:
  519. title_final, x_label_final, y_label_final = _resolve_plot_text(
  520. plot_key=plot_key,
  521. default_title=title,
  522. default_x_label=x_label,
  523. default_y_label=y_label,
  524. )
  525. fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  526. axes[0].plot(x, left_y, marker="o", label=left_label)
  527. axes[1].plot(x, right_y, marker="s", label=right_label)
  528. for ax, name in [(axes[0], left_label), (axes[1], right_label)]:
  529. ax.set_xlabel(x_label_final)
  530. ax.set_ylabel(name)
  531. ax.set_title(name)
  532. ax.grid(True, alpha=0.3)
  533. ax.legend()
  534. fig.suptitle(title_final)
  535. fig.tight_layout()
  536. output_path.parent.mkdir(parents=True, exist_ok=True)
  537. fig.savefig(output_path)
  538. plt.close(fig)