import nibabel as nib import torch import torch.utils.data as data import pathlib as pl import pandas as pd from jaxtyping import Float from typing import Tuple, List, Callable from result import Ok, Err, Result class ADNIDataset(data.Dataset): # type: ignore """ A PyTorch Dataset class for loading and processing MRI and Excel data from the ADNI dataset. """ def __init__( self, mri_data: Float[torch.Tensor, "n_samples, width, height, depth"], xls_data: Float[torch.Tensor, "n_samples, features"], ): """ Args: mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth). xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features). """ self.mri_data = mri_data self.xls_data = xls_data def __len__(self) -> int: """ Returns the number of samples in the dataset. """ return self.mri_data.shape[0] # 0th dimension is the number of samples def __getitem__(self, idx: int) -> Tuple[ Float[torch.Tensor, "width, height, depth"], Float[torch.Tensor, "features"], ]: """ Returns a sample from the dataset at the given index. Args: idx (int): Index of the sample to retrieve. Returns: tuple: A tuple containing the MRI data and Excel data for the sample. """ # Slices the data on the 0th dimension, corresponding to the sample index mri_sample = self.mri_data[idx] xls_sample = self.xls_data[idx] return mri_sample, xls_sample def load_adni_data_from_file( mri_files: List[pl.Path], # List of nibablel files xls_file: pl.Path, # Path to the Excel file xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x, ) -> Result[ADNIDataset, str]: """ Loads MRI and Excel data from the ADNI dataset. Args: mri_files (List[pl.Path]): List of paths to the MRI files. xls_file (pl.Path): Path to the Excel file. Returns: Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message. """ # Load the MRI data mri_data_unstacked = [ torch.from_numpy(nib.load(file).get_fdata()) for file in mri_files # type: ignore # type checking does not work well with nibabel ] mri_data = torch.stack( mri_data_unstacked ) # Stack the list of tensors into a single tensor\ # Load the Excel data xls_data = torch.from_numpy( # type: ignore xls_preprocessor(pd.read_excel(xls_file)).to_numpy() # type: ignore ).float() # Check if the number of samples in MRI and Excel data match if mri_data.shape[0] == xls_data.shape[0]: return Ok(ADNIDataset(mri_data, xls_data)) else: return Err("Loading MRI data failed") def divide_dataset( dataset: ADNIDataset, ratios: Tuple[float, float, float], seed: int = 0, ) -> Result[List[data.Subset[ADNIDataset]], str]: """ Divides the dataset into training, validation, and test sets. Args: dataset (ADNIDataset): The dataset to divide. train_ratio (float): The ratio of the training set. val_ratio (float): The ratio of the validation set. test_ratio (float): The ratio of the test set. Returns: Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message. """ if sum(ratios) != 1.0: return Err("Ratios must sum to 1.0") # Set the random seed for reproducibility gen = torch.Generator().manual_seed(seed) return Ok(data.random_split(dataset, ratios, generator=gen))