analysis_modules.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import numpy as np
  6. import pandas as pd
  7. from .data_access import BackendEvaluation, physician_column
  8. from .defaults import DEFAULT_DECISION_THRESHOLD, uncertainty_cutoff_percentiles
  9. from .metrics import calibration_stats, performance_at_threshold, threshold_sweep
  10. from .plotting import (
  11. plots_dir,
  12. save_boxplot,
  13. save_calibration_plot,
  14. save_performance_threshold_plot,
  15. save_uncertainty_cutoff_plot,
  16. )
  17. from .runtime import write_json
  18. from .uncertainty import confidence_uncertainty
  19. def _save_table(rows: list[dict[str, Any]], out_path: Path) -> pd.DataFrame:
  20. df = pd.DataFrame(rows)
  21. out_path.parent.mkdir(parents=True, exist_ok=True)
  22. df.to_csv(out_path, index=False)
  23. return df
  24. def _uncertainty_percentiles(values: np.ndarray) -> np.ndarray:
  25. vals = np.asarray(values, dtype=float)
  26. n = len(vals)
  27. if n == 0:
  28. return np.asarray([], dtype=float)
  29. if n == 1:
  30. return np.asarray([0.0], dtype=float)
  31. order = np.argsort(vals)
  32. ranks = np.empty(n, dtype=float)
  33. ranks[order] = np.arange(n, dtype=float)
  34. return (ranks / float(n - 1)) * 100.0
  35. def _save_ensemble_prediction_debug(
  36. evaluation: BackendEvaluation,
  37. output_dir: Path,
  38. ) -> str:
  39. uncertainty_vals = confidence_uncertainty(evaluation.y_prob)
  40. uncertainty_pct = _uncertainty_percentiles(uncertainty_vals)
  41. pred_label = (evaluation.y_prob >= DEFAULT_DECISION_THRESHOLD).astype(int)
  42. true_label = np.asarray(evaluation.y_true, dtype=int)
  43. is_correct = (pred_label == true_label).astype(int)
  44. debug_df = pd.DataFrame(
  45. {
  46. "image_id": evaluation.image_ids.astype(int),
  47. "predicted_probability": np.asarray(evaluation.y_prob, dtype=float),
  48. "predicted_label": pred_label,
  49. "actual_label": true_label,
  50. "is_correct": is_correct,
  51. "uncertainty": uncertainty_vals,
  52. "uncertainty_percentile": uncertainty_pct,
  53. }
  54. ).sort_values("uncertainty", ascending=True)
  55. path = output_dir / "ensemble_prediction_debug.csv"
  56. debug_df.to_csv(path, index=False)
  57. return str(path)
  58. def _uncertainty_cutoff_analysis(
  59. evaluation: BackendEvaluation,
  60. output_dir: Path,
  61. uncertainty_types: list[tuple[str, np.ndarray]],
  62. cutoff_percentiles: np.ndarray,
  63. keep_higher: bool,
  64. table_filename: str,
  65. plot_filename: str,
  66. title_prefix: str,
  67. x_label: str,
  68. selection_label: str,
  69. ) -> dict[str, Any]:
  70. cutoff_rows: list[dict[str, Any]] = []
  71. for uncertainty_name, values in uncertainty_types:
  72. finite_mask = np.isfinite(values)
  73. if not finite_mask.any():
  74. continue
  75. values_valid = values[finite_mask]
  76. y_true_valid = evaluation.y_true[finite_mask]
  77. y_prob_valid = evaluation.y_prob[finite_mask]
  78. for cutoff_percentile in cutoff_percentiles:
  79. cutoff_value = float(np.percentile(values_valid, cutoff_percentile))
  80. if keep_higher:
  81. keep_mask = values_valid >= cutoff_value
  82. else:
  83. keep_mask = values_valid <= cutoff_value
  84. retained = int(keep_mask.sum())
  85. if retained == 0:
  86. continue
  87. perf = performance_at_threshold(
  88. y_true=y_true_valid[keep_mask],
  89. y_prob=y_prob_valid[keep_mask],
  90. threshold=DEFAULT_DECISION_THRESHOLD,
  91. )
  92. cutoff_rows.append(
  93. {
  94. "uncertainty_type": uncertainty_name,
  95. "cutoff_percentile": float(cutoff_percentile),
  96. "cutoff_value": cutoff_value,
  97. "n_retained": retained,
  98. "coverage": float(retained / len(values_valid)),
  99. "selection_rule": selection_label,
  100. "accuracy": float(perf["accuracy"]),
  101. "f1": float(perf["f1"]),
  102. }
  103. )
  104. table_path = output_dir / table_filename
  105. plot_path = plots_dir(output_dir) / plot_filename
  106. if not cutoff_rows:
  107. return {"table": str(table_path), "plot": str(plot_path), "rows": 0}
  108. cutoff_df = pd.DataFrame(cutoff_rows)
  109. cutoff_df.to_csv(table_path, index=False)
  110. # Restriction level increases from left to right so right-most points are most restricted.
  111. if keep_higher:
  112. cutoff_df["restriction_level"] = cutoff_df["cutoff_percentile"]
  113. else:
  114. cutoff_df["restriction_level"] = 100.0 - cutoff_df["cutoff_percentile"]
  115. save_uncertainty_cutoff_plot(
  116. cutoff_df=cutoff_df,
  117. title_prefix=title_prefix,
  118. x_label=x_label,
  119. output_path=plot_path,
  120. )
  121. return {
  122. "table": str(table_path),
  123. "plot": str(plot_path),
  124. "rows": int(len(cutoff_df)),
  125. }
  126. def run_performance(
  127. evaluation: BackendEvaluation,
  128. output_dir: Path,
  129. thresholds: np.ndarray,
  130. ) -> dict[str, Any]:
  131. rows = threshold_sweep(evaluation.y_true, evaluation.y_prob, thresholds)
  132. table_path = output_dir / "performance_threshold_sweep.csv"
  133. df = _save_table(rows, table_path)
  134. plot_path = plots_dir(output_dir) / "performance_threshold_sweep.png"
  135. save_performance_threshold_plot(
  136. df=df,
  137. backend=evaluation.backend,
  138. output_path=plot_path,
  139. )
  140. best_idx = int(df["f1"].idxmax())
  141. best = df.iloc[best_idx].to_dict()
  142. cutoff_percentiles = uncertainty_cutoff_percentiles()
  143. confidence_uncertainty_vals = confidence_uncertainty(evaluation.y_prob)
  144. secondary_uncertainty = np.asarray(evaluation.uncertainty_std, dtype=float)
  145. uncertainty_types = [
  146. ("confidence_uncertainty", confidence_uncertainty_vals),
  147. (evaluation.uncertainty_metric, secondary_uncertainty),
  148. ]
  149. uncertainty_cutoff = _uncertainty_cutoff_analysis(
  150. evaluation=evaluation,
  151. output_dir=output_dir,
  152. uncertainty_types=uncertainty_types,
  153. cutoff_percentiles=cutoff_percentiles,
  154. keep_higher=False,
  155. table_filename="performance_uncertainty_cutoff.csv",
  156. plot_filename="performance_uncertainty_cutoff.png",
  157. title_prefix="Uncertainty Cutoff Percentile",
  158. x_label="Restriction Level (0 = all results, 100 = lowest-uncertainty subset)",
  159. selection_label="uncertainty <= percentile cutoff",
  160. )
  161. percentile_cutoffs = uncertainty_cutoff_percentiles()
  162. uncertainty_percentile_cutoff = _uncertainty_cutoff_analysis(
  163. evaluation=evaluation,
  164. output_dir=output_dir,
  165. uncertainty_types=uncertainty_types,
  166. cutoff_percentiles=percentile_cutoffs,
  167. keep_higher=True,
  168. table_filename="performance_uncertainty_percentile_cutoff.csv",
  169. plot_filename="performance_uncertainty_percentile_cutoff.png",
  170. title_prefix="Uncertainty Percentile Floor",
  171. x_label="Uncertainty Percentile Floor (0 = all results, 100 = highest-uncertainty subset)",
  172. selection_label="uncertainty >= percentile floor",
  173. )
  174. cutoff_table_path = Path(uncertainty_cutoff["table"])
  175. cutoff_plot_path = Path(uncertainty_cutoff["plot"])
  176. percentile_cutoff_table_path = Path(uncertainty_percentile_cutoff["table"])
  177. percentile_cutoff_plot_path = Path(uncertainty_percentile_cutoff["plot"])
  178. summary = {
  179. "best_by_f1": {
  180. k: float(v) for k, v in best.items() if isinstance(v, (int, float))
  181. },
  182. "table": str(table_path),
  183. "plot": str(plot_path),
  184. "uncertainty_cutoff": {
  185. "table": str(cutoff_table_path),
  186. "plot": str(cutoff_plot_path),
  187. "decision_threshold": DEFAULT_DECISION_THRESHOLD,
  188. },
  189. "uncertainty_percentile_cutoff": {
  190. "table": str(percentile_cutoff_table_path),
  191. "plot": str(percentile_cutoff_plot_path),
  192. "decision_threshold": DEFAULT_DECISION_THRESHOLD,
  193. },
  194. }
  195. if evaluation.backend == "ensemble":
  196. summary["ensemble_prediction_debug"] = _save_ensemble_prediction_debug(
  197. evaluation=evaluation,
  198. output_dir=output_dir,
  199. )
  200. write_json(output_dir / "performance_summary.json", summary)
  201. return summary
  202. def run_calibration(
  203. evaluation: BackendEvaluation,
  204. output_dir: Path,
  205. bins: int,
  206. ) -> dict[str, Any]:
  207. summary, per_bin = calibration_stats(
  208. evaluation.y_true, evaluation.y_prob, bins=bins
  209. )
  210. bin_df = pd.DataFrame(
  211. per_bin,
  212. columns=["mean_confidence", "fraction_positive", "count"],
  213. )
  214. table_path = output_dir / "calibration_bins.csv"
  215. bin_df.to_csv(table_path, index=False)
  216. plot_path = plots_dir(output_dir) / "calibration_reliability.png"
  217. save_calibration_plot(
  218. per_bin=per_bin,
  219. backend=evaluation.backend,
  220. output_path=plot_path,
  221. )
  222. out = {
  223. **summary,
  224. "table": str(table_path),
  225. "plot": str(plot_path),
  226. }
  227. write_json(output_dir / "calibration_summary.json", out)
  228. return out
  229. def run_physician(
  230. evaluation: BackendEvaluation,
  231. clinical_df: pd.DataFrame,
  232. output_dir: Path,
  233. ) -> dict[str, Any]:
  234. secondary_key = (
  235. "predictive_entropy"
  236. if evaluation.uncertainty_metric == "predictive_entropy"
  237. else "std"
  238. )
  239. secondary_label = (
  240. "Model Predictive Entropy"
  241. if secondary_key == "predictive_entropy"
  242. else "Model Uncertainty Std"
  243. )
  244. col = physician_column(clinical_df)
  245. subset = clinical_df[["Image Data ID", col]].copy()
  246. subset[col] = pd.to_numeric(subset[col], errors="coerce")
  247. subset = subset.dropna(subset=["Image Data ID", col])
  248. subset["Image Data ID"] = subset["Image Data ID"].astype(int)
  249. subset[col] = subset[col].astype(int)
  250. eval_df = pd.DataFrame(
  251. {
  252. "Image Data ID": evaluation.image_ids.astype(int),
  253. "model_confidence": evaluation.uncertainty_confidence,
  254. "model_std": evaluation.uncertainty_std,
  255. "model_prob": evaluation.y_prob,
  256. }
  257. )
  258. merged = eval_df.merge(subset, on="Image Data ID", how="inner")
  259. if merged.empty:
  260. raise ValueError("No overlapping Image Data ID rows for physician analysis")
  261. grouped_rows: list[dict[str, Any]] = []
  262. uncertainty_specs = [
  263. ("confidence", "model_confidence", "Model Confidence (2*|p-0.5|)"),
  264. (secondary_key, "model_std", secondary_label),
  265. ]
  266. ratings = [int(r) for r in sorted(pd.unique(merged[col]))]
  267. plot_paths: dict[str, str] = {}
  268. correlations: dict[str, float] = {}
  269. for metric_name, metric_col, metric_label in uncertainty_specs:
  270. grouped_metric = (
  271. merged.groupby(col)
  272. .agg(
  273. n=("Image Data ID", "count"),
  274. mean_value=(metric_col, "mean"),
  275. std_value=(metric_col, "std"),
  276. mean_prob=("model_prob", "mean"),
  277. )
  278. .reset_index()
  279. .rename(columns={col: "physician_rating"})
  280. )
  281. grouped_metric["uncertainty_type"] = metric_name
  282. grouped_rows.extend(
  283. [
  284. {str(k): v for k, v in rec.items()}
  285. for rec in grouped_metric.to_dict(orient="records")
  286. ]
  287. )
  288. data = [
  289. np.asarray(merged.loc[merged[col] == r, metric_col], dtype=float)
  290. for r in ratings
  291. ]
  292. plot_path = plots_dir(output_dir) / f"physician_{metric_name}_boxplot.png"
  293. save_boxplot(
  294. data=data,
  295. tick_labels=[str(r) for r in ratings],
  296. x_label="Physician Confidence Rating (DXCONFID)",
  297. y_label=metric_label,
  298. title=f"{metric_label} vs Physician Confidence ({evaluation.backend})",
  299. output_path=plot_path,
  300. )
  301. corr = float(
  302. pd.to_numeric(
  303. merged[[metric_col, col]].corr(method="spearman").iloc[0, 1],
  304. errors="coerce",
  305. )
  306. )
  307. correlations[metric_name] = corr
  308. plot_paths[metric_name] = str(plot_path)
  309. grouped = pd.DataFrame(grouped_rows)
  310. table_path = output_dir / "physician_grouped_metrics.csv"
  311. grouped.to_csv(table_path, index=False)
  312. confidence_table = output_dir / "physician_confidence_grouped_metrics.csv"
  313. std_table = output_dir / "physician_std_grouped_metrics.csv"
  314. secondary_table = output_dir / f"physician_{secondary_key}_grouped_metrics.csv"
  315. grouped[grouped["uncertainty_type"] == "confidence"].to_csv(
  316. confidence_table, index=False
  317. )
  318. grouped[grouped["uncertainty_type"] == secondary_key].to_csv(
  319. secondary_table, index=False
  320. )
  321. grouped[grouped["uncertainty_type"] == secondary_key].to_csv(std_table, index=False)
  322. out = {
  323. "n_overlap": int(len(merged)),
  324. "spearman_vs_dxconfid": correlations,
  325. "table": str(table_path),
  326. "tables": {
  327. "confidence": str(confidence_table),
  328. secondary_key: str(secondary_table),
  329. "std": str(std_table),
  330. },
  331. "plots": plot_paths,
  332. }
  333. write_json(output_dir / "physician_summary.json", out)
  334. return out
  335. def _normalize_dx(value: Any) -> str:
  336. if value is None or (isinstance(value, float) and np.isnan(value)):
  337. return ""
  338. v = str(value).strip().upper()
  339. if v in {"NL", "NORMAL"}:
  340. return "CN"
  341. return v
  342. def run_longitudinal(
  343. evaluation: BackendEvaluation,
  344. clinical_df: pd.DataFrame,
  345. output_dir: Path,
  346. ) -> dict[str, Any]:
  347. secondary_key = (
  348. "predictive_entropy"
  349. if evaluation.uncertainty_metric == "predictive_entropy"
  350. else "std"
  351. )
  352. secondary_label = (
  353. "Mean Model Predictive Entropy"
  354. if secondary_key == "predictive_entropy"
  355. else "Mean Model Uncertainty Std"
  356. )
  357. required = ["Image Data ID", "PTID"]
  358. missing = [c for c in required if c not in clinical_df.columns]
  359. if missing:
  360. raise KeyError(f"Missing columns for longitudinal analysis: {missing}")
  361. diagnosis_col = None
  362. for candidate in ["Class", "DX", "Diagnosis"]:
  363. if candidate in clinical_df.columns:
  364. diagnosis_col = candidate
  365. break
  366. if diagnosis_col is None:
  367. raise KeyError(
  368. "No diagnosis column found. Expected one of: Class, DX, Diagnosis"
  369. )
  370. work = clinical_df[
  371. ["Image Data ID", "PTID", diagnosis_col]
  372. + [c for c in ["EXAMDATE"] if c in clinical_df.columns]
  373. ].copy()
  374. work["Image Data ID"] = pd.to_numeric(work["Image Data ID"], errors="coerce")
  375. work = work.dropna(subset=["Image Data ID", "PTID"])
  376. work["Image Data ID"] = work["Image Data ID"].astype(int)
  377. work["PTID"] = work["PTID"].astype(str).str.strip()
  378. work["diagnosis"] = work[diagnosis_col].map(_normalize_dx)
  379. if "EXAMDATE" in work.columns:
  380. work["EXAMDATE"] = pd.to_datetime(work["EXAMDATE"], errors="coerce")
  381. work = work.sort_values(["PTID", "EXAMDATE"], na_position="last")
  382. else:
  383. work = work.sort_values(["PTID", "Image Data ID"])
  384. eval_df = pd.DataFrame(
  385. {
  386. "Image Data ID": evaluation.image_ids.astype(int),
  387. "model_confidence": evaluation.uncertainty_confidence,
  388. "model_std": evaluation.uncertainty_std,
  389. "model_prob": evaluation.y_prob,
  390. }
  391. )
  392. merged = work.merge(eval_df, on="Image Data ID", how="inner")
  393. if merged.empty:
  394. raise ValueError("No overlapping Image Data ID rows for longitudinal analysis")
  395. patient_rows: list[dict[str, Any]] = []
  396. for ptid, group in merged.groupby("PTID"):
  397. diagnoses = [d for d in group["diagnosis"].tolist() if d]
  398. if len(diagnoses) < 2:
  399. continue
  400. first_dx = diagnoses[0]
  401. last_dx = diagnoses[-1]
  402. unique_dx = set(diagnoses)
  403. cohort = "other"
  404. if unique_dx.issubset({"CN"}):
  405. cohort = "stable_cn"
  406. elif unique_dx.issubset({"AD"}):
  407. cohort = "stable_ad"
  408. elif first_dx == "CN" and "AD" in unique_dx and last_dx == "AD":
  409. cohort = "cn_to_ad"
  410. patient_rows.append(
  411. {
  412. "PTID": ptid,
  413. "n_visits": int(len(group)),
  414. "first_dx": first_dx,
  415. "last_dx": last_dx,
  416. "cohort": cohort,
  417. "mean_confidence": float(group["model_confidence"].mean()),
  418. "mean_std": float(group["model_std"].mean()),
  419. "mean_prob": float(group["model_prob"].mean()),
  420. }
  421. )
  422. patient_df = pd.DataFrame(patient_rows)
  423. table_path = output_dir / "longitudinal_patient_summary.csv"
  424. patient_df.to_csv(table_path, index=False)
  425. cohort_df = (
  426. patient_df.groupby("cohort")
  427. .agg(
  428. n_patients=("PTID", "count"),
  429. mean_confidence=("mean_confidence", "mean"),
  430. mean_std=("mean_std", "mean"),
  431. mean_prob=("mean_prob", "mean"),
  432. )
  433. .reset_index()
  434. )
  435. cohort_table = output_dir / "longitudinal_cohort_summary.csv"
  436. cohort_df.to_csv(cohort_table, index=False)
  437. cohorts = ["stable_cn", "stable_ad", "cn_to_ad"]
  438. uncertainty_specs = [
  439. ("confidence", "mean_confidence", "Mean Model Confidence"),
  440. (secondary_key, "mean_std", secondary_label),
  441. ]
  442. plot_paths: dict[str, str] = {}
  443. for metric_name, metric_col, metric_label in uncertainty_specs:
  444. values = [
  445. np.asarray(
  446. patient_df.loc[patient_df["cohort"] == c, metric_col], dtype=float
  447. )
  448. for c in cohorts
  449. ]
  450. plot_path = plots_dir(output_dir) / f"longitudinal_cohort_{metric_name}.png"
  451. save_boxplot(
  452. data=values,
  453. tick_labels=cohorts,
  454. x_label="Cohort",
  455. y_label=metric_label,
  456. title=f"Longitudinal Cohort {metric_label} ({evaluation.backend})",
  457. output_path=plot_path,
  458. )
  459. plot_paths[metric_name] = str(plot_path)
  460. uncertainty_by_cohort = cohort_df.melt(
  461. id_vars=["cohort", "n_patients"],
  462. value_vars=["mean_confidence", "mean_std"],
  463. var_name="uncertainty_type",
  464. value_name="mean_value",
  465. ).replace(
  466. {
  467. "uncertainty_type": {
  468. "mean_confidence": "confidence",
  469. "mean_std": secondary_key,
  470. }
  471. }
  472. )
  473. uncertainty_table = output_dir / "longitudinal_uncertainty_by_cohort.csv"
  474. uncertainty_by_cohort.to_csv(uncertainty_table, index=False)
  475. confidence_patient_table = (
  476. output_dir / "longitudinal_confidence_patient_summary.csv"
  477. )
  478. std_patient_table = output_dir / "longitudinal_std_patient_summary.csv"
  479. confidence_cohort_table = output_dir / "longitudinal_confidence_cohort_summary.csv"
  480. std_cohort_table = output_dir / "longitudinal_std_cohort_summary.csv"
  481. secondary_patient_table = (
  482. output_dir / f"longitudinal_{secondary_key}_patient_summary.csv"
  483. )
  484. secondary_cohort_table = (
  485. output_dir / f"longitudinal_{secondary_key}_cohort_summary.csv"
  486. )
  487. patient_df[
  488. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_confidence"]
  489. ].to_csv(confidence_patient_table, index=False)
  490. patient_df[
  491. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_std"]
  492. ].to_csv(std_patient_table, index=False)
  493. patient_df[
  494. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_std"]
  495. ].to_csv(secondary_patient_table, index=False)
  496. cohort_df[["cohort", "n_patients", "mean_confidence"]].to_csv(
  497. confidence_cohort_table, index=False
  498. )
  499. cohort_df[["cohort", "n_patients", "mean_std"]].to_csv(
  500. std_cohort_table, index=False
  501. )
  502. cohort_df[["cohort", "n_patients", "mean_std"]].to_csv(
  503. secondary_cohort_table, index=False
  504. )
  505. out = {
  506. "n_patients_analyzed": int(len(patient_df)),
  507. "table_patient": str(table_path),
  508. "table_cohort": str(cohort_table),
  509. "table_uncertainty": str(uncertainty_table),
  510. "tables": {
  511. "confidence": {
  512. "patient": str(confidence_patient_table),
  513. "cohort": str(confidence_cohort_table),
  514. },
  515. secondary_key: {
  516. "patient": str(secondary_patient_table),
  517. "cohort": str(secondary_cohort_table),
  518. },
  519. "std": {
  520. "patient": str(std_patient_table),
  521. "cohort": str(std_cohort_table),
  522. },
  523. },
  524. "plots": plot_paths,
  525. }
  526. write_json(output_dir / "longitudinal_summary.json", out)
  527. return out