longitudinal_audit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import numpy as np
  6. import pandas as pd
  7. from .data_access import BackendEvaluation
  8. from .runtime import write_json
  9. def _normalize_dx(value: Any) -> str:
  10. if value is None or (isinstance(value, float) and np.isnan(value)):
  11. return ""
  12. v = str(value).strip().upper()
  13. if v in {"NL", "NORMAL"}:
  14. return "CN"
  15. return v
  16. def _is_mci(value: str) -> bool:
  17. return "MCI" in str(value).strip().upper()
  18. def _assign_cohort(diagnoses: list[str]) -> str:
  19. if len(diagnoses) < 2:
  20. return "insufficient_visits"
  21. first_dx = diagnoses[0]
  22. last_dx = diagnoses[-1]
  23. unique_dx = set(diagnoses)
  24. if unique_dx.issubset({"CN"}):
  25. return "stable_cn"
  26. if unique_dx.issubset({"AD"}):
  27. return "stable_ad"
  28. if all(_is_mci(dx) for dx in unique_dx):
  29. return "stable_mci"
  30. if first_dx == "CN" and "AD" in unique_dx and last_dx == "AD":
  31. return "cn_to_ad"
  32. if first_dx == "CN" and _is_mci(last_dx):
  33. return "cn_to_mci"
  34. if _is_mci(first_dx) and last_dx == "AD":
  35. return "mci_to_ad"
  36. return "other"
  37. def _prepare_clinical_longitudinal_table(clinical_df: pd.DataFrame) -> pd.DataFrame:
  38. required = ["Image Data ID", "PTID"]
  39. missing = [c for c in required if c not in clinical_df.columns]
  40. if missing:
  41. raise KeyError(f"Missing columns for longitudinal analysis: {missing}")
  42. diagnosis_col = None
  43. for candidate in ["Class", "DX", "Diagnosis"]:
  44. if candidate in clinical_df.columns:
  45. diagnosis_col = candidate
  46. break
  47. if diagnosis_col is None:
  48. raise KeyError(
  49. "No diagnosis column found. Expected one of: Class, DX, Diagnosis"
  50. )
  51. work = clinical_df[
  52. ["Image Data ID", "PTID", diagnosis_col]
  53. + [c for c in ["EXAMDATE"] if c in clinical_df.columns]
  54. ].copy()
  55. work["Image Data ID"] = pd.to_numeric(work["Image Data ID"], errors="coerce")
  56. work = work.dropna(subset=["Image Data ID", "PTID"])
  57. work["Image Data ID"] = work["Image Data ID"].astype(int)
  58. work["PTID"] = work["PTID"].astype(str).str.strip()
  59. work["diagnosis"] = work[diagnosis_col].map(_normalize_dx)
  60. if "EXAMDATE" in work.columns:
  61. work["EXAMDATE"] = pd.to_datetime(
  62. work["EXAMDATE"],
  63. errors="coerce",
  64. format="mixed",
  65. )
  66. work = work.sort_values(["PTID", "EXAMDATE"], na_position="last")
  67. else:
  68. work = work.sort_values(["PTID", "Image Data ID"])
  69. return work
  70. def _build_patient_table(work: pd.DataFrame) -> pd.DataFrame:
  71. rows: list[dict[str, Any]] = []
  72. for ptid, group in work.groupby("PTID"):
  73. diagnoses = [d for d in group["diagnosis"].tolist() if d]
  74. cohort = _assign_cohort(diagnoses)
  75. first_dx = diagnoses[0] if diagnoses else ""
  76. last_dx = diagnoses[-1] if diagnoses else ""
  77. rows.append(
  78. {
  79. "PTID": ptid,
  80. "n_visits": int(len(group)),
  81. "n_labeled_visits": int(len(diagnoses)),
  82. "first_dx": first_dx,
  83. "last_dx": last_dx,
  84. "cohort": cohort,
  85. }
  86. )
  87. if not rows:
  88. return pd.DataFrame(
  89. columns=[
  90. "PTID",
  91. "n_visits",
  92. "n_labeled_visits",
  93. "first_dx",
  94. "last_dx",
  95. "cohort",
  96. ]
  97. )
  98. return pd.DataFrame(rows)
  99. def _count_table(df: pd.DataFrame, column: str, count_name: str) -> pd.DataFrame:
  100. if column not in df.columns or df.empty:
  101. return pd.DataFrame({column: [], count_name: []})
  102. return (
  103. df.groupby(column, dropna=False)
  104. .size()
  105. .rename(count_name)
  106. .reset_index()
  107. .sort_values(count_name, ascending=False)
  108. )
  109. def run_longitudinal_breakdown_audit(
  110. evaluation: BackendEvaluation,
  111. clinical_df: pd.DataFrame,
  112. output_dir: Path,
  113. ) -> dict[str, Any]:
  114. output_dir.mkdir(parents=True, exist_ok=True)
  115. full_table = _prepare_clinical_longitudinal_table(clinical_df)
  116. evaluable_table = full_table[
  117. full_table["Image Data ID"].isin(evaluation.image_ids.astype(int))
  118. ].copy()
  119. full_patient = _build_patient_table(full_table)
  120. evaluable_patient = _build_patient_table(evaluable_table)
  121. full_patient_min2 = full_patient[full_patient["n_labeled_visits"] >= 2].copy()
  122. evaluable_patient_min2 = evaluable_patient[
  123. evaluable_patient["n_labeled_visits"] >= 2
  124. ].copy()
  125. full_diagnosis_counts = _count_table(full_table, "diagnosis", "n_rows")
  126. evaluable_diagnosis_counts = _count_table(evaluable_table, "diagnosis", "n_rows")
  127. full_transition_counts = (
  128. full_patient_min2.groupby(["first_dx", "last_dx"])
  129. .size()
  130. .rename("n_patients")
  131. .reset_index()
  132. )
  133. full_transition_counts = full_transition_counts.sort_values(
  134. "n_patients", ascending=False
  135. )
  136. evaluable_transition_counts = (
  137. evaluable_patient_min2.groupby(["first_dx", "last_dx"])
  138. .size()
  139. .rename("n_patients")
  140. .reset_index()
  141. )
  142. evaluable_transition_counts = evaluable_transition_counts.sort_values(
  143. "n_patients", ascending=False
  144. )
  145. full_cohort_counts = _count_table(full_patient_min2, "cohort", "n_patients")
  146. evaluable_cohort_counts = _count_table(
  147. evaluable_patient_min2, "cohort", "n_patients"
  148. )
  149. paths = {
  150. "full_diagnosis_counts": output_dir / "longitudinal_diagnosis_counts_all.csv",
  151. "evaluable_diagnosis_counts": output_dir
  152. / "longitudinal_diagnosis_counts_evaluable.csv",
  153. "full_transition_counts": output_dir / "longitudinal_transition_counts_all.csv",
  154. "evaluable_transition_counts": output_dir
  155. / "longitudinal_transition_counts_evaluable.csv",
  156. "full_cohort_counts": output_dir / "longitudinal_cohort_counts_all.csv",
  157. "evaluable_cohort_counts": output_dir
  158. / "longitudinal_cohort_counts_evaluable.csv",
  159. "full_patient_table": output_dir / "longitudinal_patient_breakdown_all.csv",
  160. "evaluable_patient_table": output_dir
  161. / "longitudinal_patient_breakdown_evaluable.csv",
  162. "summary_md": output_dir / "longitudinal_breakdown_summary.md",
  163. "summary_json": output_dir / "longitudinal_breakdown_summary.json",
  164. }
  165. full_diagnosis_counts.to_csv(paths["full_diagnosis_counts"], index=False)
  166. evaluable_diagnosis_counts.to_csv(paths["evaluable_diagnosis_counts"], index=False)
  167. full_transition_counts.to_csv(paths["full_transition_counts"], index=False)
  168. evaluable_transition_counts.to_csv(
  169. paths["evaluable_transition_counts"], index=False
  170. )
  171. full_cohort_counts.to_csv(paths["full_cohort_counts"], index=False)
  172. evaluable_cohort_counts.to_csv(paths["evaluable_cohort_counts"], index=False)
  173. full_patient.to_csv(paths["full_patient_table"], index=False)
  174. evaluable_patient.to_csv(paths["evaluable_patient_table"], index=False)
  175. summary = {
  176. "backend": evaluation.backend,
  177. "n_evaluation_images": int(len(evaluation.image_ids)),
  178. "n_clinical_rows_total": int(len(full_table)),
  179. "n_clinical_rows_evaluable": int(len(evaluable_table)),
  180. "n_patients_total": (
  181. int(full_patient["PTID"].nunique()) if not full_patient.empty else 0
  182. ),
  183. "n_patients_evaluable": (
  184. int(evaluable_patient["PTID"].nunique())
  185. if not evaluable_patient.empty
  186. else 0
  187. ),
  188. "n_patients_total_min2_visits": int(len(full_patient_min2)),
  189. "n_patients_evaluable_min2_visits": int(len(evaluable_patient_min2)),
  190. "n_patients_evaluable_single_visit_only": (
  191. int((evaluable_patient["n_labeled_visits"] < 2).sum())
  192. if not evaluable_patient.empty
  193. else 0
  194. ),
  195. "tables": {
  196. "full_diagnosis_counts": str(paths["full_diagnosis_counts"]),
  197. "evaluable_diagnosis_counts": str(paths["evaluable_diagnosis_counts"]),
  198. "full_transition_counts": str(paths["full_transition_counts"]),
  199. "evaluable_transition_counts": str(paths["evaluable_transition_counts"]),
  200. "full_cohort_counts": str(paths["full_cohort_counts"]),
  201. "evaluable_cohort_counts": str(paths["evaluable_cohort_counts"]),
  202. "full_patient_table": str(paths["full_patient_table"]),
  203. "evaluable_patient_table": str(paths["evaluable_patient_table"]),
  204. },
  205. }
  206. write_json(paths["summary_json"], summary)
  207. cohort_text = "(none)"
  208. if not evaluable_cohort_counts.empty:
  209. cohort_text = "\n".join(
  210. [
  211. f"- {str(r['cohort'])}: {int(r['n_patients'])}"
  212. for _, r in evaluable_cohort_counts.iterrows()
  213. ]
  214. )
  215. lines = [
  216. f"# Longitudinal Breakdown Audit ({evaluation.backend})",
  217. "",
  218. "This report explains the patient/cohort counts used by the longitudinal analysis.",
  219. "",
  220. "## Key Counts",
  221. f"- Evaluation images available: {summary['n_evaluation_images']}",
  222. f"- Clinical rows total: {summary['n_clinical_rows_total']}",
  223. f"- Clinical rows overlapping evaluated images: {summary['n_clinical_rows_evaluable']}",
  224. f"- Patients total (any visits): {summary['n_patients_total']}",
  225. f"- Patients with overlapping evaluated images: {summary['n_patients_evaluable']}",
  226. f"- Patients with >=2 labeled visits (full clinical): {summary['n_patients_total_min2_visits']}",
  227. f"- Patients with >=2 labeled visits (evaluable overlap): {summary['n_patients_evaluable_min2_visits']}",
  228. f"- Overlap patients with only 1 labeled visit: {summary['n_patients_evaluable_single_visit_only']}",
  229. "",
  230. "## Evaluable Cohort Counts (>=2 labeled visits)",
  231. cohort_text,
  232. "",
  233. "## Interpretation",
  234. "- The longitudinal plot uses only patients with at least two labeled visits after intersection with the evaluated image IDs.",
  235. "- If MCI transition cohorts are missing, it typically means those patients do not have enough overlapping evaluated visits (or labels) in this run.",
  236. "- Use the CSV tables below to inspect diagnosis and transition distributions in full clinical data versus evaluable overlap.",
  237. "",
  238. "## Output Tables",
  239. f"- {paths['full_diagnosis_counts'].name}",
  240. f"- {paths['evaluable_diagnosis_counts'].name}",
  241. f"- {paths['full_transition_counts'].name}",
  242. f"- {paths['evaluable_transition_counts'].name}",
  243. f"- {paths['full_cohort_counts'].name}",
  244. f"- {paths['evaluable_cohort_counts'].name}",
  245. f"- {paths['full_patient_table'].name}",
  246. f"- {paths['evaluable_patient_table'].name}",
  247. ]
  248. paths["summary_md"].write_text("\n".join(lines), encoding="utf-8")
  249. return {
  250. **summary,
  251. "summary_markdown": str(paths["summary_md"]),
  252. "summary_json": str(paths["summary_json"]),
  253. }