| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- import nibabel as nib
- import torch
- import torch.utils.data as data
- import pathlib as pl
- import pandas as pd
- from torch.utils.data import Subset, DataLoader
- import re
- import random
- from jaxtyping import Float
- from typing import Tuple, Iterator, Callable, List, Dict
- class ADNIDataset(data.Dataset): # type: ignore
- """
- A PyTorch Dataset class for loading
- and processing MRI and Excel data from the ADNI dataset.
- """
- def __init__(
- self,
- mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
- xls_data: Float[torch.Tensor, "n_samples features"],
- expected_classes: Float[torch.Tensor, "classes"],
- filename_ids: List[int],
- device: str = "cuda",
- ):
- """
- Args:
- mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
- xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
- """
- self.mri_data = mri_data.float().to(device)
- self.xls_data = xls_data.float().to(device)
- self.expected_classes = expected_classes.float().to(device)
- self.filename_ids = filename_ids
- def __len__(self) -> int:
- """
- Returns the number of samples in the dataset.
- """
- return self.mri_data.shape[0] # 0th dimension is the number of samples
- def __getitem__(self, idx: int) -> Tuple[
- Float[torch.Tensor, "channels width height depth"],
- Float[torch.Tensor, "features"],
- Float[torch.Tensor, "classes"],
- int,
- ]:
- """
- Returns a sample from the dataset at the given index.
- Args:
- idx (int): Index of the sample to retrieve.
- Returns:
- tuple: A tuple containing the MRI data and Excel data for the sample.
- """
- # Slices the data on the 0th dimension, corresponding to the sample index
- mri_sample = self.mri_data[idx]
- xls_sample = self.xls_data[idx]
- # Assuming expected_classes is a tensor of classes, we return it as well
- expected_classes = self.expected_classes[idx]
- filename_id = self.filename_ids[idx]
- return mri_sample, xls_sample, expected_classes, filename_id
- def load_adni_data_from_file(
- mri_files: Iterator[pl.Path], # List of nibablel files
- xls_file: pl.Path, # Path to the Excel file
- device: str = "cuda",
- xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
- ) -> ADNIDataset:
- """
- Loads MRI and Excel data from the ADNI dataset.
- Args:
- mri_files (List[pl.Path]): List of paths to the MRI files.
- xls_file (pl.Path): Path to the Excel file.
- Returns:
- ADNIDataset: The loaded dataset.
- """
- # Load the Excel data
- xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore
- # Load the MRI data
- mri_data_unstacked: List[torch.Tensor] = []
- expected_classes_unstacked: List[torch.Tensor] = []
- xls_data_unstacked: List[torch.Tensor] = []
- img_ids: List[int] = []
- for file in mri_files:
- filename = file.stem
- match re.search(r".+?(?=_I)_I(\d+).+", filename):
- case None:
- raise ValueError(
- f"Filename {filename} does not match expected pattern."
- )
- case m:
- img_id = int(m.group(1))
- file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel
- # Read the filename to determine the expected class
- file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros
- if "AD" in filename:
- file_expected_class = torch.tensor([1.0, 0.0])
- elif "NL" in filename:
- file_expected_class = torch.tensor([0.0, 1.0])
- else:
- raise ValueError(
- f"Filename {filename} does not contain a valid class identifier (AD or CN)."
- )
- mri_data_unstacked.append(file_mri_data)
- expected_classes_unstacked.append(file_expected_class)
- # Extract the corresponding row from the Excel data using the img_id
- xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
- if xls_row.empty:
- raise ValueError(
- f"No matching row found in Excel data for Image Data ID {img_id}."
- )
- elif len(xls_row) > 1:
- raise ValueError(
- f"Multiple rows found in Excel data for Image Data ID {img_id}."
- )
- file_xls_data = torch.tensor(
- xls_row.drop(columns=["Image Data ID"]).values.flatten() # type: ignore
- )
- xls_data_unstacked.append(file_xls_data)
- img_ids.append(img_id)
- mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
- # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
- xls_data = torch.stack(
- xls_data_unstacked
- ) # Stack the list of tensors into a single tensor
- expected_classes = torch.stack(
- expected_classes_unstacked
- ) # Stack the list of expected classes into a single tensor
- return ADNIDataset(mri_data, xls_data, expected_classes, img_ids, device=device)
- def divide_dataset(
- dataset: ADNIDataset,
- ratios: Tuple[float, float, float],
- seed: int,
- ) -> List[data.Subset[ADNIDataset]]:
- """
- Divides the dataset into training, validation, and test sets.
- Args:
- dataset (ADNIDataset): The dataset to divide.
- train_ratio (float): The ratio of the training set.
- val_ratio (float): The ratio of the validation set.
- test_ratio (float): The ratio of the test set.
- Returns:
- Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
- """
- if sum(ratios) != 1.0:
- raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
- # Set the random seed for reproducibility
- gen = torch.Generator().manual_seed(seed)
- return data.random_split(dataset, ratios, generator=gen)
- def initalize_dataloaders(
- datasets: List[Subset[ADNIDataset]],
- batch_size: int = 64,
- ) -> List[DataLoader[ADNIDataset]]:
- """
- Initializes the DataLoader for the given datasets.
- Args:
- datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
- batch_size (int): The batch size for the DataLoader.
- Returns:
- List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
- """
- return [
- DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
- ]
- def divide_dataset_by_patient_id(
- dataset: ADNIDataset,
- ptids: List[Tuple[int, str]],
- ratios: Tuple[float, float, float],
- seed: int,
- ) -> List[data.Subset[ADNIDataset]]:
- """
- Divides the dataset into training, validation, and test sets based on patient IDs.
- Ensures that all samples from the same patient are in the same set.
- Args:
- dataset (ADNIDataset): The dataset to divide.
- ptids (List[Tuple[int, str]]): A list of tuples containing image file IDs and their corresponding patient IDs.
- ratios (Tuple[float, float, float]): The ratios for training, validation, and test sets.
- seed (int): The random seed for reproducibility.
- Returns:
- List[data.Subset[ADNIDataset]]: A list of subsets for training, validation, and test sets.
- Notes:
- This split is grouped by PTID, so all images from the same patient are assigned
- to exactly one partition to avoid patient-level leakage across train/val/test.
- """
- if sum(ratios) != 1.0:
- raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
- if not ptids:
- raise ValueError("ptids list cannot be empty.")
- image_to_patient: Dict[int, str] = {}
- for image_id, patient_id in ptids:
- image_id_int = int(image_id)
- patient_id_str = str(patient_id).strip()
- if not patient_id_str or patient_id_str.lower() == "nan":
- raise ValueError(f"Invalid PTID for Image Data ID {image_id_int}.")
- if (
- image_id_int in image_to_patient
- and image_to_patient[image_id_int] != patient_id_str
- ):
- raise ValueError(
- f"Conflicting PTIDs for Image Data ID {image_id_int}: "
- f"{image_to_patient[image_id_int]} vs {patient_id_str}."
- )
- image_to_patient[image_id_int] = patient_id_str
- patient_to_indices: Dict[str, List[int]] = {}
- for idx, image_id in enumerate(dataset.filename_ids):
- if image_id not in image_to_patient:
- raise ValueError(
- f"Missing PTID mapping for dataset Image Data ID {image_id}."
- )
- patient_id = image_to_patient[image_id]
- if patient_id not in patient_to_indices:
- patient_to_indices[patient_id] = []
- patient_to_indices[patient_id].append(idx)
- shuffled_patients = list(patient_to_indices.keys())
- random.Random(seed).shuffle(shuffled_patients)
- train_cutoff = int(len(shuffled_patients) * ratios[0])
- val_cutoff = train_cutoff + int(len(shuffled_patients) * ratios[1])
- train_patients = shuffled_patients[:train_cutoff]
- val_patients = shuffled_patients[train_cutoff:val_cutoff]
- test_patients = shuffled_patients[val_cutoff:]
- train_patient_set = set(train_patients)
- val_patient_set = set(val_patients)
- test_patient_set = set(test_patients)
- if (
- train_patient_set & val_patient_set
- or train_patient_set & test_patient_set
- or val_patient_set & test_patient_set
- ):
- raise ValueError("Patient separation violated across train/val/test splits.")
- all_patients = set(patient_to_indices.keys())
- if train_patient_set | val_patient_set | test_patient_set != all_patients:
- raise ValueError("Not all patients were assigned to a split.")
- train_indices = [
- idx for patient_id in train_patients for idx in patient_to_indices[patient_id]
- ]
- val_indices = [
- idx for patient_id in val_patients for idx in patient_to_indices[patient_id]
- ]
- test_indices = [
- idx for patient_id in test_patients for idx in patient_to_indices[patient_id]
- ]
- train_index_set = set(train_indices)
- val_index_set = set(val_indices)
- test_index_set = set(test_indices)
- if (
- train_index_set & val_index_set
- or train_index_set & test_index_set
- or val_index_set & test_index_set
- ):
- raise ValueError("Sample index overlap detected across train/val/test splits.")
- all_split_indices = train_index_set | val_index_set | test_index_set
- expected_indices = set(range(len(dataset)))
- if all_split_indices != expected_indices:
- raise ValueError(
- "Split coverage check failed: not all dataset samples are assigned exactly once."
- )
- return [
- Subset(dataset, train_indices),
- Subset(dataset, val_indices),
- Subset(dataset, test_indices),
- ]
|