# pyright: basic from __future__ import annotations from pathlib import Path from typing import Any import pandas as pd from torch.utils.data import ConcatDataset, DataLoader, Subset from data.dataset import ( ADNIDataset, divide_dataset_by_patient_id, initalize_dataloaders, load_adni_data_from_file, ) def xls_preprocess(df: pd.DataFrame) -> pd.DataFrame: data = df[["Image Data ID", "Sex", "Age (current)"]].copy() data["Sex"] = data["Sex"].astype(str).str.strip() data = data.replace({"M": 0, "F": 1}) return data def _patient_ids(xls_file: Path) -> list[tuple[int, str]]: ptid_df = pd.read_csv(xls_file) ptid_df.columns = ptid_df.columns.str.strip() ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna( subset=["Image Data ID", "PTID"] ) ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int) ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip() ptid_df = ptid_df[ptid_df["PTID"] != ""] return list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist())) def build_dataset(config: dict[str, Any], root_dir: Path) -> tuple[ADNIDataset, Path]: mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii") xls_file = (root_dir / config["data"]["xls_file_path"]).resolve() dataset = load_adni_data_from_file( mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_preprocess, ) return dataset, xls_file def build_dataset_splits( config: dict[str, Any], dataset: ADNIDataset, xls_file: Path, seed: int, ) -> list[Subset[ADNIDataset]]: return divide_dataset_by_patient_id( dataset, _patient_ids(xls_file), tuple(config["data"]["data_splits"]), seed=seed, ) def build_dataset_and_test_loader( config: dict[str, Any], root_dir: Path, seed: int, ) -> tuple[ADNIDataset, DataLoader]: dataset, xls_file = build_dataset(config, root_dir) splits = build_dataset_splits(config, dataset, xls_file, seed=seed) _, _, test_loader = initalize_dataloaders( splits, batch_size=int(config["training"]["batch_size"]), ) return dataset, test_loader def build_holdout_loader( config: dict[str, Any], root_dir: Path, seed: int, ) -> DataLoader: dataset, xls_file = build_dataset(config, root_dir) splits = build_dataset_splits(config, dataset, xls_file, seed=seed) _, val_loader, test_loader = initalize_dataloaders(splits, batch_size=1) combined = ConcatDataset([val_loader.dataset, test_loader.dataset]) return DataLoader(combined, batch_size=1, shuffle=False)