plotting.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import pandas as pd
  7. import torch
  8. from matplotlib.axes import Axes
  9. # Easily editable plot text overrides by plot key.
  10. # Example:
  11. # "performance_threshold": {
  12. # "title": "Custom Title",
  13. # "x_label": "Custom X",
  14. # "y_label": "Custom Y",
  15. # }
  16. # Common keys:
  17. # - performance_threshold_accuracy
  18. # - performance_threshold_f1
  19. # - performance_uncertainty_cutoff_accuracy
  20. # - performance_uncertainty_cutoff_f1
  21. # - performance_uncertainty_percentile_cutoff_accuracy
  22. # - performance_uncertainty_percentile_cutoff_f1
  23. # - calibration_reliability
  24. # - noise_sensitivity_accuracy
  25. # - noise_sensitivity_f1
  26. # - noise_confidence
  27. # - noise_standard_deviation
  28. # - noise_predictive_uncertainty
  29. # - boxplot
  30. PLOT_TEXT_OVERRIDES: dict[str, dict[str, str]] = {}
  31. def _resolve_plot_text(
  32. plot_key: str,
  33. default_title: str,
  34. default_x_label: str,
  35. default_y_label: str,
  36. ) -> tuple[str, str, str]:
  37. override = PLOT_TEXT_OVERRIDES.get(plot_key, {})
  38. return (
  39. override.get("title", default_title),
  40. override.get("x_label", default_x_label),
  41. override.get("y_label", default_y_label),
  42. )
  43. def annotate_stats_box(
  44. ax: Axes,
  45. lines: list[str],
  46. location: str = "upper left",
  47. ) -> None:
  48. if not lines:
  49. return
  50. locations: dict[str, tuple[float, float, str, str]] = {
  51. "upper left": (0.02, 0.98, "left", "top"),
  52. "upper right": (0.98, 0.98, "right", "top"),
  53. "lower left": (0.02, 0.02, "left", "bottom"),
  54. "lower right": (0.98, 0.02, "right", "bottom"),
  55. }
  56. x, y, ha, va = locations.get(location, locations["upper left"])
  57. ax.text(
  58. x,
  59. y,
  60. "\n".join(lines),
  61. transform=ax.transAxes,
  62. ha=ha,
  63. va=va,
  64. fontsize=10,
  65. bbox={
  66. "boxstyle": "round,pad=0.35",
  67. "facecolor": "white",
  68. "edgecolor": "#555555",
  69. "alpha": 0.9,
  70. },
  71. )
  72. def _plot_correct_incorrect_bars(
  73. ax: Axes,
  74. x_values: pd.Series,
  75. n_correct: pd.Series,
  76. n_incorrect: pd.Series,
  77. ) -> None:
  78. x = np.asarray(x_values, dtype=float)
  79. correct = np.asarray(n_correct, dtype=float)
  80. incorrect = np.asarray(n_incorrect, dtype=float)
  81. if x.size == 0 or correct.size == 0 or incorrect.size == 0:
  82. return
  83. width = float(np.diff(np.sort(x)).min()) * 0.8 if x.size > 1 else 0.04
  84. max_count = float(max(np.nanmax(correct), np.nanmax(incorrect), 1.0))
  85. bars_ax = ax.twinx()
  86. bars_ax.patch.set_alpha(0.0)
  87. bars_ax.bar(
  88. x,
  89. correct,
  90. width=width,
  91. color="#2ca02c",
  92. alpha=0.2,
  93. label="correct",
  94. zorder=0,
  95. align="center",
  96. )
  97. bars_ax.bar(
  98. x,
  99. -incorrect,
  100. width=width,
  101. color="#d62728",
  102. alpha=0.2,
  103. label="incorrect",
  104. zorder=0,
  105. align="center",
  106. )
  107. bars_ax.axhline(0.0, color="gray", linewidth=0.8, alpha=0.4)
  108. bars_ax.set_ylim(-1.15 * max_count, 1.15 * max_count)
  109. bars_ax.set_yticks([])
  110. bars_ax.grid(False)
  111. def plots_dir(output_dir: Path) -> Path:
  112. plots = output_dir / "plots"
  113. plots.mkdir(parents=True, exist_ok=True)
  114. return plots
  115. def save_performance_threshold_plot(
  116. df: pd.DataFrame,
  117. backend: str,
  118. output_path: Path,
  119. metric_column: str,
  120. metric_label: str,
  121. plot_key: str,
  122. ) -> None:
  123. title, x_label, y_label = _resolve_plot_text(
  124. plot_key=plot_key,
  125. default_title=f"{metric_label} vs Decision Threshold ({backend})",
  126. default_x_label="Decision Threshold",
  127. default_y_label=metric_label,
  128. )
  129. n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric(
  130. df["tn"], errors="coerce"
  131. )
  132. n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric(
  133. df["fn"], errors="coerce"
  134. )
  135. fig, ax = plt.subplots(figsize=(10, 5))
  136. _plot_correct_incorrect_bars(ax, df["threshold"], n_correct, n_incorrect)
  137. ax.plot(df["threshold"], df[metric_column], label=metric_label, marker="o")
  138. ax.set_xlabel(x_label)
  139. ax.set_ylabel(y_label)
  140. ax.set_title(title)
  141. ax.grid(True, alpha=0.3)
  142. ax.legend()
  143. fig.tight_layout()
  144. output_path.parent.mkdir(parents=True, exist_ok=True)
  145. fig.savefig(output_path)
  146. plt.close(fig)
  147. def save_performance_threshold_pair_plot(
  148. df: pd.DataFrame,
  149. backend: str,
  150. output_path: Path,
  151. plot_key: str,
  152. ) -> None:
  153. title, x_label, _ = _resolve_plot_text(
  154. plot_key=plot_key,
  155. default_title=f"Accuracy and F1 vs Decision Threshold ({backend})",
  156. default_x_label="Decision Threshold",
  157. default_y_label="Accuracy/F1",
  158. )
  159. n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric(
  160. df["tn"], errors="coerce"
  161. )
  162. n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric(
  163. df["fn"], errors="coerce"
  164. )
  165. fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  166. for ax, metric_col, metric_label, marker in [
  167. (axes[0], "accuracy", "Accuracy", "o"),
  168. (axes[1], "f1", "F1", "s"),
  169. ]:
  170. _plot_correct_incorrect_bars(ax, df["threshold"], n_correct, n_incorrect)
  171. ax.plot(df["threshold"], df[metric_col], label=metric_label, marker=marker)
  172. ax.set_xlabel(x_label)
  173. ax.set_ylabel(metric_label)
  174. ax.set_title(f"{metric_label}")
  175. ax.grid(True, alpha=0.3)
  176. ax.legend()
  177. fig.suptitle(title)
  178. fig.tight_layout()
  179. output_path.parent.mkdir(parents=True, exist_ok=True)
  180. fig.savefig(output_path)
  181. plt.close(fig)
  182. def save_uncertainty_cutoff_plot(
  183. cutoff_df: pd.DataFrame,
  184. title_prefix: str,
  185. x_label: str,
  186. output_path: Path,
  187. metric_column: str,
  188. metric_label: str,
  189. plot_key: str,
  190. ) -> None:
  191. title, x_label_final, y_label = _resolve_plot_text(
  192. plot_key=plot_key,
  193. default_title=f"{metric_label} vs {title_prefix}",
  194. default_x_label=x_label,
  195. default_y_label=metric_label,
  196. )
  197. fig, ax = plt.subplots(figsize=(10, 5))
  198. first_group = (
  199. cutoff_df.sort_values(["uncertainty_type", "restriction_level"])
  200. .groupby("uncertainty_type", as_index=False)
  201. .head(1)
  202. )
  203. if not first_group.empty:
  204. # Draw count bars once; uncertainty lines are overlaid afterwards.
  205. rep_name = str(first_group.iloc[0]["uncertainty_type"])
  206. rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values(
  207. "restriction_level"
  208. )
  209. _plot_correct_incorrect_bars(
  210. ax,
  211. rep["restriction_level"],
  212. pd.to_numeric(rep["n_correct"], errors="coerce"),
  213. pd.to_numeric(rep["n_incorrect"], errors="coerce"),
  214. )
  215. for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
  216. g = group.sort_values("restriction_level")
  217. ax.plot(
  218. g["restriction_level"],
  219. g[metric_column],
  220. marker="o",
  221. label=uncertainty_name,
  222. )
  223. ax.set_title(title)
  224. ax.set_xlabel(x_label_final)
  225. ax.set_ylabel(y_label)
  226. ax.grid(True, alpha=0.3)
  227. ax.legend()
  228. fig.tight_layout()
  229. output_path.parent.mkdir(parents=True, exist_ok=True)
  230. fig.savefig(output_path)
  231. plt.close(fig)
  232. def save_uncertainty_cutoff_pair_plot(
  233. cutoff_df: pd.DataFrame,
  234. title_prefix: str,
  235. x_label: str,
  236. output_path: Path,
  237. plot_key: str,
  238. ) -> None:
  239. title, x_label_final, _ = _resolve_plot_text(
  240. plot_key=plot_key,
  241. default_title=f"Accuracy and F1 vs {title_prefix}",
  242. default_x_label=x_label,
  243. default_y_label="Accuracy/F1",
  244. )
  245. fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  246. first_group = (
  247. cutoff_df.sort_values(["uncertainty_type", "restriction_level"])
  248. .groupby("uncertainty_type", as_index=False)
  249. .head(1)
  250. )
  251. if not first_group.empty:
  252. rep_name = str(first_group.iloc[0]["uncertainty_type"])
  253. rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values(
  254. "restriction_level"
  255. )
  256. for ax in axes:
  257. _plot_correct_incorrect_bars(
  258. ax,
  259. rep["restriction_level"],
  260. pd.to_numeric(rep["n_correct"], errors="coerce"),
  261. pd.to_numeric(rep["n_incorrect"], errors="coerce"),
  262. )
  263. for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
  264. g = group.sort_values("restriction_level")
  265. axes[0].plot(
  266. g["restriction_level"], g["accuracy"], marker="o", label=uncertainty_name
  267. )
  268. axes[1].plot(
  269. g["restriction_level"], g["f1"], marker="s", label=uncertainty_name
  270. )
  271. axes[0].set_title("Accuracy")
  272. axes[1].set_title("F1")
  273. for ax, metric_label in [(axes[0], "Accuracy"), (axes[1], "F1")]:
  274. ax.set_xlabel(x_label_final)
  275. ax.set_ylabel(metric_label)
  276. ax.grid(True, alpha=0.3)
  277. ax.legend()
  278. fig.suptitle(title)
  279. fig.tight_layout()
  280. output_path.parent.mkdir(parents=True, exist_ok=True)
  281. fig.savefig(output_path)
  282. plt.close(fig)
  283. def save_calibration_plot(per_bin: np.ndarray, backend: str, output_path: Path) -> None:
  284. title, x_label, y_label = _resolve_plot_text(
  285. plot_key="calibration_reliability",
  286. default_title=f"Reliability Diagram ({backend})",
  287. default_x_label="Mean Predicted Probability",
  288. default_y_label="Empirical Fraction Positive",
  289. )
  290. fig, ax = plt.subplots(figsize=(6, 6))
  291. valid = ~np.isnan(per_bin[:, 1])
  292. ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal")
  293. ax.plot(per_bin[valid, 0], per_bin[valid, 1], marker="o", label=backend)
  294. ax.set_xlabel(x_label)
  295. ax.set_ylabel(y_label)
  296. ax.set_title(title)
  297. ax.legend()
  298. ax.grid(True, alpha=0.3)
  299. fig.tight_layout()
  300. output_path.parent.mkdir(parents=True, exist_ok=True)
  301. fig.savefig(output_path)
  302. plt.close(fig)
  303. def save_boxplot(
  304. data: list[np.ndarray],
  305. tick_labels: list[str],
  306. x_label: str,
  307. y_label: str,
  308. title: str,
  309. output_path: Path,
  310. ) -> None:
  311. title_final, x_label_final, y_label_final = _resolve_plot_text(
  312. plot_key="boxplot",
  313. default_title=title,
  314. default_x_label=x_label,
  315. default_y_label=y_label,
  316. )
  317. labels_with_n = [
  318. f"{label}\n(n={len(np.asarray(values, dtype=float))})"
  319. for label, values in zip(tick_labels, data)
  320. ]
  321. fig, ax = plt.subplots(figsize=(9, 5))
  322. ax.boxplot(data, tick_labels=labels_with_n)
  323. ax.set_xlabel(x_label_final)
  324. ax.set_ylabel(y_label_final)
  325. ax.set_title(title_final)
  326. ax.grid(True, axis="y", alpha=0.3)
  327. fig.tight_layout()
  328. output_path.parent.mkdir(parents=True, exist_ok=True)
  329. fig.savefig(output_path)
  330. plt.close(fig)
  331. def _central_slice(volume: torch.Tensor) -> np.ndarray:
  332. tensor = volume.detach().cpu()
  333. if tensor.ndim == 5:
  334. tensor = tensor[0]
  335. if tensor.ndim == 4:
  336. tensor = tensor[0]
  337. if tensor.ndim != 3:
  338. raise ValueError(
  339. f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}"
  340. )
  341. center_index = tensor.shape[0] // 2
  342. return tensor[center_index].numpy().astype(float)
  343. def _normalize_for_display(image: np.ndarray) -> np.ndarray:
  344. low = float(np.percentile(image, 1))
  345. high = float(np.percentile(image, 99))
  346. if high <= low:
  347. return np.zeros_like(image, dtype=float)
  348. clipped = np.clip(image, low, high)
  349. return (clipped - low) / (high - low)
  350. def save_noise_example_grid(
  351. original_mri: torch.Tensor,
  352. noisy_by_sigma: list[tuple[float, torch.Tensor]],
  353. output_path: Path,
  354. title: str,
  355. max_images: int = 9,
  356. n_rows: int = 2,
  357. ) -> None:
  358. if not noisy_by_sigma:
  359. return
  360. n_total = len(noisy_by_sigma)
  361. target = max(1, min(int(max_images), n_total))
  362. if target >= n_total:
  363. selected = noisy_by_sigma
  364. else:
  365. # Sample indices across the full tested range, including first and last.
  366. raw_idx = np.linspace(0, n_total - 1, num=target)
  367. idx = np.round(raw_idx).astype(int)
  368. selected_indices = sorted(set(idx.tolist()))
  369. if len(selected_indices) < target:
  370. existing = set(selected_indices)
  371. for i in range(n_total):
  372. if i in existing:
  373. continue
  374. selected_indices.append(i)
  375. if len(selected_indices) >= target:
  376. break
  377. selected_indices = sorted(selected_indices)
  378. selected = [noisy_by_sigma[i] for i in selected_indices]
  379. n_images = len(selected)
  380. n_rows = max(1, int(n_rows))
  381. n_cols = int(np.ceil(n_images / n_rows))
  382. fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.8 * n_cols, 3.2 * n_rows))
  383. axes_flat = np.atleast_1d(axes).reshape(-1)
  384. for idx, (sigma, noisy_tensor) in enumerate(selected):
  385. ax = axes_flat[idx]
  386. noisy_slice = _normalize_for_display(_central_slice(noisy_tensor))
  387. ax.imshow(noisy_slice, cmap="gray")
  388. ax.set_title(f"Noise factor={sigma:g}")
  389. ax.axis("off")
  390. for idx in range(n_images, len(axes_flat)):
  391. axes_flat[idx].axis("off")
  392. fig.suptitle(title)
  393. fig.tight_layout()
  394. output_path.parent.mkdir(parents=True, exist_ok=True)
  395. fig.savefig(output_path)
  396. plt.close(fig)
  397. def save_clean_scan_image(
  398. original_mri: torch.Tensor,
  399. output_path: Path,
  400. ) -> None:
  401. image = _normalize_for_display(_central_slice(original_mri))
  402. fig, ax = plt.subplots(figsize=(4, 4))
  403. ax.imshow(image, cmap="gray")
  404. ax.axis("off")
  405. fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
  406. output_path.parent.mkdir(parents=True, exist_ok=True)
  407. fig.savefig(output_path, bbox_inches="tight", pad_inches=0)
  408. plt.close(fig)
  409. def save_noise_metrics_plot(
  410. x: pd.Series,
  411. y: pd.Series,
  412. legend_label: str,
  413. marker: str,
  414. x_label: str,
  415. y_label: str,
  416. title: str,
  417. output_path: Path,
  418. plot_key: str,
  419. ) -> None:
  420. title_final, x_label_final, y_label_final = _resolve_plot_text(
  421. plot_key=plot_key,
  422. default_title=title,
  423. default_x_label=x_label,
  424. default_y_label=y_label,
  425. )
  426. fig, ax = plt.subplots(figsize=(10, 5))
  427. ax.plot(x, y, marker=marker, label=legend_label)
  428. ax.set_xlabel(x_label_final)
  429. ax.set_ylabel(y_label_final)
  430. ax.set_title(title_final)
  431. ax.grid(True, alpha=0.3)
  432. ax.legend()
  433. fig.tight_layout()
  434. output_path.parent.mkdir(parents=True, exist_ok=True)
  435. fig.savefig(output_path)
  436. plt.close(fig)
  437. def save_metric_pair_plot(
  438. x: pd.Series,
  439. left_y: pd.Series,
  440. right_y: pd.Series,
  441. left_label: str,
  442. right_label: str,
  443. x_label: str,
  444. y_label: str,
  445. title: str,
  446. output_path: Path,
  447. plot_key: str,
  448. ) -> None:
  449. title_final, x_label_final, y_label_final = _resolve_plot_text(
  450. plot_key=plot_key,
  451. default_title=title,
  452. default_x_label=x_label,
  453. default_y_label=y_label,
  454. )
  455. fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  456. axes[0].plot(x, left_y, marker="o", label=left_label)
  457. axes[1].plot(x, right_y, marker="s", label=right_label)
  458. for ax, name in [(axes[0], left_label), (axes[1], right_label)]:
  459. ax.set_xlabel(x_label_final)
  460. ax.set_ylabel(name)
  461. ax.set_title(name)
  462. ax.grid(True, alpha=0.3)
  463. ax.legend()
  464. fig.suptitle(title_final)
  465. fig.tight_layout()
  466. output_path.parent.mkdir(parents=True, exist_ok=True)
  467. fig.savefig(output_path)
  468. plt.close(fig)