|
@@ -4,11 +4,11 @@ import torch.utils.data as data
|
|
|
import pathlib as pl
|
|
|
import pandas as pd
|
|
|
from torch.utils.data import Subset, DataLoader
|
|
|
+import re
|
|
|
|
|
|
|
|
|
from jaxtyping import Float
|
|
|
from typing import Tuple, Iterator, Callable, List
|
|
|
-from result import Ok, Err, Result
|
|
|
|
|
|
|
|
|
class ADNIDataset(data.Dataset): # type: ignore
|
|
@@ -19,17 +19,19 @@ class ADNIDataset(data.Dataset): # type: ignore
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- mri_data: Float[torch.Tensor, "n_samples width height depth"],
|
|
|
+ mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
|
|
|
xls_data: Float[torch.Tensor, "n_samples features"],
|
|
|
+ expected_classes: Float[torch.Tensor, "classes"],
|
|
|
device: str = "cuda",
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
- mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
|
|
|
+ mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
|
|
|
xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
|
|
|
"""
|
|
|
self.mri_data = mri_data.float().to(device)
|
|
|
self.xls_data = xls_data.float().to(device)
|
|
|
+ self.expected_classes = expected_classes.float().to(device)
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""
|
|
@@ -38,8 +40,9 @@ class ADNIDataset(data.Dataset): # type: ignore
|
|
|
return self.mri_data.shape[0] # 0th dimension is the number of samples
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[
|
|
|
- Float[torch.Tensor, "width height depth"],
|
|
|
+ Float[torch.Tensor, "channels width height depth"],
|
|
|
Float[torch.Tensor, "features"],
|
|
|
+ Float[torch.Tensor, "classes"],
|
|
|
]:
|
|
|
"""
|
|
|
Returns a sample from the dataset at the given index.
|
|
@@ -54,7 +57,10 @@ class ADNIDataset(data.Dataset): # type: ignore
|
|
|
# Slices the data on the 0th dimension, corresponding to the sample index
|
|
|
mri_sample = self.mri_data[idx]
|
|
|
xls_sample = self.xls_data[idx]
|
|
|
- return mri_sample, xls_sample
|
|
|
+ # Assuming expected_classes is a tensor of classes, we return it as well
|
|
|
+ expected_classes = self.expected_classes[idx]
|
|
|
+
|
|
|
+ return mri_sample, xls_sample, expected_classes
|
|
|
|
|
|
|
|
|
def load_adni_data_from_file(
|
|
@@ -62,7 +68,7 @@ def load_adni_data_from_file(
|
|
|
xls_file: pl.Path, # Path to the Excel file
|
|
|
device: str = "cuda",
|
|
|
xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
|
|
|
-) -> Result[ADNIDataset, str]:
|
|
|
+) -> ADNIDataset:
|
|
|
"""
|
|
|
Loads MRI and Excel data from the ADNI dataset.
|
|
|
|
|
@@ -73,31 +79,72 @@ def load_adni_data_from_file(
|
|
|
Returns:
|
|
|
Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message.
|
|
|
"""
|
|
|
- # Load the MRI data
|
|
|
- mri_data_unstacked = [
|
|
|
- torch.from_numpy(nib.load(file).get_fdata()) for file in mri_files # type: ignore # type checking does not work well with nibabel
|
|
|
- ]
|
|
|
- mri_data = torch.stack(
|
|
|
- mri_data_unstacked
|
|
|
- ) # Stack the list of tensors into a single tensor\
|
|
|
-
|
|
|
# Load the Excel data
|
|
|
- xls_data = torch.from_numpy( # type: ignore
|
|
|
- xls_preprocessor(pd.read_excel(xls_file)).to_numpy() # type: ignore
|
|
|
- ).float()
|
|
|
+ xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore
|
|
|
|
|
|
- # Check if the number of samples in MRI and Excel data match
|
|
|
- if mri_data.shape[0] == xls_data.shape[0]:
|
|
|
- return Ok(ADNIDataset(mri_data, xls_data, device=device))
|
|
|
- else:
|
|
|
- return Err("Loading MRI data failed")
|
|
|
+ # Load the MRI data
|
|
|
+ mri_data_unstacked: List[torch.Tensor] = []
|
|
|
+ expected_classes_unstacked: List[torch.Tensor] = []
|
|
|
+ xls_data_unstacked: List[torch.Tensor] = []
|
|
|
+ img_ids: List[int] = []
|
|
|
+ for file in mri_files:
|
|
|
+ filename = file.stem
|
|
|
+ match re.search(r".+?(?=_I)_I(\d+).+", filename):
|
|
|
+ case None:
|
|
|
+ raise ValueError(
|
|
|
+ f"Filename {filename} does not match expected pattern."
|
|
|
+ )
|
|
|
+ case m:
|
|
|
+ img_id = int(m.group(1))
|
|
|
+
|
|
|
+ file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel
|
|
|
+
|
|
|
+ # Read the filename to determine the expected class
|
|
|
+ file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros
|
|
|
+
|
|
|
+ if "AD" in filename:
|
|
|
+ file_expected_class = torch.tensor([1.0, 0.0])
|
|
|
+ elif "CN" in filename:
|
|
|
+ file_expected_class = torch.tensor([0.0, 1.0])
|
|
|
+
|
|
|
+ mri_data_unstacked.append(file_mri_data)
|
|
|
+ expected_classes_unstacked.append(file_expected_class)
|
|
|
+ # Extract the corresponding row from the Excel data using the img_id
|
|
|
+ xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
|
|
|
+ if xls_row.empty:
|
|
|
+ raise ValueError(
|
|
|
+ f"No matching row found in Excel data for Image Data ID {img_id}."
|
|
|
+ )
|
|
|
+ elif len(xls_row) > 1:
|
|
|
+ raise ValueError(
|
|
|
+ f"Multiple rows found in Excel data for Image Data ID {img_id}."
|
|
|
+ )
|
|
|
+ file_xls_data = torch.tensor(
|
|
|
+ xls_row.drop(columns=["Image Data ID"]).values.flatten() # type: ignore
|
|
|
+ )
|
|
|
+
|
|
|
+ xls_data_unstacked.append(file_xls_data)
|
|
|
+ img_ids.append(img_id)
|
|
|
+
|
|
|
+ mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
|
|
|
+ # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
|
|
|
+
|
|
|
+ xls_data = torch.stack(
|
|
|
+ xls_data_unstacked
|
|
|
+ ) # Stack the list of tensors into a single tensor
|
|
|
+
|
|
|
+ expected_classes = torch.stack(
|
|
|
+ expected_classes_unstacked
|
|
|
+ ) # Stack the list of expected classes into a single tensor
|
|
|
+
|
|
|
+ return ADNIDataset(mri_data, xls_data, expected_classes, device=device)
|
|
|
|
|
|
|
|
|
def divide_dataset(
|
|
|
dataset: ADNIDataset,
|
|
|
ratios: Tuple[float, float, float],
|
|
|
seed: int,
|
|
|
-) -> Result[List[data.Subset[ADNIDataset]], str]:
|
|
|
+) -> List[data.Subset[ADNIDataset]]:
|
|
|
"""
|
|
|
Divides the dataset into training, validation, and test sets.
|
|
|
|
|
@@ -111,11 +158,11 @@ def divide_dataset(
|
|
|
Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
|
|
|
"""
|
|
|
if sum(ratios) != 1.0:
|
|
|
- return Err("Ratios must sum to 1.0")
|
|
|
+ raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
|
|
|
|
|
|
# Set the random seed for reproducibility
|
|
|
gen = torch.Generator().manual_seed(seed)
|
|
|
- return Ok(data.random_split(dataset, ratios, generator=gen))
|
|
|
+ return data.random_split(dataset, ratios, generator=gen)
|
|
|
|
|
|
|
|
|
def initalize_dataloaders(
|