analysis_modules.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. import re
  5. from typing import Any
  6. import numpy as np
  7. import pandas as pd
  8. from .data_access import BackendEvaluation, physician_column
  9. from .defaults import DEFAULT_DECISION_THRESHOLD, uncertainty_cutoff_percentiles
  10. from .metrics import calibration_stats, performance_at_threshold, threshold_sweep
  11. from .plotting import (
  12. plots_dir,
  13. save_boxplot,
  14. save_calibration_plot,
  15. save_performance_threshold_plot,
  16. save_performance_threshold_pair_plot,
  17. save_uncertainty_cutoff_plot,
  18. save_uncertainty_cutoff_pair_plot,
  19. )
  20. from .runtime import write_json
  21. def _save_table(rows: list[dict[str, Any]], out_path: Path) -> pd.DataFrame:
  22. df = pd.DataFrame(rows)
  23. out_path.parent.mkdir(parents=True, exist_ok=True)
  24. df.to_csv(out_path, index=False)
  25. return df
  26. def _uncertainty_percentiles(values: np.ndarray) -> np.ndarray:
  27. vals = np.asarray(values, dtype=float)
  28. n = len(vals)
  29. if n == 0:
  30. return np.asarray([], dtype=float)
  31. if n == 1:
  32. return np.asarray([0.0], dtype=float)
  33. order = np.argsort(vals)
  34. ranks = np.empty(n, dtype=float)
  35. ranks[order] = np.arange(0, n, dtype=float)
  36. return (ranks / float(n - 1)) * 100.0
  37. def _save_ensemble_prediction_debug(
  38. evaluation: BackendEvaluation,
  39. output_dir: Path,
  40. ) -> str:
  41. uncertainty_vals = np.asarray(evaluation.uncertainty_confidence, dtype=float)
  42. uncertainty_pct = _uncertainty_percentiles(uncertainty_vals)
  43. pred_label = (evaluation.y_prob >= DEFAULT_DECISION_THRESHOLD).astype(int)
  44. true_label = np.asarray(evaluation.y_true, dtype=int)
  45. is_correct = (pred_label == true_label).astype(int)
  46. debug_df = pd.DataFrame(
  47. {
  48. "image_id": evaluation.image_ids.astype(int),
  49. "predicted_probability": np.asarray(evaluation.y_prob, dtype=float),
  50. "predicted_label": pred_label,
  51. "actual_label": true_label,
  52. "is_correct": is_correct,
  53. "confidence": uncertainty_vals,
  54. "confidence_percentile": uncertainty_pct,
  55. }
  56. ).sort_values("confidence", ascending=False)
  57. path = output_dir / "ensemble_prediction_debug.csv"
  58. debug_df.to_csv(path, index=False)
  59. return str(path)
  60. def _uncertainty_cutoff_analysis(
  61. evaluation: BackendEvaluation,
  62. output_dir: Path,
  63. uncertainty_types: list[tuple[str, np.ndarray]],
  64. cutoff_percentiles: np.ndarray,
  65. table_filename: str,
  66. plot_filename_prefix: str,
  67. title_prefix: str,
  68. x_label: str,
  69. ) -> dict[str, Any]:
  70. def _slug(name: str) -> str:
  71. s = re.sub(r"[^a-z0-9]+", "_", name.strip().lower())
  72. return s.strip("_") or "metric"
  73. def _confidence_like(name: str) -> bool:
  74. return "confidence" in name.strip().lower()
  75. cutoff_rows: list[dict[str, Any]] = []
  76. for uncertainty_name, values in uncertainty_types:
  77. finite_mask = np.isfinite(values)
  78. if not finite_mask.any():
  79. continue
  80. values_valid = values[finite_mask]
  81. y_true_valid = evaluation.y_true[finite_mask]
  82. y_prob_valid = evaluation.y_prob[finite_mask]
  83. keep_higher = _confidence_like(uncertainty_name)
  84. selection_label = (
  85. "metric >= percentile cutoff (higher is lower uncertainty)"
  86. if keep_higher
  87. else "metric <= percentile cutoff (lower is lower uncertainty)"
  88. )
  89. for cutoff_percentile in cutoff_percentiles:
  90. cutoff_value = float(np.percentile(values_valid, cutoff_percentile))
  91. if keep_higher:
  92. keep_mask = values_valid >= cutoff_value
  93. else:
  94. keep_mask = values_valid <= cutoff_value
  95. retained = int(keep_mask.sum())
  96. if retained == 0:
  97. continue
  98. perf = performance_at_threshold(
  99. y_true=y_true_valid[keep_mask],
  100. y_prob=y_prob_valid[keep_mask],
  101. threshold=DEFAULT_DECISION_THRESHOLD,
  102. )
  103. cutoff_rows.append(
  104. {
  105. "uncertainty_type": uncertainty_name,
  106. "cutoff_percentile": float(cutoff_percentile),
  107. "cutoff_value": cutoff_value,
  108. "n_retained": retained,
  109. "coverage": float(retained / len(values_valid)),
  110. "selection_rule": selection_label,
  111. "keep_higher": bool(keep_higher),
  112. "accuracy": float(perf["accuracy"]),
  113. "f1": float(perf["f1"]),
  114. "n_correct": int(perf["tp"] + perf["tn"]),
  115. "n_incorrect": int(perf["fp"] + perf["fn"]),
  116. }
  117. )
  118. table_path = output_dir / table_filename
  119. accuracy_plot_path = plots_dir(output_dir) / f"{plot_filename_prefix}_accuracy.png"
  120. f1_plot_path = plots_dir(output_dir) / f"{plot_filename_prefix}_f1.png"
  121. pair_plot_path = plots_dir(output_dir) / f"{plot_filename_prefix}_accuracy_f1.png"
  122. if not cutoff_rows:
  123. return {
  124. "table": str(table_path),
  125. "accuracy_plot": str(accuracy_plot_path),
  126. "f1_plot": str(f1_plot_path),
  127. "pair_plot": str(pair_plot_path),
  128. "rows": 0,
  129. }
  130. cutoff_df = pd.DataFrame(cutoff_rows)
  131. cutoff_df.to_csv(table_path, index=False)
  132. # Restriction level increases from left to right so right-most points are most restricted.
  133. cutoff_df["restriction_level"] = np.where(
  134. cutoff_df["keep_higher"].astype(bool),
  135. cutoff_df["cutoff_percentile"],
  136. 100.0 - cutoff_df["cutoff_percentile"],
  137. )
  138. plots_by_uncertainty: dict[str, dict[str, str]] = {}
  139. for uncertainty_name in sorted(pd.unique(cutoff_df["uncertainty_type"])):
  140. sub_df = cutoff_df[cutoff_df["uncertainty_type"] == uncertainty_name].copy()
  141. slug = _slug(str(uncertainty_name))
  142. sub_accuracy_plot_path = (
  143. plots_dir(output_dir) / f"{plot_filename_prefix}_{slug}_accuracy.png"
  144. )
  145. sub_f1_plot_path = (
  146. plots_dir(output_dir) / f"{plot_filename_prefix}_{slug}_f1.png"
  147. )
  148. sub_pair_plot_path = (
  149. plots_dir(output_dir) / f"{plot_filename_prefix}_{slug}_accuracy_f1.png"
  150. )
  151. save_uncertainty_cutoff_plot(
  152. cutoff_df=sub_df,
  153. title_prefix=f"{title_prefix} ({uncertainty_name})",
  154. x_label=x_label,
  155. output_path=sub_accuracy_plot_path,
  156. metric_column="accuracy",
  157. metric_label="Accuracy",
  158. plot_key=f"{plot_filename_prefix}_{slug}_accuracy",
  159. )
  160. save_uncertainty_cutoff_plot(
  161. cutoff_df=sub_df,
  162. title_prefix=f"{title_prefix} ({uncertainty_name})",
  163. x_label=x_label,
  164. output_path=sub_f1_plot_path,
  165. metric_column="f1",
  166. metric_label="F1",
  167. plot_key=f"{plot_filename_prefix}_{slug}_f1",
  168. )
  169. save_uncertainty_cutoff_pair_plot(
  170. cutoff_df=sub_df,
  171. title_prefix=f"{title_prefix} ({uncertainty_name})",
  172. x_label=x_label,
  173. output_path=sub_pair_plot_path,
  174. plot_key=f"{plot_filename_prefix}_{slug}_accuracy_f1",
  175. )
  176. plots_by_uncertainty[str(uncertainty_name)] = {
  177. "accuracy": str(sub_accuracy_plot_path),
  178. "f1": str(sub_f1_plot_path),
  179. "accuracy_f1": str(sub_pair_plot_path),
  180. }
  181. return {
  182. "table": str(table_path),
  183. "plots_by_uncertainty": plots_by_uncertainty,
  184. "rows": int(len(cutoff_df)),
  185. }
  186. def run_performance(
  187. evaluation: BackendEvaluation,
  188. output_dir: Path,
  189. thresholds: np.ndarray,
  190. ) -> dict[str, Any]:
  191. rows = threshold_sweep(evaluation.y_true, evaluation.y_prob, thresholds)
  192. table_path = output_dir / "performance_threshold_sweep.csv"
  193. df = _save_table(rows, table_path)
  194. accuracy_plot_path = plots_dir(output_dir) / "performance_threshold_accuracy.png"
  195. f1_plot_path = plots_dir(output_dir) / "performance_threshold_f1.png"
  196. pair_plot_path = plots_dir(output_dir) / "performance_threshold_accuracy_f1.png"
  197. save_performance_threshold_plot(
  198. df=df,
  199. backend=evaluation.backend,
  200. output_path=accuracy_plot_path,
  201. metric_column="accuracy",
  202. metric_label="Accuracy",
  203. plot_key="performance_threshold_accuracy",
  204. )
  205. save_performance_threshold_plot(
  206. df=df,
  207. backend=evaluation.backend,
  208. output_path=f1_plot_path,
  209. metric_column="f1",
  210. metric_label="F1",
  211. plot_key="performance_threshold_f1",
  212. )
  213. save_performance_threshold_pair_plot(
  214. df=df,
  215. backend=evaluation.backend,
  216. output_path=pair_plot_path,
  217. plot_key="performance_threshold_accuracy_f1",
  218. )
  219. best_idx = int(df["f1"].idxmax())
  220. best = df.iloc[best_idx].to_dict()
  221. cutoff_percentiles = uncertainty_cutoff_percentiles()
  222. model_output_vals = np.asarray(evaluation.uncertainty_confidence, dtype=float)
  223. secondary_uncertainty = np.asarray(evaluation.uncertainty_std, dtype=float)
  224. secondary_uncertainty_name = (
  225. "predictive uncertainty"
  226. if evaluation.uncertainty_metric == "predictive_entropy"
  227. else "standard deviation"
  228. )
  229. uncertainty_types = [
  230. ("confidence", model_output_vals),
  231. (secondary_uncertainty_name, secondary_uncertainty),
  232. ]
  233. uncertainty_cutoff = _uncertainty_cutoff_analysis(
  234. evaluation=evaluation,
  235. output_dir=output_dir,
  236. uncertainty_types=uncertainty_types,
  237. cutoff_percentiles=cutoff_percentiles,
  238. table_filename="performance_uncertainty_cutoff.csv",
  239. plot_filename_prefix="performance_uncertainty_cutoff",
  240. title_prefix="Model Output / Uncertainty Cutoff Percentile",
  241. x_label="Restriction Level (0 = all samples, 100 = most restricted subset)",
  242. )
  243. percentile_cutoffs = uncertainty_cutoff_percentiles()
  244. uncertainty_percentile_cutoff = _uncertainty_cutoff_analysis(
  245. evaluation=evaluation,
  246. output_dir=output_dir,
  247. uncertainty_types=uncertainty_types,
  248. cutoff_percentiles=percentile_cutoffs,
  249. table_filename="performance_uncertainty_percentile_cutoff.csv",
  250. plot_filename_prefix="performance_uncertainty_percentile_cutoff",
  251. title_prefix="Model Output / Uncertainty Percentile Floor",
  252. x_label="Percentile Floor (0 = all samples, 100 = top percentile subset)",
  253. )
  254. cutoff_table_path = Path(uncertainty_cutoff["table"])
  255. percentile_cutoff_table_path = Path(uncertainty_percentile_cutoff["table"])
  256. summary = {
  257. "best_by_f1": {
  258. k: float(v) for k, v in best.items() if isinstance(v, (int, float))
  259. },
  260. "table": str(table_path),
  261. "plots": {
  262. "accuracy": str(accuracy_plot_path),
  263. "f1": str(f1_plot_path),
  264. "accuracy_f1": str(pair_plot_path),
  265. },
  266. "uncertainty_cutoff": {
  267. "table": str(cutoff_table_path),
  268. "plots_by_uncertainty": uncertainty_cutoff["plots_by_uncertainty"],
  269. "decision_threshold": DEFAULT_DECISION_THRESHOLD,
  270. },
  271. "uncertainty_percentile_cutoff": {
  272. "table": str(percentile_cutoff_table_path),
  273. "plots_by_uncertainty": uncertainty_percentile_cutoff[
  274. "plots_by_uncertainty"
  275. ],
  276. "decision_threshold": DEFAULT_DECISION_THRESHOLD,
  277. },
  278. }
  279. if evaluation.backend == "ensemble":
  280. summary["ensemble_prediction_debug"] = _save_ensemble_prediction_debug(
  281. evaluation=evaluation,
  282. output_dir=output_dir,
  283. )
  284. write_json(output_dir / "performance_summary.json", summary)
  285. return summary
  286. def run_calibration(
  287. evaluation: BackendEvaluation,
  288. output_dir: Path,
  289. bins: int,
  290. ) -> dict[str, Any]:
  291. summary, per_bin = calibration_stats(
  292. evaluation.y_true, evaluation.y_prob, bins=bins
  293. )
  294. bin_df = pd.DataFrame(
  295. per_bin,
  296. columns=["mean_confidence", "fraction_positive", "count"],
  297. )
  298. table_path = output_dir / "calibration_bins.csv"
  299. bin_df.to_csv(table_path, index=False)
  300. plot_path = plots_dir(output_dir) / "calibration_reliability.png"
  301. save_calibration_plot(
  302. per_bin=per_bin,
  303. backend=evaluation.backend,
  304. output_path=plot_path,
  305. )
  306. out = {
  307. **summary,
  308. "table": str(table_path),
  309. "plot": str(plot_path),
  310. }
  311. write_json(output_dir / "calibration_summary.json", out)
  312. return out
  313. def run_physician(
  314. evaluation: BackendEvaluation,
  315. clinical_df: pd.DataFrame,
  316. output_dir: Path,
  317. ) -> dict[str, Any]:
  318. secondary_key = (
  319. "predictive_entropy"
  320. if evaluation.uncertainty_metric == "predictive_entropy"
  321. else "std"
  322. )
  323. secondary_label = (
  324. "Predictive Uncertainty"
  325. if secondary_key == "predictive_entropy"
  326. else "Standard Deviation"
  327. )
  328. col = physician_column(clinical_df)
  329. subset = clinical_df[["Image Data ID", col]].copy()
  330. subset[col] = pd.to_numeric(subset[col], errors="coerce")
  331. subset = subset.dropna(subset=["Image Data ID", col])
  332. subset["Image Data ID"] = subset["Image Data ID"].astype(int)
  333. subset[col] = subset[col].astype(int)
  334. eval_df = pd.DataFrame(
  335. {
  336. "Image Data ID": evaluation.image_ids.astype(int),
  337. "model_probability": evaluation.uncertainty_confidence,
  338. "model_std": evaluation.uncertainty_std,
  339. "model_prob": evaluation.y_prob,
  340. }
  341. )
  342. merged = eval_df.merge(subset, on="Image Data ID", how="inner")
  343. if merged.empty:
  344. raise ValueError("No overlapping Image Data ID rows for physician analysis")
  345. grouped_rows: list[dict[str, Any]] = []
  346. uncertainty_specs = [
  347. ("confidence", "model_probability", "Confidence"),
  348. (secondary_key, "model_std", secondary_label),
  349. ]
  350. ratings = [int(r) for r in sorted(pd.unique(merged[col]))]
  351. plot_paths: dict[str, str] = {}
  352. correlations: dict[str, float] = {}
  353. for metric_name, metric_col, metric_label in uncertainty_specs:
  354. grouped_metric = (
  355. merged.groupby(col)
  356. .agg(
  357. n=("Image Data ID", "count"),
  358. mean_value=(metric_col, "mean"),
  359. std_value=(metric_col, "std"),
  360. mean_prob=("model_prob", "mean"),
  361. )
  362. .reset_index()
  363. .rename(columns={col: "physician_rating"})
  364. )
  365. grouped_metric["uncertainty_type"] = metric_name
  366. grouped_rows.extend(
  367. [
  368. {str(k): v for k, v in rec.items()}
  369. for rec in grouped_metric.to_dict(orient="records")
  370. ]
  371. )
  372. data = [
  373. np.asarray(merged.loc[merged[col] == r, metric_col], dtype=float)
  374. for r in ratings
  375. ]
  376. plot_path = plots_dir(output_dir) / f"physician_{metric_name}_boxplot.png"
  377. save_boxplot(
  378. data=data,
  379. tick_labels=[str(r) for r in ratings],
  380. x_label="Physician Confidence Rating (DXCONFID)",
  381. y_label=metric_label,
  382. title=f"{metric_label} by Physician Confidence Rating ({evaluation.backend})",
  383. output_path=plot_path,
  384. )
  385. corr = float(
  386. pd.to_numeric(
  387. merged[[metric_col, col]].corr(method="spearman").iloc[0, 1],
  388. errors="coerce",
  389. )
  390. )
  391. correlations[metric_name] = corr
  392. plot_paths[metric_name] = str(plot_path)
  393. grouped = pd.DataFrame(grouped_rows)
  394. table_path = output_dir / "physician_grouped_metrics.csv"
  395. grouped.to_csv(table_path, index=False)
  396. confidence_table = output_dir / "physician_confidence_grouped_metrics.csv"
  397. std_table = output_dir / "physician_std_grouped_metrics.csv"
  398. secondary_table = output_dir / f"physician_{secondary_key}_grouped_metrics.csv"
  399. grouped[grouped["uncertainty_type"] == "confidence"].to_csv(
  400. confidence_table, index=False
  401. )
  402. grouped[grouped["uncertainty_type"] == secondary_key].to_csv(
  403. secondary_table, index=False
  404. )
  405. grouped[grouped["uncertainty_type"] == secondary_key].to_csv(std_table, index=False)
  406. out = {
  407. "n_overlap": int(len(merged)),
  408. "spearman_vs_dxconfid": correlations,
  409. "table": str(table_path),
  410. "tables": {
  411. "confidence": str(confidence_table),
  412. secondary_key: str(secondary_table),
  413. "std": str(std_table),
  414. },
  415. "plots": plot_paths,
  416. }
  417. write_json(output_dir / "physician_summary.json", out)
  418. return out
  419. def _normalize_dx(value: Any) -> str:
  420. if value is None or (isinstance(value, float) and np.isnan(value)):
  421. return ""
  422. v = str(value).strip().upper()
  423. if v in {"NL", "NORMAL"}:
  424. return "CN"
  425. return v
  426. def run_longitudinal(
  427. evaluation: BackendEvaluation,
  428. clinical_df: pd.DataFrame,
  429. output_dir: Path,
  430. ) -> dict[str, Any]:
  431. secondary_key = (
  432. "predictive_entropy"
  433. if evaluation.uncertainty_metric == "predictive_entropy"
  434. else "std"
  435. )
  436. secondary_label = (
  437. "Mean Predictive Uncertainty"
  438. if secondary_key == "predictive_entropy"
  439. else "Mean Standard Deviation"
  440. )
  441. required = ["Image Data ID", "PTID"]
  442. missing = [c for c in required if c not in clinical_df.columns]
  443. if missing:
  444. raise KeyError(f"Missing columns for longitudinal analysis: {missing}")
  445. diagnosis_col = None
  446. for candidate in ["Class", "DX", "Diagnosis"]:
  447. if candidate in clinical_df.columns:
  448. diagnosis_col = candidate
  449. break
  450. if diagnosis_col is None:
  451. raise KeyError(
  452. "No diagnosis column found. Expected one of: Class, DX, Diagnosis"
  453. )
  454. work = clinical_df[
  455. ["Image Data ID", "PTID", diagnosis_col]
  456. + [c for c in ["EXAMDATE"] if c in clinical_df.columns]
  457. ].copy()
  458. work["Image Data ID"] = pd.to_numeric(work["Image Data ID"], errors="coerce")
  459. work = work.dropna(subset=["Image Data ID", "PTID"])
  460. work["Image Data ID"] = work["Image Data ID"].astype(int)
  461. work["PTID"] = work["PTID"].astype(str).str.strip()
  462. work["diagnosis"] = work[diagnosis_col].map(_normalize_dx)
  463. if "EXAMDATE" in work.columns:
  464. work["EXAMDATE"] = pd.to_datetime(work["EXAMDATE"], errors="coerce")
  465. work = work.sort_values(["PTID", "EXAMDATE"], na_position="last")
  466. else:
  467. work = work.sort_values(["PTID", "Image Data ID"])
  468. eval_df = pd.DataFrame(
  469. {
  470. "Image Data ID": evaluation.image_ids.astype(int),
  471. "model_confidence": evaluation.uncertainty_confidence,
  472. "model_std": evaluation.uncertainty_std,
  473. "model_prob": evaluation.y_prob,
  474. }
  475. )
  476. merged = work.merge(eval_df, on="Image Data ID", how="inner")
  477. if merged.empty:
  478. raise ValueError("No overlapping Image Data ID rows for longitudinal analysis")
  479. patient_rows: list[dict[str, Any]] = []
  480. for ptid, group in merged.groupby("PTID"):
  481. diagnoses = [d for d in group["diagnosis"].tolist() if d]
  482. if len(diagnoses) < 2:
  483. continue
  484. first_dx = diagnoses[0]
  485. last_dx = diagnoses[-1]
  486. unique_dx = set(diagnoses)
  487. cohort = "other"
  488. if unique_dx.issubset({"CN"}):
  489. cohort = "stable_cn"
  490. elif unique_dx.issubset({"AD"}):
  491. cohort = "stable_ad"
  492. elif first_dx == "CN" and last_dx == "AD" and unique_dx.issubset({"CN", "AD"}):
  493. cohort = "cn_to_ad"
  494. patient_rows.append(
  495. {
  496. "PTID": ptid,
  497. "n_visits": int(len(group)),
  498. "first_dx": first_dx,
  499. "last_dx": last_dx,
  500. "cohort": cohort,
  501. "mean_confidence": float(group["model_confidence"].mean()),
  502. "mean_std": float(group["model_std"].mean()),
  503. "mean_prob": float(group["model_prob"].mean()),
  504. }
  505. )
  506. patient_df = pd.DataFrame(patient_rows)
  507. patient_df = patient_df[
  508. patient_df["cohort"].isin(["stable_cn", "stable_ad", "cn_to_ad"])
  509. ].copy()
  510. table_path = output_dir / "longitudinal_patient_summary.csv"
  511. patient_df.to_csv(table_path, index=False)
  512. cohort_df = (
  513. patient_df.groupby("cohort")
  514. .agg(
  515. n_patients=("PTID", "count"),
  516. mean_confidence=("mean_confidence", "mean"),
  517. mean_std=("mean_std", "mean"),
  518. mean_prob=("mean_prob", "mean"),
  519. )
  520. .reset_index()
  521. )
  522. cohort_table = output_dir / "longitudinal_cohort_summary.csv"
  523. cohort_df.to_csv(cohort_table, index=False)
  524. cohorts = ["stable_cn", "stable_ad", "cn_to_ad"]
  525. uncertainty_specs = [
  526. ("confidence", "mean_confidence", "Mean Confidence"),
  527. (secondary_key, "mean_std", secondary_label),
  528. ]
  529. plot_paths: dict[str, str] = {}
  530. for metric_name, metric_col, metric_label in uncertainty_specs:
  531. values = [
  532. np.asarray(
  533. patient_df.loc[patient_df["cohort"] == c, metric_col], dtype=float
  534. )
  535. for c in cohorts
  536. ]
  537. plot_path = plots_dir(output_dir) / f"longitudinal_cohort_{metric_name}.png"
  538. save_boxplot(
  539. data=values,
  540. tick_labels=cohorts,
  541. x_label="Longitudinal Cohort",
  542. y_label=metric_label,
  543. title=f"Longitudinal Cohort Comparison: {metric_label} ({evaluation.backend})",
  544. output_path=plot_path,
  545. )
  546. plot_paths[metric_name] = str(plot_path)
  547. uncertainty_by_cohort = cohort_df.melt(
  548. id_vars=["cohort", "n_patients"],
  549. value_vars=["mean_confidence", "mean_std"],
  550. var_name="uncertainty_type",
  551. value_name="mean_value",
  552. ).replace(
  553. {
  554. "uncertainty_type": {
  555. "mean_confidence": "confidence",
  556. "mean_std": secondary_key,
  557. }
  558. }
  559. )
  560. uncertainty_table = output_dir / "longitudinal_uncertainty_by_cohort.csv"
  561. uncertainty_by_cohort.to_csv(uncertainty_table, index=False)
  562. confidence_patient_table = (
  563. output_dir / "longitudinal_confidence_patient_summary.csv"
  564. )
  565. std_patient_table = output_dir / "longitudinal_std_patient_summary.csv"
  566. confidence_cohort_table = output_dir / "longitudinal_confidence_cohort_summary.csv"
  567. std_cohort_table = output_dir / "longitudinal_std_cohort_summary.csv"
  568. secondary_patient_table = (
  569. output_dir / f"longitudinal_{secondary_key}_patient_summary.csv"
  570. )
  571. secondary_cohort_table = (
  572. output_dir / f"longitudinal_{secondary_key}_cohort_summary.csv"
  573. )
  574. patient_df[
  575. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_confidence"]
  576. ].to_csv(confidence_patient_table, index=False)
  577. patient_df[
  578. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_std"]
  579. ].to_csv(std_patient_table, index=False)
  580. patient_df[
  581. ["PTID", "n_visits", "first_dx", "last_dx", "cohort", "mean_std"]
  582. ].to_csv(secondary_patient_table, index=False)
  583. cohort_df[["cohort", "n_patients", "mean_confidence"]].to_csv(
  584. confidence_cohort_table, index=False
  585. )
  586. cohort_df[["cohort", "n_patients", "mean_std"]].to_csv(
  587. std_cohort_table, index=False
  588. )
  589. cohort_df[["cohort", "n_patients", "mean_std"]].to_csv(
  590. secondary_cohort_table, index=False
  591. )
  592. out = {
  593. "n_patients_analyzed": int(len(patient_df)),
  594. "table_patient": str(table_path),
  595. "table_cohort": str(cohort_table),
  596. "table_uncertainty": str(uncertainty_table),
  597. "tables": {
  598. "confidence": {
  599. "patient": str(confidence_patient_table),
  600. "cohort": str(confidence_cohort_table),
  601. },
  602. secondary_key: {
  603. "patient": str(secondary_patient_table),
  604. "cohort": str(secondary_cohort_table),
  605. },
  606. "std": {
  607. "patient": str(std_patient_table),
  608. "cohort": str(std_cohort_table),
  609. },
  610. },
  611. "plots": plot_paths,
  612. }
  613. write_json(output_dir / "longitudinal_summary.json", out)
  614. return out