| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # pyright: basic
- from __future__ import annotations
- from pathlib import Path
- from typing import Any
- import pandas as pd
- from torch.utils.data import ConcatDataset, DataLoader, Subset
- from data.dataset import (
- ADNIDataset,
- divide_dataset_by_patient_id,
- initalize_dataloaders,
- load_adni_data_from_file,
- )
- def xls_preprocess(df: pd.DataFrame) -> pd.DataFrame:
- data = df[["Image Data ID", "Sex", "Age (current)"]].copy()
- data["Sex"] = data["Sex"].astype(str).str.strip()
- data = data.replace({"M": 0, "F": 1})
- return data
- def _patient_ids(xls_file: Path) -> list[tuple[int, str]]:
- ptid_df = pd.read_csv(xls_file)
- ptid_df.columns = ptid_df.columns.str.strip()
- ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(
- subset=["Image Data ID", "PTID"]
- )
- ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
- ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
- ptid_df = ptid_df[ptid_df["PTID"] != ""]
- return list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
- def build_dataset(config: dict[str, Any], root_dir: Path) -> tuple[ADNIDataset, Path]:
- mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii")
- xls_file = (root_dir / config["data"]["xls_file_path"]).resolve()
- dataset = load_adni_data_from_file(
- mri_files,
- xls_file,
- device=config["training"]["device"],
- xls_preprocessor=xls_preprocess,
- )
- return dataset, xls_file
- def build_dataset_splits(
- config: dict[str, Any],
- dataset: ADNIDataset,
- xls_file: Path,
- seed: int,
- ) -> list[Subset[ADNIDataset]]:
- return divide_dataset_by_patient_id(
- dataset,
- _patient_ids(xls_file),
- tuple(config["data"]["data_splits"]),
- seed=seed,
- )
- def build_dataset_and_test_loader(
- config: dict[str, Any],
- root_dir: Path,
- seed: int,
- ) -> tuple[ADNIDataset, DataLoader]:
- dataset, xls_file = build_dataset(config, root_dir)
- splits = build_dataset_splits(config, dataset, xls_file, seed=seed)
- _, _, test_loader = initalize_dataloaders(
- splits,
- batch_size=int(config["training"]["batch_size"]),
- )
- return dataset, test_loader
- def build_holdout_loader(
- config: dict[str, Any],
- root_dir: Path,
- seed: int,
- ) -> DataLoader:
- dataset, xls_file = build_dataset(config, root_dir)
- splits = build_dataset_splits(config, dataset, xls_file, seed=seed)
- _, val_loader, test_loader = initalize_dataloaders(splits, batch_size=1)
- combined = ConcatDataset([val_loader.dataset, test_loader.dataset])
- return DataLoader(combined, batch_size=1, shuffle=False)
|