123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import nibabel as nib
- import torch
- import torch.utils.data as data
- import pathlib as pl
- import pandas as pd
- from torch.utils.data import Subset, DataLoader
- import re
- from jaxtyping import Float
- from typing import Tuple, Iterator, Callable, List
- class ADNIDataset(data.Dataset): # type: ignore
- """
- A PyTorch Dataset class for loading
- and processing MRI and Excel data from the ADNI dataset.
- """
- def __init__(
- self,
- mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
- xls_data: Float[torch.Tensor, "n_samples features"],
- expected_classes: Float[torch.Tensor, "classes"],
- device: str = "cuda",
- ):
- """
- Args:
- mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
- xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
- """
- self.mri_data = mri_data.float().to(device)
- self.xls_data = xls_data.float().to(device)
- self.expected_classes = expected_classes.float().to(device)
- def __len__(self) -> int:
- """
- Returns the number of samples in the dataset.
- """
- return self.mri_data.shape[0] # 0th dimension is the number of samples
- def __getitem__(self, idx: int) -> Tuple[
- Float[torch.Tensor, "channels width height depth"],
- Float[torch.Tensor, "features"],
- Float[torch.Tensor, "classes"],
- ]:
- """
- Returns a sample from the dataset at the given index.
- Args:
- idx (int): Index of the sample to retrieve.
- Returns:
- tuple: A tuple containing the MRI data and Excel data for the sample.
- """
- # Slices the data on the 0th dimension, corresponding to the sample index
- mri_sample = self.mri_data[idx]
- xls_sample = self.xls_data[idx]
- # Assuming expected_classes is a tensor of classes, we return it as well
- expected_classes = self.expected_classes[idx]
- return mri_sample, xls_sample, expected_classes
- def load_adni_data_from_file(
- mri_files: Iterator[pl.Path], # List of nibablel files
- xls_file: pl.Path, # Path to the Excel file
- device: str = "cuda",
- xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
- ) -> ADNIDataset:
- """
- Loads MRI and Excel data from the ADNI dataset.
- Args:
- mri_files (List[pl.Path]): List of paths to the MRI files.
- xls_file (pl.Path): Path to the Excel file.
- Returns:
- Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message.
- """
- # Load the Excel data
- xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore
- # Load the MRI data
- mri_data_unstacked: List[torch.Tensor] = []
- expected_classes_unstacked: List[torch.Tensor] = []
- xls_data_unstacked: List[torch.Tensor] = []
- img_ids: List[int] = []
- for file in mri_files:
- filename = file.stem
- match re.search(r".+?(?=_I)_I(\d+).+", filename):
- case None:
- raise ValueError(
- f"Filename {filename} does not match expected pattern."
- )
- case m:
- img_id = int(m.group(1))
- file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel
- # Read the filename to determine the expected class
- file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros
- if "AD" in filename:
- file_expected_class = torch.tensor([1.0, 0.0])
- elif "CN" in filename:
- file_expected_class = torch.tensor([0.0, 1.0])
- mri_data_unstacked.append(file_mri_data)
- expected_classes_unstacked.append(file_expected_class)
- # Extract the corresponding row from the Excel data using the img_id
- xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
- if xls_row.empty:
- raise ValueError(
- f"No matching row found in Excel data for Image Data ID {img_id}."
- )
- elif len(xls_row) > 1:
- raise ValueError(
- f"Multiple rows found in Excel data for Image Data ID {img_id}."
- )
- file_xls_data = torch.tensor(
- xls_row.drop(columns=["Image Data ID"]).values.flatten() # type: ignore
- )
- xls_data_unstacked.append(file_xls_data)
- img_ids.append(img_id)
- mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
- # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
- xls_data = torch.stack(
- xls_data_unstacked
- ) # Stack the list of tensors into a single tensor
- expected_classes = torch.stack(
- expected_classes_unstacked
- ) # Stack the list of expected classes into a single tensor
- return ADNIDataset(mri_data, xls_data, expected_classes, device=device)
- def divide_dataset(
- dataset: ADNIDataset,
- ratios: Tuple[float, float, float],
- seed: int,
- ) -> List[data.Subset[ADNIDataset]]:
- """
- Divides the dataset into training, validation, and test sets.
- Args:
- dataset (ADNIDataset): The dataset to divide.
- train_ratio (float): The ratio of the training set.
- val_ratio (float): The ratio of the validation set.
- test_ratio (float): The ratio of the test set.
- Returns:
- Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
- """
- if sum(ratios) != 1.0:
- raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
- # Set the random seed for reproducibility
- gen = torch.Generator().manual_seed(seed)
- return data.random_split(dataset, ratios, generator=gen)
- def initalize_dataloaders(
- datasets: List[Subset[ADNIDataset]],
- batch_size: int = 64,
- ) -> List[DataLoader[ADNIDataset]]:
- """
- Initializes the DataLoader for the given datasets.
- Args:
- datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
- batch_size (int): The batch size for the DataLoader.
- Returns:
- List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
- """
- return [
- DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
- ]
|