analysis_modules.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import pandas as pd
  8. from .data_access import BackendEvaluation, physician_column
  9. from .metrics import calibration_stats, performance_at_threshold, threshold_sweep
  10. from .runtime import write_json
  11. def _save_table(rows: list[dict[str, Any]], out_path: Path) -> pd.DataFrame:
  12. df = pd.DataFrame(rows)
  13. out_path.parent.mkdir(parents=True, exist_ok=True)
  14. df.to_csv(out_path, index=False)
  15. return df
  16. def run_performance(
  17. evaluation: BackendEvaluation,
  18. output_dir: Path,
  19. thresholds: np.ndarray,
  20. ) -> dict[str, Any]:
  21. rows = threshold_sweep(evaluation.y_true, evaluation.y_prob, thresholds)
  22. table_path = output_dir / "performance_threshold_sweep.csv"
  23. df = _save_table(rows, table_path)
  24. fig, ax = plt.subplots(figsize=(10, 5))
  25. ax.plot(df["threshold"], df["accuracy"], label="accuracy", marker="o")
  26. ax.plot(df["threshold"], df["f1"], label="f1", marker="s")
  27. ax.set_xlabel("Threshold")
  28. ax.set_ylabel("Score")
  29. ax.set_title(f"Performance vs Threshold ({evaluation.backend})")
  30. ax.grid(True, alpha=0.3)
  31. ax.legend()
  32. fig.tight_layout()
  33. plot_path = output_dir / "performance_threshold_sweep.png"
  34. fig.savefig(plot_path)
  35. plt.close(fig)
  36. best_idx = int(df["f1"].idxmax())
  37. best = df.iloc[best_idx].to_dict()
  38. cutoff_percentiles = np.array(
  39. [100, 95, 90, 85, 80, 75, 70, 60, 50, 40, 30, 20, 10, 5, 2, 1],
  40. dtype=float,
  41. )
  42. confidence_uncertainty = 1.0 - np.asarray(
  43. evaluation.uncertainty_confidence, dtype=float
  44. )
  45. secondary_uncertainty = np.asarray(evaluation.uncertainty_std, dtype=float)
  46. uncertainty_types = [
  47. ("confidence_uncertainty", confidence_uncertainty),
  48. (evaluation.uncertainty_metric, secondary_uncertainty),
  49. ]
  50. cutoff_rows: list[dict[str, Any]] = []
  51. for uncertainty_name, values in uncertainty_types:
  52. finite_mask = np.isfinite(values)
  53. if not finite_mask.any():
  54. continue
  55. values_valid = values[finite_mask]
  56. y_true_valid = evaluation.y_true[finite_mask]
  57. y_prob_valid = evaluation.y_prob[finite_mask]
  58. for cutoff_percentile in cutoff_percentiles:
  59. # Keep predictions whose uncertainty is <= percentile cutoff.
  60. cutoff_value = float(np.percentile(values_valid, cutoff_percentile))
  61. keep_mask = values_valid <= cutoff_value
  62. retained = int(keep_mask.sum())
  63. if retained == 0:
  64. continue
  65. perf = performance_at_threshold(
  66. y_true=y_true_valid[keep_mask],
  67. y_prob=y_prob_valid[keep_mask],
  68. threshold=0.5,
  69. )
  70. cutoff_rows.append(
  71. {
  72. "uncertainty_type": uncertainty_name,
  73. "cutoff_percentile": float(cutoff_percentile),
  74. "cutoff_value": cutoff_value,
  75. "n_retained": retained,
  76. "coverage": float(retained / len(values_valid)),
  77. "accuracy": float(perf["accuracy"]),
  78. "f1": float(perf["f1"]),
  79. }
  80. )
  81. cutoff_table_path = output_dir / "performance_uncertainty_cutoff.csv"
  82. cutoff_plot_path = output_dir / "performance_uncertainty_cutoff.png"
  83. if cutoff_rows:
  84. cutoff_df = pd.DataFrame(cutoff_rows)
  85. cutoff_df.to_csv(cutoff_table_path, index=False)
  86. fig_u, axes_u = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
  87. for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
  88. g = group.sort_values("cutoff_percentile", ascending=False)
  89. axes_u[0].plot(
  90. g["cutoff_percentile"],
  91. g["accuracy"],
  92. marker="o",
  93. label=uncertainty_name,
  94. )
  95. axes_u[1].plot(
  96. g["cutoff_percentile"],
  97. g["f1"],
  98. marker="s",
  99. label=uncertainty_name,
  100. )
  101. axes_u[0].set_title("Accuracy vs Uncertainty Cutoff Percentile")
  102. axes_u[1].set_title("F1 vs Uncertainty Cutoff Percentile")
  103. for ax in axes_u:
  104. ax.set_xlabel("Uncertainty Cutoff Percentile (100 = no cutoff)")
  105. ax.grid(True, alpha=0.3)
  106. ax.legend()
  107. axes_u[0].set_ylabel("Accuracy")
  108. axes_u[1].set_ylabel("F1")
  109. fig_u.tight_layout()
  110. fig_u.savefig(cutoff_plot_path)
  111. plt.close(fig_u)
  112. summary = {
  113. "best_by_f1": {
  114. k: float(v) for k, v in best.items() if isinstance(v, (int, float))
  115. },
  116. "table": str(table_path),
  117. "plot": str(plot_path),
  118. "uncertainty_cutoff": {
  119. "table": str(cutoff_table_path),
  120. "plot": str(cutoff_plot_path),
  121. "decision_threshold": 0.5,
  122. },
  123. }
  124. write_json(output_dir / "performance_summary.json", summary)
  125. return summary
  126. def run_calibration(
  127. evaluation: BackendEvaluation,
  128. output_dir: Path,
  129. bins: int,
  130. ) -> dict[str, Any]:
  131. summary, per_bin = calibration_stats(
  132. evaluation.y_true, evaluation.y_prob, bins=bins
  133. )
  134. bin_df = pd.DataFrame(
  135. per_bin,
  136. columns=["mean_confidence", "fraction_positive", "count"],
  137. )
  138. table_path = output_dir / "calibration_bins.csv"
  139. bin_df.to_csv(table_path, index=False)
  140. fig, ax = plt.subplots(figsize=(6, 6))
  141. valid = ~np.isnan(per_bin[:, 1])
  142. ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal")
  143. ax.plot(
  144. per_bin[valid, 0],
  145. per_bin[valid, 1],
  146. marker="o",
  147. label=f"{evaluation.backend}",
  148. )
  149. ax.set_xlabel("Mean Predicted Probability")
  150. ax.set_ylabel("Empirical Fraction Positive")
  151. ax.set_title(f"Reliability Diagram ({evaluation.backend})")
  152. ax.legend()
  153. ax.grid(True, alpha=0.3)
  154. fig.tight_layout()
  155. plot_path = output_dir / "calibration_reliability.png"
  156. fig.savefig(plot_path)
  157. plt.close(fig)
  158. out = {
  159. **summary,
  160. "table": str(table_path),
  161. "plot": str(plot_path),
  162. }
  163. write_json(output_dir / "calibration_summary.json", out)
  164. return out
  165. def run_physician(
  166. evaluation: BackendEvaluation,
  167. clinical_df: pd.DataFrame,
  168. output_dir: Path,
  169. ) -> dict[str, Any]:
  170. secondary_key = (
  171. "predictive_entropy"
  172. if evaluation.uncertainty_metric == "predictive_entropy"
  173. else "std"
  174. )
  175. secondary_label = (
  176. "Model Predictive Entropy"
  177. if secondary_key == "predictive_entropy"
  178. else "Model Uncertainty Std"
  179. )
  180. col = physician_column(clinical_df)
  181. subset = clinical_df[["Image Data ID", col]].copy()
  182. subset[col] = pd.to_numeric(subset[col], errors="coerce")
  183. subset = subset.dropna(subset=["Image Data ID", col])
  184. subset["Image Data ID"] = subset["Image Data ID"].astype(int)
  185. subset[col] = subset[col].astype(int)
  186. eval_df = pd.DataFrame(
  187. {
  188. "Image Data ID": evaluation.image_ids.astype(int),
  189. "model_confidence": evaluation.uncertainty_confidence,
  190. "model_std": evaluation.uncertainty_std,
  191. "model_prob": evaluation.y_prob,
  192. }
  193. )
  194. merged = eval_df.merge(subset, on="Image Data ID", how="inner")
  195. if merged.empty:
  196. raise ValueError("No overlapping Image Data ID rows for physician analysis")
  197. grouped_rows: list[dict[str, Any]] = []
  198. uncertainty_specs = [
  199. ("confidence", "model_confidence", "Model Confidence (2*|p-0.5|)"),
  200. (secondary_key, "model_std", secondary_label),
  201. ]
  202. ratings = [int(r) for r in sorted(pd.unique(merged[col]))]
  203. plot_paths: dict[str, str] = {}
  204. correlations: dict[str, float] = {}
  205. for metric_name, metric_col, metric_label in uncertainty_specs:
  206. grouped_metric = (
  207. merged.groupby(col)
  208. .agg(
  209. n=("Image Data ID", "count"),
  210. mean_value=(metric_col, "mean"),
  211. std_value=(metric_col, "std"),
  212. mean_prob=("model_prob", "mean"),
  213. )
  214. .reset_index()
  215. .rename(columns={col: "physician_rating"})
  216. )
  217. grouped_metric["uncertainty_type"] = metric_name
  218. grouped_rows.extend(
  219. [
  220. {str(k): v for k, v in rec.items()}
  221. for rec in grouped_metric.to_dict(orient="records")
  222. ]
  223. )
  224. fig, ax = plt.subplots(figsize=(9, 5))
  225. data = [
  226. np.asarray(merged.loc[merged[col] == r, metric_col], dtype=float)
  227. for r in ratings
  228. ]
  229. ax.boxplot(data, tick_labels=[str(r) for r in ratings])
  230. ax.set_xlabel("Physician Confidence Rating (DXCONFID)")
  231. ax.set_ylabel(metric_label)
  232. ax.set_title(f"{metric_label} vs Physician Confidence ({evaluation.backend})")
  233. ax.grid(True, axis="y", alpha=0.3)
  234. fig.tight_layout()
  235. plot_path = output_dir / f"physician_{metric_name}_boxplot.png"
  236. fig.savefig(plot_path)
  237. plt.close(fig)
  238. corr = float(
  239. pd.to_numeric(
  240. merged[[metric_col, col]].corr(method="spearman").iloc[0, 1],
  241. errors="coerce",
  242. )
  243. )
  244. correlations[metric_name] = corr
  245. plot_paths[metric_name] = str(plot_path)
  246. grouped = pd.DataFrame(grouped_rows)
  247. table_path = output_dir / "physician_grouped_metrics.csv"
  248. grouped.to_csv(table_path, index=False)
  249. confidence_table = output_dir / "physician_confidence_grouped_metrics.csv"
  250. std_table = output_dir / "physician_std_grouped_metrics.csv"
  251. secondary_table = output_dir / f"physician_{secondary_key}_grouped_metrics.csv"
  252. grouped[grouped["uncertainty_type"] == "confidence"].to_csv(
  253. confidence_table, index=False
  254. )
  255. grouped[grouped["uncertainty_type"] == secondary_key].to_csv(
  256. secondary_table, index=False
  257. )
  258. grouped[grouped["uncertainty_type"] == secondary_key].to_csv(std_table, index=False)
  259. out = {
  260. "n_overlap": int(len(merged)),
  261. "spearman_vs_dxconfid": correlations,
  262. "table": str(table_path),
  263. "tables": {
  264. "confidence": str(confidence_table),
  265. secondary_key: str(secondary_table),
  266. "std": str(std_table),
  267. },
  268. "plots": plot_paths,
  269. }
  270. write_json(output_dir / "physician_summary.json", out)
  271. return out
  272. def _normalize_dx(value: Any) -> str:
  273. if value is None or (isinstance(value, float) and np.isnan(value)):
  274. return ""
  275. v = str(value).strip().upper()
  276. if v in {"NL", "NORMAL"}:
  277. return "CN"
  278. return v
  279. def run_longitudinal(
  280. evaluation: BackendEvaluation,
  281. clinical_df: pd.DataFrame,
  282. output_dir: Path,
  283. ) -> dict[str, Any]:
  284. secondary_key = (
  285. "predictive_entropy"
  286. if evaluation.uncertainty_metric == "predictive_entropy"
  287. else "std"
  288. )
  289. secondary_label = (
  290. "Mean Model Predictive Entropy"
  291. if secondary_key == "predictive_entropy"
  292. else "Mean Model Uncertainty Std"
  293. )
  294. required = ["Image Data ID", "PTID"]
  295. missing = [c for c in required if c not in clinical_df.columns]
  296. if missing:
  297. raise KeyError(f"Missing columns for longitudinal analysis: {missing}")
  298. diagnosis_col = None
  299. for candidate in ["Class", "DX", "Diagnosis"]:
  300. if candidate in clinical_df.columns:
  301. diagnosis_col = candidate
  302. break
  303. if diagnosis_col is None:
  304. raise KeyError(
  305. "No diagnosis column found. Expected one of: Class, DX, Diagnosis"
  306. )
  307. work = clinical_df[
  308. ["Image Data ID", "PTID", diagnosis_col]
  309. + [c for c in ["EXAMDATE"] if c in clinical_df.columns]
  310. ].copy()
  311. work["Image Data ID"] = pd.to_numeric(work["Image Data ID"], errors="coerce")
  312. work = work.dropna(subset=["Image Data ID", "PTID"])
  313. work["Image Data ID"] = work["Image Data ID"].astype(int)
  314. work["PTID"] = work["PTID"].astype(str).str.strip()
  315. work["diagnosis"] = work[diagnosis_col].map(_normalize_dx)
  316. if "EXAMDATE" in work.columns:
  317. work["EXAMDATE"] = pd.to_datetime(work["EXAMDATE"], errors="coerce")
  318. work = work.sort_values(["PTID", "EXAMDATE"], na_position="last")
  319. else:
  320. work = work.sort_values(["PTID", "Image Data ID"])
  321. eval_df = pd.DataFrame(
  322. {
  323. "Image Data ID": evaluation.image_ids.astype(int),
  324. "model_confidence": evaluation.uncertainty_confidence,
  325. "model_std": evaluation.uncertainty_std,
  326. "model_prob": evaluation.y_prob,
  327. }
  328. )
  329. merged = work.merge(eval_df, on="Image Data ID", how="inner")
  330. if merged.empty:
  331. raise ValueError("No overlapping Image Data ID rows for longitudinal analysis")
  332. patient_rows: list[dict[str, Any]] = []
  333. for ptid, group in merged.groupby("PTID"):
  334. diagnoses = [d for d in group["diagnosis"].tolist() if d]
  335. if len(diagnoses) < 2:
  336. continue
  337. first_dx = diagnoses[0]
  338. last_dx = diagnoses[-1]
  339. unique_dx = set(diagnoses)
  340. cohort = "other"
  341. if unique_dx.issubset({"CN"}):
  342. cohort = "stable_cn"
  343. elif unique_dx.issubset({"AD"}):
  344. cohort = "stable_ad"
  345. elif first_dx == "CN" and "AD" in unique_dx and last_dx == "AD":
  346. cohort = "cn_to_ad"
  347. patient_rows.append(
  348. {
  349. "PTID": ptid,
  350. "n_visits": int(len(group)),
  351. "first_dx": first_dx,
  352. "last_dx": last_dx,
  353. "cohort": cohort,
  354. "mean_confidence": float(group["model_confidence"].mean()),
  355. "mean_std": float(group["model_std"].mean()),
  356. "mean_prob": float(group["model_prob"].mean()),
  357. }
  358. )
  359. patient_df = pd.DataFrame(patient_rows)
  360. table_path = output_dir / "longitudinal_patient_summary.csv"
  361. patient_df.to_csv(table_path, index=False)
  362. cohort_df = (
  363. patient_df.groupby("cohort")
  364. .agg(
  365. n_patients=("PTID", "count"),
  366. mean_confidence=("mean_confidence", "mean"),
  367. mean_std=("mean_std", "mean"),
  368. mean_prob=("mean_prob", "mean"),
  369. )
  370. .reset_index()
  371. )
  372. cohort_table = output_dir / "longitudinal_cohort_summary.csv"
  373. cohort_df.to_csv(cohort_table, index=False)
  374. cohorts = ["stable_cn", "stable_ad", "cn_to_ad"]
  375. uncertainty_specs = [
  376. ("confidence", "mean_confidence", "Mean Model Confidence"),
  377. (secondary_key, "mean_std", secondary_label),
  378. ]
  379. plot_paths: dict[str, str] = {}
  380. for metric_name, metric_col, metric_label in uncertainty_specs:
  381. fig, ax = plt.subplots(figsize=(9, 5))
  382. values = [
  383. np.asarray(
  384. patient_df.loc[patient_df["cohort"] == c, metric_col], dtype=float
  385. )
  386. for c in cohorts
  387. ]
  388. ax.boxplot(values, tick_labels=cohorts)
  389. ax.set_ylabel(metric_label)
  390. ax.set_title(f"Longitudinal Cohort {metric_label} ({evaluation.backend})")
  391. ax.grid(True, axis="y", alpha=0.3)
  392. fig.tight_layout()
  393. plot_path = output_dir / f"longitudinal_cohort_{metric_name}.png"
  394. fig.savefig(plot_path)
  395. plt.close(fig)
  396. plot_paths[metric_name] = str(plot_path)
  397. uncertainty_by_cohort = cohort_df.melt(
  398. id_vars=["cohort", "n_patients"],
  399. value_vars=["mean_confidence", "mean_std"],
  400. var_name="uncertainty_type",
  401. value_name="mean_value",
  402. ).replace(
  403. {
  404. "uncertainty_type": {
  405. "mean_confidence": "confidence",
  406. "mean_std": secondary_key,
  407. }
  408. }
  409. )
  410. uncertainty_table = output_dir / "longitudinal_uncertainty_by_cohort.csv"
  411. uncertainty_by_cohort.to_csv(uncertainty_table, index=False)
  412. confidence_patient_table = (
  413. output_dir / "longitudinal_confidence_patient_summary.csv"
  414. )
  415. std_patient_table = output_dir / "longitudinal_std_patient_summary.csv"
  416. confidence_cohort_table = output_dir / "longitudinal_confidence_cohort_summary.csv"
  417. std_cohort_table = output_dir / "longitudinal_std_cohort_summary.csv"
  418. secondary_patient_table = (
  419. output_dir / f"longitudinal_{secondary_key}_patient_summary.csv"
  420. )
  421. secondary_cohort_table = (
  422. output_dir / f"longitudinal_{secondary_key}_cohort_summary.csv"
  423. )
  424. patient_df[
  425. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_confidence"]
  426. ].to_csv(confidence_patient_table, index=False)
  427. patient_df[
  428. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_std"]
  429. ].to_csv(std_patient_table, index=False)
  430. patient_df[
  431. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_std"]
  432. ].to_csv(secondary_patient_table, index=False)
  433. cohort_df[["cohort", "n_patients", "mean_confidence"]].to_csv(
  434. confidence_cohort_table, index=False
  435. )
  436. cohort_df[["cohort", "n_patients", "mean_std"]].to_csv(
  437. std_cohort_table, index=False
  438. )
  439. cohort_df[["cohort", "n_patients", "mean_std"]].to_csv(
  440. secondary_cohort_table, index=False
  441. )
  442. out = {
  443. "n_patients_analyzed": int(len(patient_df)),
  444. "table_patient": str(table_path),
  445. "table_cohort": str(cohort_table),
  446. "table_uncertainty": str(uncertainty_table),
  447. "tables": {
  448. "confidence": {
  449. "patient": str(confidence_patient_table),
  450. "cohort": str(confidence_cohort_table),
  451. },
  452. secondary_key: {
  453. "patient": str(secondary_patient_table),
  454. "cohort": str(secondary_cohort_table),
  455. },
  456. "std": {
  457. "patient": str(std_patient_table),
  458. "cohort": str(std_cohort_table),
  459. },
  460. },
  461. "plots": plot_paths,
  462. }
  463. write_json(output_dir / "longitudinal_summary.json", out)
  464. return out