123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import nibabel as nib
- import torch
- import torch.utils.data as data
- import pathlib as pl
- import pandas as pd
- from jaxtyping import Float
- from typing import Tuple, List, Callable
- from result import Ok, Err, Result
- class ADNIDataset(data.Dataset): # type: ignore
- """
- A PyTorch Dataset class for loading
- and processing MRI and Excel data from the ADNI dataset.
- """
- def __init__(
- self,
- mri_data: Float[torch.Tensor, "n_samples, width, height, depth"],
- xls_data: Float[torch.Tensor, "n_samples, features"],
- ):
- """
- Args:
- mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
- xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
- """
- self.mri_data = mri_data
- self.xls_data = xls_data
- def __len__(self) -> int:
- """
- Returns the number of samples in the dataset.
- """
- return self.mri_data.shape[0] # 0th dimension is the number of samples
- def __getitem__(self, idx: int) -> Tuple[
- Float[torch.Tensor, "width, height, depth"],
- Float[torch.Tensor, "features"],
- ]:
- """
- Returns a sample from the dataset at the given index.
- Args:
- idx (int): Index of the sample to retrieve.
- Returns:
- tuple: A tuple containing the MRI data and Excel data for the sample.
- """
- # Slices the data on the 0th dimension, corresponding to the sample index
- mri_sample = self.mri_data[idx]
- xls_sample = self.xls_data[idx]
- return mri_sample, xls_sample
- def load_adni_data_from_file(
- mri_files: List[pl.Path], # List of nibablel files
- xls_file: pl.Path, # Path to the Excel file
- xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
- ) -> Result[ADNIDataset, str]:
- """
- Loads MRI and Excel data from the ADNI dataset.
- Args:
- mri_files (List[pl.Path]): List of paths to the MRI files.
- xls_file (pl.Path): Path to the Excel file.
- Returns:
- Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message.
- """
- # Load the MRI data
- mri_data_unstacked = [
- torch.from_numpy(nib.load(file).get_fdata()) for file in mri_files # type: ignore # type checking does not work well with nibabel
- ]
- mri_data = torch.stack(
- mri_data_unstacked
- ) # Stack the list of tensors into a single tensor\
- # Load the Excel data
- xls_data = torch.from_numpy( # type: ignore
- xls_preprocessor(pd.read_excel(xls_file)).to_numpy() # type: ignore
- ).float()
- # Check if the number of samples in MRI and Excel data match
- if mri_data.shape[0] == xls_data.shape[0]:
- return Ok(ADNIDataset(mri_data, xls_data))
- else:
- return Err("Loading MRI data failed")
- def divide_dataset(
- dataset: ADNIDataset,
- ratios: Tuple[float, float, float],
- seed: int = 0,
- ) -> Result[List[data.Subset[ADNIDataset]], str]:
- """
- Divides the dataset into training, validation, and test sets.
- Args:
- dataset (ADNIDataset): The dataset to divide.
- train_ratio (float): The ratio of the training set.
- val_ratio (float): The ratio of the validation set.
- test_ratio (float): The ratio of the test set.
- Returns:
- Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
- """
- if sum(ratios) != 1.0:
- return Err("Ratios must sum to 1.0")
- # Set the random seed for reproducibility
- gen = torch.Generator().manual_seed(seed)
- return Ok(data.random_split(dataset, ratios, generator=gen))
|