# pyright: basic from __future__ import annotations from pathlib import Path from typing import Any import numpy as np import pandas as pd from .data_access import BackendEvaluation from .runtime import write_json def _normalize_dx(value: Any) -> str: if value is None or (isinstance(value, float) and np.isnan(value)): return "" v = str(value).strip().upper() if v in {"NL", "NORMAL"}: return "CN" return v def _is_mci(value: str) -> bool: return "MCI" in str(value).strip().upper() def _assign_cohort(diagnoses: list[str]) -> str: if len(diagnoses) < 2: return "insufficient_visits" first_dx = diagnoses[0] last_dx = diagnoses[-1] unique_dx = set(diagnoses) if unique_dx.issubset({"CN"}): return "stable_cn" if unique_dx.issubset({"AD"}): return "stable_ad" if all(_is_mci(dx) for dx in unique_dx): return "stable_mci" if first_dx == "CN" and "AD" in unique_dx and last_dx == "AD": return "cn_to_ad" if first_dx == "CN" and _is_mci(last_dx): return "cn_to_mci" if _is_mci(first_dx) and last_dx == "AD": return "mci_to_ad" return "other" def _prepare_clinical_longitudinal_table(clinical_df: pd.DataFrame) -> pd.DataFrame: required = ["Image Data ID", "PTID"] missing = [c for c in required if c not in clinical_df.columns] if missing: raise KeyError(f"Missing columns for longitudinal analysis: {missing}") diagnosis_col = None for candidate in ["Class", "DX", "Diagnosis"]: if candidate in clinical_df.columns: diagnosis_col = candidate break if diagnosis_col is None: raise KeyError( "No diagnosis column found. Expected one of: Class, DX, Diagnosis" ) work = clinical_df[ ["Image Data ID", "PTID", diagnosis_col] + [c for c in ["EXAMDATE"] if c in clinical_df.columns] ].copy() work["Image Data ID"] = pd.to_numeric(work["Image Data ID"], errors="coerce") work = work.dropna(subset=["Image Data ID", "PTID"]) work["Image Data ID"] = work["Image Data ID"].astype(int) work["PTID"] = work["PTID"].astype(str).str.strip() work["diagnosis"] = work[diagnosis_col].map(_normalize_dx) if "EXAMDATE" in work.columns: work["EXAMDATE"] = pd.to_datetime( work["EXAMDATE"], errors="coerce", format="mixed", ) work = work.sort_values(["PTID", "EXAMDATE"], na_position="last") else: work = work.sort_values(["PTID", "Image Data ID"]) return work def _build_patient_table(work: pd.DataFrame) -> pd.DataFrame: rows: list[dict[str, Any]] = [] for ptid, group in work.groupby("PTID"): diagnoses = [d for d in group["diagnosis"].tolist() if d] cohort = _assign_cohort(diagnoses) first_dx = diagnoses[0] if diagnoses else "" last_dx = diagnoses[-1] if diagnoses else "" rows.append( { "PTID": ptid, "n_visits": int(len(group)), "n_labeled_visits": int(len(diagnoses)), "first_dx": first_dx, "last_dx": last_dx, "cohort": cohort, } ) if not rows: return pd.DataFrame( columns=[ "PTID", "n_visits", "n_labeled_visits", "first_dx", "last_dx", "cohort", ] ) return pd.DataFrame(rows) def _count_table(df: pd.DataFrame, column: str, count_name: str) -> pd.DataFrame: if column not in df.columns or df.empty: return pd.DataFrame({column: [], count_name: []}) return ( df.groupby(column, dropna=False) .size() .rename(count_name) .reset_index() .sort_values(count_name, ascending=False) ) def run_longitudinal_breakdown_audit( evaluation: BackendEvaluation, clinical_df: pd.DataFrame, output_dir: Path, ) -> dict[str, Any]: output_dir.mkdir(parents=True, exist_ok=True) full_table = _prepare_clinical_longitudinal_table(clinical_df) evaluable_table = full_table[ full_table["Image Data ID"].isin(evaluation.image_ids.astype(int)) ].copy() full_patient = _build_patient_table(full_table) evaluable_patient = _build_patient_table(evaluable_table) full_patient_min2 = full_patient[full_patient["n_labeled_visits"] >= 2].copy() evaluable_patient_min2 = evaluable_patient[ evaluable_patient["n_labeled_visits"] >= 2 ].copy() full_diagnosis_counts = _count_table(full_table, "diagnosis", "n_rows") evaluable_diagnosis_counts = _count_table(evaluable_table, "diagnosis", "n_rows") full_transition_counts = ( full_patient_min2.groupby(["first_dx", "last_dx"]) .size() .rename("n_patients") .reset_index() ) full_transition_counts = full_transition_counts.sort_values( "n_patients", ascending=False ) evaluable_transition_counts = ( evaluable_patient_min2.groupby(["first_dx", "last_dx"]) .size() .rename("n_patients") .reset_index() ) evaluable_transition_counts = evaluable_transition_counts.sort_values( "n_patients", ascending=False ) full_cohort_counts = _count_table(full_patient_min2, "cohort", "n_patients") evaluable_cohort_counts = _count_table( evaluable_patient_min2, "cohort", "n_patients" ) paths = { "full_diagnosis_counts": output_dir / "longitudinal_diagnosis_counts_all.csv", "evaluable_diagnosis_counts": output_dir / "longitudinal_diagnosis_counts_evaluable.csv", "full_transition_counts": output_dir / "longitudinal_transition_counts_all.csv", "evaluable_transition_counts": output_dir / "longitudinal_transition_counts_evaluable.csv", "full_cohort_counts": output_dir / "longitudinal_cohort_counts_all.csv", "evaluable_cohort_counts": output_dir / "longitudinal_cohort_counts_evaluable.csv", "full_patient_table": output_dir / "longitudinal_patient_breakdown_all.csv", "evaluable_patient_table": output_dir / "longitudinal_patient_breakdown_evaluable.csv", "summary_md": output_dir / "longitudinal_breakdown_summary.md", "summary_json": output_dir / "longitudinal_breakdown_summary.json", } full_diagnosis_counts.to_csv(paths["full_diagnosis_counts"], index=False) evaluable_diagnosis_counts.to_csv(paths["evaluable_diagnosis_counts"], index=False) full_transition_counts.to_csv(paths["full_transition_counts"], index=False) evaluable_transition_counts.to_csv( paths["evaluable_transition_counts"], index=False ) full_cohort_counts.to_csv(paths["full_cohort_counts"], index=False) evaluable_cohort_counts.to_csv(paths["evaluable_cohort_counts"], index=False) full_patient.to_csv(paths["full_patient_table"], index=False) evaluable_patient.to_csv(paths["evaluable_patient_table"], index=False) summary = { "backend": evaluation.backend, "n_evaluation_images": int(len(evaluation.image_ids)), "n_clinical_rows_total": int(len(full_table)), "n_clinical_rows_evaluable": int(len(evaluable_table)), "n_patients_total": ( int(full_patient["PTID"].nunique()) if not full_patient.empty else 0 ), "n_patients_evaluable": ( int(evaluable_patient["PTID"].nunique()) if not evaluable_patient.empty else 0 ), "n_patients_total_min2_visits": int(len(full_patient_min2)), "n_patients_evaluable_min2_visits": int(len(evaluable_patient_min2)), "n_patients_evaluable_single_visit_only": ( int((evaluable_patient["n_labeled_visits"] < 2).sum()) if not evaluable_patient.empty else 0 ), "tables": { "full_diagnosis_counts": str(paths["full_diagnosis_counts"]), "evaluable_diagnosis_counts": str(paths["evaluable_diagnosis_counts"]), "full_transition_counts": str(paths["full_transition_counts"]), "evaluable_transition_counts": str(paths["evaluable_transition_counts"]), "full_cohort_counts": str(paths["full_cohort_counts"]), "evaluable_cohort_counts": str(paths["evaluable_cohort_counts"]), "full_patient_table": str(paths["full_patient_table"]), "evaluable_patient_table": str(paths["evaluable_patient_table"]), }, } write_json(paths["summary_json"], summary) cohort_text = "(none)" if not evaluable_cohort_counts.empty: cohort_text = "\n".join( [ f"- {str(r['cohort'])}: {int(r['n_patients'])}" for _, r in evaluable_cohort_counts.iterrows() ] ) lines = [ f"# Longitudinal Breakdown Audit ({evaluation.backend})", "", "This report explains the patient/cohort counts used by the longitudinal analysis.", "", "## Key Counts", f"- Evaluation images available: {summary['n_evaluation_images']}", f"- Clinical rows total: {summary['n_clinical_rows_total']}", f"- Clinical rows overlapping evaluated images: {summary['n_clinical_rows_evaluable']}", f"- Patients total (any visits): {summary['n_patients_total']}", f"- Patients with overlapping evaluated images: {summary['n_patients_evaluable']}", f"- Patients with >=2 labeled visits (full clinical): {summary['n_patients_total_min2_visits']}", f"- Patients with >=2 labeled visits (evaluable overlap): {summary['n_patients_evaluable_min2_visits']}", f"- Overlap patients with only 1 labeled visit: {summary['n_patients_evaluable_single_visit_only']}", "", "## Evaluable Cohort Counts (>=2 labeled visits)", cohort_text, "", "## Interpretation", "- The longitudinal plot uses only patients with at least two labeled visits after intersection with the evaluated image IDs.", "- If MCI transition cohorts are missing, it typically means those patients do not have enough overlapping evaluated visits (or labels) in this run.", "- Use the CSV tables below to inspect diagnosis and transition distributions in full clinical data versus evaluable overlap.", "", "## Output Tables", f"- {paths['full_diagnosis_counts'].name}", f"- {paths['evaluable_diagnosis_counts'].name}", f"- {paths['full_transition_counts'].name}", f"- {paths['evaluable_transition_counts'].name}", f"- {paths['full_cohort_counts'].name}", f"- {paths['evaluable_cohort_counts'].name}", f"- {paths['full_patient_table'].name}", f"- {paths['evaluable_patient_table'].name}", ] paths["summary_md"].write_text("\n".join(lines), encoding="utf-8") return { **summary, "summary_markdown": str(paths["summary_md"]), "summary_json": str(paths["summary_json"]), }