import nibabel as nib import torch import torch.utils.data as data import pathlib as pl import pandas as pd from torch.utils.data import Subset, DataLoader import re from jaxtyping import Float from typing import Tuple, Iterator, Callable, List class ADNIDataset(data.Dataset): # type: ignore """ A PyTorch Dataset class for loading and processing MRI and Excel data from the ADNI dataset. """ def __init__( self, mri_data: Float[torch.Tensor, "n_samples channels width height depth"], xls_data: Float[torch.Tensor, "n_samples features"], expected_classes: Float[torch.Tensor, "classes"], device: str = "cuda", ): """ Args: mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth). xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features). """ self.mri_data = mri_data.float().to(device) self.xls_data = xls_data.float().to(device) self.expected_classes = expected_classes.float().to(device) def __len__(self) -> int: """ Returns the number of samples in the dataset. """ return self.mri_data.shape[0] # 0th dimension is the number of samples def __getitem__(self, idx: int) -> Tuple[ Float[torch.Tensor, "channels width height depth"], Float[torch.Tensor, "features"], Float[torch.Tensor, "classes"], ]: """ Returns a sample from the dataset at the given index. Args: idx (int): Index of the sample to retrieve. Returns: tuple: A tuple containing the MRI data and Excel data for the sample. """ # Slices the data on the 0th dimension, corresponding to the sample index mri_sample = self.mri_data[idx] xls_sample = self.xls_data[idx] # Assuming expected_classes is a tensor of classes, we return it as well expected_classes = self.expected_classes[idx] return mri_sample, xls_sample, expected_classes def load_adni_data_from_file( mri_files: Iterator[pl.Path], # List of nibablel files xls_file: pl.Path, # Path to the Excel file device: str = "cuda", xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x, ) -> ADNIDataset: """ Loads MRI and Excel data from the ADNI dataset. Args: mri_files (List[pl.Path]): List of paths to the MRI files. xls_file (pl.Path): Path to the Excel file. Returns: Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message. """ # Load the Excel data xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore # Load the MRI data mri_data_unstacked: List[torch.Tensor] = [] expected_classes_unstacked: List[torch.Tensor] = [] xls_data_unstacked: List[torch.Tensor] = [] img_ids: List[int] = [] for file in mri_files: filename = file.stem match re.search(r".+?(?=_I)_I(\d+).+", filename): case None: raise ValueError( f"Filename {filename} does not match expected pattern." ) case m: img_id = int(m.group(1)) file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel # Read the filename to determine the expected class file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros if "AD" in filename: file_expected_class = torch.tensor([1.0, 0.0]) elif "CN" in filename: file_expected_class = torch.tensor([0.0, 1.0]) mri_data_unstacked.append(file_mri_data) expected_classes_unstacked.append(file_expected_class) # Extract the corresponding row from the Excel data using the img_id xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id] if xls_row.empty: raise ValueError( f"No matching row found in Excel data for Image Data ID {img_id}." ) elif len(xls_row) > 1: raise ValueError( f"Multiple rows found in Excel data for Image Data ID {img_id}." ) file_xls_data = torch.tensor( xls_row.drop(columns=["Image Data ID"]).values.flatten() # type: ignore ) xls_data_unstacked.append(file_xls_data) img_ids.append(img_id) mri_data = torch.stack(mri_data_unstacked).unsqueeze(1) # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension xls_data = torch.stack( xls_data_unstacked ) # Stack the list of tensors into a single tensor expected_classes = torch.stack( expected_classes_unstacked ) # Stack the list of expected classes into a single tensor return ADNIDataset(mri_data, xls_data, expected_classes, device=device) def divide_dataset( dataset: ADNIDataset, ratios: Tuple[float, float, float], seed: int, ) -> List[data.Subset[ADNIDataset]]: """ Divides the dataset into training, validation, and test sets. Args: dataset (ADNIDataset): The dataset to divide. train_ratio (float): The ratio of the training set. val_ratio (float): The ratio of the validation set. test_ratio (float): The ratio of the test set. Returns: Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message. """ if sum(ratios) != 1.0: raise ValueError(f"Ratios must sum to 1.0, got {ratios}.") # Set the random seed for reproducibility gen = torch.Generator().manual_seed(seed) return data.random_split(dataset, ratios, generator=gen) def initalize_dataloaders( datasets: List[Subset[ADNIDataset]], batch_size: int = 64, ) -> List[DataLoader[ADNIDataset]]: """ Initializes the DataLoader for the given datasets. Args: datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for. batch_size (int): The batch size for the DataLoader. Returns: List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets. """ return [ DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets ]