longitudinal_comparison.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from utils.config import config
  2. import pathlib as pl
  3. import numpy as np
  4. import pandas as pd
  5. from collections.abc import Iterable
  6. RESET_COLOR = "\033[0m"
  7. DIAGNOSIS_COLORS = {
  8. "CN": "\033[92m",
  9. "sMCI": "\033[93m",
  10. "pMCI": "\033[96m",
  11. "AD": "\033[91m",
  12. }
  13. adni_data = pd.read_csv(config["analysis"]["adni_path"])
  14. # Strip leading and trailing spaces from column names
  15. adni_data = adni_data.rename(columns=lambda x: x.strip())
  16. adni_data["EXAMDATE"] = pd.to_datetime(
  17. adni_data["EXAMDATE"].astype("string").str.strip(),
  18. format="%m/%d/%y",
  19. errors="coerce",
  20. )
  21. adni_data = adni_data.sort_values(["PTID", "EXAMDATE"], na_position="last")
  22. plots_dir = (
  23. pl.Path("../output") / pl.Path(config["analysis"]["evaluation_name"]) / "plots"
  24. )
  25. plots_dir.mkdir(parents=True, exist_ok=True)
  26. # The goal for this analysis is to compare the confidence of the model on longitudinal data, where we analyze patients who have multiple images and switched from MCI or CN to AD to patients who stayed stable.
  27. # First filter to only inlude patients with multiple images
  28. multiple_images = adni_data.groupby("PTID").filter(lambda x: len(x) > 1)
  29. # Filter by patients who switched from MCI or CN to AD, accounting for an unknown number of images per patient.
  30. def is_missing_diagnosis(diagnosis: object) -> bool:
  31. if diagnosis is None or diagnosis is pd.NA:
  32. return True
  33. if isinstance(diagnosis, float):
  34. return bool(np.isnan(diagnosis))
  35. return False
  36. def normalize_diagnoses(diagnoses: Iterable[object]) -> np.ndarray:
  37. normalized = [
  38. str(diagnosis).strip()
  39. for diagnosis in diagnoses
  40. if not is_missing_diagnosis(diagnosis)
  41. ]
  42. return np.array(normalized, dtype=str)
  43. def format_diagnoses(diagnoses: Iterable[object]) -> str:
  44. normalized = normalize_diagnoses(diagnoses)
  45. colored = [
  46. (
  47. f"{DIAGNOSIS_COLORS[diagnosis]}{diagnosis}{RESET_COLOR}"
  48. if diagnosis in DIAGNOSIS_COLORS
  49. else diagnosis
  50. )
  51. for diagnosis in normalized
  52. ]
  53. return f"[{', '.join(colored)}]"
  54. # Print a list of patients with multiple images and their diagnoses
  55. print("Patients with multiple images and their diagnoses:")
  56. for ptid, group in multiple_images.groupby("PTID"):
  57. diagnoses = group["Class"].values
  58. print(f"PTID: {ptid}, Diagnoses: {format_diagnoses(diagnoses)}")
  59. def has_mci_dx(diagnoses: Iterable[object]) -> bool:
  60. diagnoses = normalize_diagnoses(diagnoses)
  61. return bool(
  62. np.any((diagnoses == "sMCI") | (diagnoses == "pMCI") | (diagnoses == "CN"))
  63. )
  64. def has_ad_dx(diagnoses: Iterable[object]) -> bool:
  65. diagnoses = normalize_diagnoses(diagnoses)
  66. return bool(np.any(diagnoses == "AD"))
  67. def has_nc_dx(diagnoses: Iterable[object]) -> bool:
  68. diagnoses = normalize_diagnoses(diagnoses)
  69. return bool(np.any(diagnoses == "CN"))
  70. # Switched from MCI to CN or AD, or from CN to MCI or AD
  71. def switched_class(diagnoses: Iterable[object]) -> bool:
  72. return (has_mci_dx(diagnoses) and has_ad_dx(diagnoses)) or (
  73. has_nc_dx(diagnoses) and (has_mci_dx(diagnoses) or has_ad_dx(diagnoses))
  74. )
  75. patients_switched = multiple_images.groupby("PTID").filter(
  76. lambda x: switched_class(x["Class"].values)
  77. )
  78. print("\nPatients who switched from MCI or CN to AD or vice versa:")
  79. for ptid, group in patients_switched.groupby("PTID"):
  80. diagnoses = group["Class"].values
  81. print(f"PTID: {ptid}, Diagnoses: {format_diagnoses(diagnoses)}")
  82. # print length of patients with multiple images, and how many switched from MCI or CN to AD or vice versa
  83. print(f"\nTotal patients with multiple images: {multiple_images['PTID'].nunique()}")
  84. print(
  85. f"Total patients who switched from MCI or CN to AD or vice versa: {patients_switched['PTID'].nunique()}"
  86. )
  87. # Filter just for the patients who switched from CN to AD
  88. patients_switched_cn_to_ad = multiple_images.groupby("PTID").filter(
  89. lambda x: has_nc_dx(x["Class"].values) and has_ad_dx(x["Class"].values)
  90. )
  91. print("\nPatients who switched from CN to AD:")
  92. for ptid, group in patients_switched_cn_to_ad.groupby("PTID"):
  93. diagnoses = group["Class"].values
  94. print(f"PTID: {ptid}, Diagnoses: {format_diagnoses(diagnoses)}")