"""Metadata and patient-label utilities for the PET/SUV dataset.""" from __future__ import annotations import pathlib import re import pandas as pd def get_meta_data(data_raw_path: str | pathlib.Path) -> pd.DataFrame: """Collect PET/SEG file metadata from a raw dataset directory. The function expects paths with the form:: /VISIT_/_VISIT___.nii.gz where ``modality`` is either ``PET`` or ``SEG``. The returned dataframe is indexed by ``patient_id``, ``organ`` and ``visit`` and contains PET/SEG filenames and paths in a wide format. Parameters ---------- data_raw_path: Root directory containing the NIfTI files. Returns ------- pandas.DataFrame Metadata table with columns ``PET_filename``, ``PET_path``, ``SEG_filename`` and ``SEG_path``. """ pattern = re.compile( r"(?PNIX-LJU-D\d+-IRAE-A\d+)/VISIT_(?P\d+)/" r"(?P=patient_id)_VISIT_(?P=visit)_(?P.+)_(?PPET|SEG)\.nii\.gz$" ) rows: list[dict[str, object]] = [] for item in pathlib.Path(data_raw_path).rglob("*.nii.gz"): match = pattern.search(str(item)) if match: rows.append( { "filename": item.name, "path": str(item), "patient_id": match.group("patient_id"), "visit": int(match.group("visit")), "organ": match.group("organ"), "modality": match.group("modality"), } ) else: print(f"Could not parse: {item}") df = pd.DataFrame(rows) if df.empty: return df index_cols = ["patient_id", "organ", "visit"] counts = df.groupby(index_cols + ["modality"]).size() duplicates = counts[counts > 1] if not duplicates.empty: raise ValueError(f"Duplicate images found:\n{duplicates}") df_wide = df.pivot(index=index_cols, columns="modality", values=["filename", "path"]) # Flatten MultiIndex columns: (field, modality) -> MODALITY_field. df_wide.columns = [f"{modality}_{field}" for field, modality in df_wide.columns] expected_cols = ["PET_filename", "PET_path", "SEG_filename", "SEG_path"] for col in expected_cols: if col not in df_wide.columns: df_wide[col] = pd.NA return df_wide[expected_cols].sort_index() def flag_corrupted_files(df: pd.DataFrame) -> pd.DataFrame: """Add an ``is_corrupted`` flag for known corrupted patient-organ-visits. Parameters ---------- df: Metadata dataframe indexed by ``patient_id``, ``organ`` and ``visit``. Returns ------- pandas.DataFrame Copy of ``df`` with an added boolean column ``is_corrupted``. """ df = df.copy() prefix = "NIX-LJU-D2002-IRAE-A" corrupted = { (13, "Lung", 1), (14, "Lung", 2), (24, "Lung", 0), (1, "Colon", 0), (16, "Colon", 0), } corrupted_ids = [ (f"{prefix}{patient_id:03d}", organ, visit) for patient_id, organ, visit in corrupted ] df["is_corrupted"] = False existing_ids = [idx for idx in corrupted_ids if idx in df.index] missing_ids = [idx for idx in corrupted_ids if idx not in df.index] if missing_ids: print("Warning: these corrupted IDs were not found in the dataframe:") for idx in missing_ids: print(idx) df.loc[existing_ids, "is_corrupted"] = True return df def flag_AE_patients(df: pd.DataFrame) -> pd.DataFrame: """Add an ``is_AE_patient`` flag for known AE patient-organ pairs. The AE label is assigned at patient-organ level and therefore applies to all visits of a given patient-organ pair. Parameters ---------- df: Metadata dataframe indexed by ``patient_id``, ``organ`` and ``visit``. Returns ------- pandas.DataFrame Copy of ``df`` with an added boolean column ``is_AE_patient``. """ df = df.copy() prefix = "NIX-LJU-D2002-IRAE-A" ae_patients = { "Thyroid": [1, 2, 14, 17, 18, 20, 21, 22], "Lung": [1, 2, 4, 20], "Colon": [5, 7, 18, 28], } ae_keys = { (f"{prefix}{patient:03d}", organ) for organ, patients in ae_patients.items() for patient in patients } index_df = df.index.to_frame(index=False) df["is_AE_patient"] = [ (patient_id, organ) in ae_keys for patient_id, organ in zip(index_df["patient_id"], index_df["organ"]) ] return df def get_AE_statistics(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.Series, int]: """Summarize AE patient-organ labels. Parameters ---------- df: Metadata dataframe with a boolean column ``is_AE_patient``. Returns ------- tuple ``(AE_patient_organs, AE_counts_by_organ, n_AE_patient_organ_pairs)``. """ if "is_AE_patient" not in df.columns: raise ValueError("DataFrame must contain column 'is_AE_patient'.") AE_patient_organs = ( df.loc[df["is_AE_patient"]] .index.to_frame(index=False)[["patient_id", "organ"]] .drop_duplicates() .sort_values(["organ", "patient_id"]) .reset_index(drop=True) ) AE_counts_by_organ = ( AE_patient_organs.groupby("organ") .size() .rename("n_AE_patient_organ_pairs") .sort_index() ) n_AE_patient_organ_pairs = len(AE_patient_organs) return AE_patient_organs, AE_counts_by_organ, n_AE_patient_organ_pairs