plotting.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import pandas as pd
  7. import torch
  8. def plots_dir(output_dir: Path) -> Path:
  9. plots = output_dir / "plots"
  10. plots.mkdir(parents=True, exist_ok=True)
  11. return plots
  12. def save_performance_threshold_plot(
  13. df: pd.DataFrame, backend: str, output_path: Path
  14. ) -> None:
  15. fig, ax = plt.subplots(figsize=(10, 5))
  16. ax.plot(df["threshold"], df["accuracy"], label="accuracy", marker="o")
  17. ax.plot(df["threshold"], df["f1"], label="f1", marker="s")
  18. ax.set_xlabel("Threshold")
  19. ax.set_ylabel("Score")
  20. ax.set_title(f"Performance vs Threshold ({backend})")
  21. ax.grid(True, alpha=0.3)
  22. ax.legend()
  23. fig.tight_layout()
  24. output_path.parent.mkdir(parents=True, exist_ok=True)
  25. fig.savefig(output_path)
  26. plt.close(fig)
  27. def save_uncertainty_cutoff_plot(
  28. cutoff_df: pd.DataFrame,
  29. title_prefix: str,
  30. x_label: str,
  31. output_path: Path,
  32. ) -> None:
  33. fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  34. for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
  35. g = group.sort_values("restriction_level")
  36. axes[0].plot(
  37. g["restriction_level"], g["accuracy"], marker="o", label=uncertainty_name
  38. )
  39. axes[1].plot(
  40. g["restriction_level"], g["f1"], marker="s", label=uncertainty_name
  41. )
  42. axes[0].set_title(f"Accuracy vs {title_prefix}")
  43. axes[1].set_title(f"F1 vs {title_prefix}")
  44. for ax in axes:
  45. ax.set_xlabel(x_label)
  46. ax.grid(True, alpha=0.3)
  47. ax.legend()
  48. axes[0].set_ylabel("Accuracy")
  49. axes[1].set_ylabel("F1")
  50. fig.tight_layout()
  51. output_path.parent.mkdir(parents=True, exist_ok=True)
  52. fig.savefig(output_path)
  53. plt.close(fig)
  54. def save_calibration_plot(per_bin: np.ndarray, backend: str, output_path: Path) -> None:
  55. fig, ax = plt.subplots(figsize=(6, 6))
  56. valid = ~np.isnan(per_bin[:, 1])
  57. ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal")
  58. ax.plot(per_bin[valid, 0], per_bin[valid, 1], marker="o", label=backend)
  59. ax.set_xlabel("Mean Predicted Probability")
  60. ax.set_ylabel("Empirical Fraction Positive")
  61. ax.set_title(f"Reliability Diagram ({backend})")
  62. ax.legend()
  63. ax.grid(True, alpha=0.3)
  64. fig.tight_layout()
  65. output_path.parent.mkdir(parents=True, exist_ok=True)
  66. fig.savefig(output_path)
  67. plt.close(fig)
  68. def save_boxplot(
  69. data: list[np.ndarray],
  70. tick_labels: list[str],
  71. x_label: str,
  72. y_label: str,
  73. title: str,
  74. output_path: Path,
  75. ) -> None:
  76. fig, ax = plt.subplots(figsize=(9, 5))
  77. ax.boxplot(data, tick_labels=tick_labels)
  78. ax.set_xlabel(x_label)
  79. ax.set_ylabel(y_label)
  80. ax.set_title(title)
  81. ax.grid(True, axis="y", alpha=0.3)
  82. fig.tight_layout()
  83. output_path.parent.mkdir(parents=True, exist_ok=True)
  84. fig.savefig(output_path)
  85. plt.close(fig)
  86. def _central_slice(volume: torch.Tensor) -> np.ndarray:
  87. tensor = volume.detach().cpu()
  88. if tensor.ndim == 5:
  89. tensor = tensor[0]
  90. if tensor.ndim == 4:
  91. tensor = tensor[0]
  92. if tensor.ndim != 3:
  93. raise ValueError(
  94. f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}"
  95. )
  96. center_index = tensor.shape[0] // 2
  97. return tensor[center_index].numpy().astype(float)
  98. def _normalize_for_display(image: np.ndarray) -> np.ndarray:
  99. low = float(np.percentile(image, 1))
  100. high = float(np.percentile(image, 99))
  101. if high <= low:
  102. return np.zeros_like(image, dtype=float)
  103. clipped = np.clip(image, low, high)
  104. return (clipped - low) / (high - low)
  105. def save_noise_example_grid(
  106. original_mri: torch.Tensor,
  107. noisy_by_sigma: list[tuple[float, torch.Tensor]],
  108. output_path: Path,
  109. title: str,
  110. ) -> None:
  111. if not noisy_by_sigma:
  112. return
  113. original_slice = _normalize_for_display(_central_slice(original_mri))
  114. n_rows = len(noisy_by_sigma)
  115. fig, axes = plt.subplots(n_rows, 2, figsize=(8, 3.2 * n_rows))
  116. if n_rows == 1:
  117. axes = np.array([axes])
  118. for row_idx, (sigma, noisy_tensor) in enumerate(noisy_by_sigma):
  119. noisy_slice = _normalize_for_display(_central_slice(noisy_tensor))
  120. ax_orig, ax_noisy = axes[row_idx]
  121. ax_orig.imshow(original_slice, cmap="gray")
  122. ax_orig.set_title("Original")
  123. ax_orig.axis("off")
  124. ax_noisy.imshow(noisy_slice, cmap="gray")
  125. ax_noisy.set_title(f"Noisy factor={sigma:g}")
  126. ax_noisy.axis("off")
  127. fig.suptitle(title)
  128. fig.tight_layout()
  129. output_path.parent.mkdir(parents=True, exist_ok=True)
  130. fig.savefig(output_path)
  131. plt.close(fig)
  132. def save_noise_metrics_plot(
  133. x: pd.Series,
  134. y_by_label: list[tuple[pd.Series, str, str]],
  135. x_label: str,
  136. y_label: str,
  137. title: str,
  138. output_path: Path,
  139. ) -> None:
  140. fig, ax = plt.subplots(figsize=(10, 5))
  141. for series, marker, label in y_by_label:
  142. ax.plot(x, series, marker=marker, label=label)
  143. ax.set_xlabel(x_label)
  144. ax.set_ylabel(y_label)
  145. ax.set_title(title)
  146. ax.grid(True, alpha=0.3)
  147. ax.legend()
  148. fig.tight_layout()
  149. output_path.parent.mkdir(parents=True, exist_ok=True)
  150. fig.savefig(output_path)
  151. plt.close(fig)