data_pipeline.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import pandas as pd
  6. from torch.utils.data import ConcatDataset, DataLoader, Subset
  7. from data.dataset import (
  8. ADNIDataset,
  9. divide_dataset_by_patient_id,
  10. initalize_dataloaders,
  11. load_adni_data_from_file,
  12. )
  13. def xls_preprocess(df: pd.DataFrame) -> pd.DataFrame:
  14. data = df[["Image Data ID", "Sex", "Age (current)"]].copy()
  15. data["Sex"] = data["Sex"].astype(str).str.strip()
  16. data = data.replace({"M": 0, "F": 1})
  17. return data
  18. def _patient_ids(xls_file: Path) -> list[tuple[int, str]]:
  19. ptid_df = pd.read_csv(xls_file)
  20. ptid_df.columns = ptid_df.columns.str.strip()
  21. ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(
  22. subset=["Image Data ID", "PTID"]
  23. )
  24. ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
  25. ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
  26. ptid_df = ptid_df[ptid_df["PTID"] != ""]
  27. return list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
  28. def build_dataset(config: dict[str, Any], root_dir: Path) -> tuple[ADNIDataset, Path]:
  29. mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii")
  30. xls_file = (root_dir / config["data"]["xls_file_path"]).resolve()
  31. dataset = load_adni_data_from_file(
  32. mri_files,
  33. xls_file,
  34. device=config["training"]["device"],
  35. xls_preprocessor=xls_preprocess,
  36. )
  37. return dataset, xls_file
  38. def build_dataset_splits(
  39. config: dict[str, Any],
  40. dataset: ADNIDataset,
  41. xls_file: Path,
  42. seed: int,
  43. ) -> list[Subset[ADNIDataset]]:
  44. return divide_dataset_by_patient_id(
  45. dataset,
  46. _patient_ids(xls_file),
  47. tuple(config["data"]["data_splits"]),
  48. seed=seed,
  49. )
  50. def build_dataset_and_test_loader(
  51. config: dict[str, Any],
  52. root_dir: Path,
  53. seed: int,
  54. ) -> tuple[ADNIDataset, DataLoader]:
  55. dataset, xls_file = build_dataset(config, root_dir)
  56. splits = build_dataset_splits(config, dataset, xls_file, seed=seed)
  57. _, _, test_loader = initalize_dataloaders(
  58. splits,
  59. batch_size=int(config["training"]["batch_size"]),
  60. )
  61. return dataset, test_loader
  62. def build_holdout_loader(
  63. config: dict[str, Any],
  64. root_dir: Path,
  65. seed: int,
  66. ) -> DataLoader:
  67. dataset, xls_file = build_dataset(config, root_dir)
  68. splits = build_dataset_splits(config, dataset, xls_file, seed=seed)
  69. _, val_loader, test_loader = initalize_dataloaders(splits, batch_size=1)
  70. combined = ConcatDataset([val_loader.dataset, test_loader.dataset])
  71. return DataLoader(combined, batch_size=1, shuffle=False)