|
@@ -3,10 +3,11 @@ import torch
|
|
import torch.utils.data as data
|
|
import torch.utils.data as data
|
|
import pathlib as pl
|
|
import pathlib as pl
|
|
import pandas as pd
|
|
import pandas as pd
|
|
|
|
+from torch.utils.data import Subset, DataLoader
|
|
|
|
|
|
|
|
|
|
from jaxtyping import Float
|
|
from jaxtyping import Float
|
|
-from typing import Tuple, List, Callable
|
|
|
|
|
|
+from typing import Tuple, Iterator, Callable, List
|
|
from result import Ok, Err, Result
|
|
from result import Ok, Err, Result
|
|
|
|
|
|
|
|
|
|
@@ -18,16 +19,17 @@ class ADNIDataset(data.Dataset): # type: ignore
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- mri_data: Float[torch.Tensor, "n_samples, width, height, depth"],
|
|
|
|
- xls_data: Float[torch.Tensor, "n_samples, features"],
|
|
|
|
|
|
+ mri_data: Float[torch.Tensor, "n_samples width height depth"],
|
|
|
|
+ xls_data: Float[torch.Tensor, "n_samples features"],
|
|
|
|
+ device: str = "cuda",
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
Args:
|
|
Args:
|
|
mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
|
|
mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
|
|
xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
|
|
xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
|
|
"""
|
|
"""
|
|
- self.mri_data = mri_data
|
|
|
|
- self.xls_data = xls_data
|
|
|
|
|
|
+ self.mri_data = mri_data.float().to(device)
|
|
|
|
+ self.xls_data = xls_data.float().to(device)
|
|
|
|
|
|
def __len__(self) -> int:
|
|
def __len__(self) -> int:
|
|
"""
|
|
"""
|
|
@@ -36,7 +38,7 @@ class ADNIDataset(data.Dataset): # type: ignore
|
|
return self.mri_data.shape[0] # 0th dimension is the number of samples
|
|
return self.mri_data.shape[0] # 0th dimension is the number of samples
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[
|
|
def __getitem__(self, idx: int) -> Tuple[
|
|
- Float[torch.Tensor, "width, height, depth"],
|
|
|
|
|
|
+ Float[torch.Tensor, "width height depth"],
|
|
Float[torch.Tensor, "features"],
|
|
Float[torch.Tensor, "features"],
|
|
]:
|
|
]:
|
|
"""
|
|
"""
|
|
@@ -56,8 +58,9 @@ class ADNIDataset(data.Dataset): # type: ignore
|
|
|
|
|
|
|
|
|
|
def load_adni_data_from_file(
|
|
def load_adni_data_from_file(
|
|
- mri_files: List[pl.Path], # List of nibablel files
|
|
|
|
|
|
+ mri_files: Iterator[pl.Path], # List of nibablel files
|
|
xls_file: pl.Path, # Path to the Excel file
|
|
xls_file: pl.Path, # Path to the Excel file
|
|
|
|
+ device: str = "cuda",
|
|
xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
|
|
xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
|
|
) -> Result[ADNIDataset, str]:
|
|
) -> Result[ADNIDataset, str]:
|
|
"""
|
|
"""
|
|
@@ -85,7 +88,7 @@ def load_adni_data_from_file(
|
|
|
|
|
|
# Check if the number of samples in MRI and Excel data match
|
|
# Check if the number of samples in MRI and Excel data match
|
|
if mri_data.shape[0] == xls_data.shape[0]:
|
|
if mri_data.shape[0] == xls_data.shape[0]:
|
|
- return Ok(ADNIDataset(mri_data, xls_data))
|
|
|
|
|
|
+ return Ok(ADNIDataset(mri_data, xls_data, device=device))
|
|
else:
|
|
else:
|
|
return Err("Loading MRI data failed")
|
|
return Err("Loading MRI data failed")
|
|
|
|
|
|
@@ -93,7 +96,7 @@ def load_adni_data_from_file(
|
|
def divide_dataset(
|
|
def divide_dataset(
|
|
dataset: ADNIDataset,
|
|
dataset: ADNIDataset,
|
|
ratios: Tuple[float, float, float],
|
|
ratios: Tuple[float, float, float],
|
|
- seed: int = 0,
|
|
|
|
|
|
+ seed: int,
|
|
) -> Result[List[data.Subset[ADNIDataset]], str]:
|
|
) -> Result[List[data.Subset[ADNIDataset]], str]:
|
|
"""
|
|
"""
|
|
Divides the dataset into training, validation, and test sets.
|
|
Divides the dataset into training, validation, and test sets.
|
|
@@ -113,3 +116,22 @@ def divide_dataset(
|
|
# Set the random seed for reproducibility
|
|
# Set the random seed for reproducibility
|
|
gen = torch.Generator().manual_seed(seed)
|
|
gen = torch.Generator().manual_seed(seed)
|
|
return Ok(data.random_split(dataset, ratios, generator=gen))
|
|
return Ok(data.random_split(dataset, ratios, generator=gen))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def initalize_dataloaders(
|
|
|
|
+ datasets: List[Subset[ADNIDataset]],
|
|
|
|
+ batch_size: int = 64,
|
|
|
|
+) -> List[DataLoader[ADNIDataset]]:
|
|
|
|
+ """
|
|
|
|
+ Initializes the DataLoader for the given datasets.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
|
|
|
|
+ batch_size (int): The batch size for the DataLoader.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
|
|
|
|
+ """
|
|
|
|
+ return [
|
|
|
|
+ DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
|
|
|
|
+ ]
|