| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- from utils.config import config
- import pathlib as pl
- import numpy as np
- import pandas as pd
- from collections.abc import Iterable
- RESET_COLOR = "\033[0m"
- DIAGNOSIS_COLORS = {
- "CN": "\033[92m",
- "sMCI": "\033[93m",
- "pMCI": "\033[96m",
- "AD": "\033[91m",
- }
- adni_data = pd.read_csv(config["analysis"]["adni_path"])
- # Strip leading and trailing spaces from column names
- adni_data = adni_data.rename(columns=lambda x: x.strip())
- adni_data["EXAMDATE"] = pd.to_datetime(
- adni_data["EXAMDATE"].astype("string").str.strip(),
- format="%m/%d/%y",
- errors="coerce",
- )
- adni_data = adni_data.sort_values(["PTID", "EXAMDATE"], na_position="last")
- plots_dir = (
- pl.Path("../output") / pl.Path(config["analysis"]["evaluation_name"]) / "plots"
- )
- plots_dir.mkdir(parents=True, exist_ok=True)
- # The goal for this analysis is to compare the confidence of the model on longitudinal data, where we analyze patients who have multiple images and switched from MCI or CN to AD to patients who stayed stable.
- # First filter to only inlude patients with multiple images
- multiple_images = adni_data.groupby("PTID").filter(lambda x: len(x) > 1)
- # Filter by patients who switched from MCI or CN to AD, accounting for an unknown number of images per patient.
- def is_missing_diagnosis(diagnosis: object) -> bool:
- if diagnosis is None or diagnosis is pd.NA:
- return True
- if isinstance(diagnosis, float):
- return bool(np.isnan(diagnosis))
- return False
- def normalize_diagnoses(diagnoses: Iterable[object]) -> np.ndarray:
- normalized = [
- str(diagnosis).strip()
- for diagnosis in diagnoses
- if not is_missing_diagnosis(diagnosis)
- ]
- return np.array(normalized, dtype=str)
- def format_diagnoses(diagnoses: Iterable[object]) -> str:
- normalized = normalize_diagnoses(diagnoses)
- colored = [
- (
- f"{DIAGNOSIS_COLORS[diagnosis]}{diagnosis}{RESET_COLOR}"
- if diagnosis in DIAGNOSIS_COLORS
- else diagnosis
- )
- for diagnosis in normalized
- ]
- return f"[{', '.join(colored)}]"
- # Print a list of patients with multiple images and their diagnoses
- print("Patients with multiple images and their diagnoses:")
- for ptid, group in multiple_images.groupby("PTID"):
- diagnoses = group["Class"].values
- print(f"PTID: {ptid}, Diagnoses: {format_diagnoses(diagnoses)}")
- def has_mci_dx(diagnoses: Iterable[object]) -> bool:
- diagnoses = normalize_diagnoses(diagnoses)
- return bool(
- np.any((diagnoses == "sMCI") | (diagnoses == "pMCI") | (diagnoses == "CN"))
- )
- def has_ad_dx(diagnoses: Iterable[object]) -> bool:
- diagnoses = normalize_diagnoses(diagnoses)
- return bool(np.any(diagnoses == "AD"))
- def has_nc_dx(diagnoses: Iterable[object]) -> bool:
- diagnoses = normalize_diagnoses(diagnoses)
- return bool(np.any(diagnoses == "CN"))
- # Switched from MCI to CN or AD, or from CN to MCI or AD
- def switched_class(diagnoses: Iterable[object]) -> bool:
- return (has_mci_dx(diagnoses) and has_ad_dx(diagnoses)) or (
- has_nc_dx(diagnoses) and (has_mci_dx(diagnoses) or has_ad_dx(diagnoses))
- )
- patients_switched = multiple_images.groupby("PTID").filter(
- lambda x: switched_class(x["Class"].values)
- )
- print("\nPatients who switched from MCI or CN to AD or vice versa:")
- for ptid, group in patients_switched.groupby("PTID"):
- diagnoses = group["Class"].values
- print(f"PTID: {ptid}, Diagnoses: {format_diagnoses(diagnoses)}")
- # print length of patients with multiple images, and how many switched from MCI or CN to AD or vice versa
- print(f"\nTotal patients with multiple images: {multiple_images['PTID'].nunique()}")
- print(
- f"Total patients who switched from MCI or CN to AD or vice versa: {patients_switched['PTID'].nunique()}"
- )
- # Filter just for the patients who switched from CN to AD
- patients_switched_cn_to_ad = multiple_images.groupby("PTID").filter(
- lambda x: has_nc_dx(x["Class"].values) and has_ad_dx(x["Class"].values)
- )
- print("\nPatients who switched from CN to AD:")
- for ptid, group in patients_switched_cn_to_ad.groupby("PTID"):
- diagnoses = group["Class"].values
- print(f"PTID: {ptid}, Diagnoses: {format_diagnoses(diagnoses)}")
|