Nicholas Schense 1 месяц назад
Сommit
b7634119c8
8 измененных файлов с 453 добавлено и 0 удалено
  1. 0 0
      conflg.toml
  2. 115 0
      data/dataset.py
  3. 72 0
      model/cnn.py
  4. 179 0
      model/layers.py
  5. 2 0
      pyproject.toml
  6. 20 0
      train_model.py
  7. 65 0
      utils/config.py
  8. 0 0
      utils/training.py

+ 0 - 0
conflg.toml


+ 115 - 0
data/dataset.py

@@ -0,0 +1,115 @@
+import nibabel as nib
+import torch
+import torch.utils.data as data
+import pathlib as pl
+import pandas as pd
+
+
+from jaxtyping import Float
+from typing import Tuple, List, Callable
+from result import Ok, Err, Result
+
+
+class ADNIDataset(data.Dataset):  # type: ignore
+    """
+    A PyTorch Dataset class for loading
+    and processing MRI and Excel data from the ADNI dataset.
+    """
+
+    def __init__(
+        self,
+        mri_data: Float[torch.Tensor, "n_samples, width, height, depth"],
+        xls_data: Float[torch.Tensor, "n_samples, features"],
+    ):
+        """
+        Args:
+            mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
+            xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
+        """
+        self.mri_data = mri_data
+        self.xls_data = xls_data
+
+    def __len__(self) -> int:
+        """
+        Returns the number of samples in the dataset.
+        """
+        return self.mri_data.shape[0]  # 0th dimension is the number of samples
+
+    def __getitem__(self, idx: int) -> Tuple[
+        Float[torch.Tensor, "width, height, depth"],
+        Float[torch.Tensor, "features"],
+    ]:
+        """
+        Returns a sample from the dataset at the given index.
+
+        Args:
+            idx (int): Index of the sample to retrieve.
+
+        Returns:
+            tuple: A tuple containing the MRI data and Excel data for the sample.
+        """
+
+        # Slices the data on the 0th dimension, corresponding to the sample index
+        mri_sample = self.mri_data[idx]
+        xls_sample = self.xls_data[idx]
+        return mri_sample, xls_sample
+
+
+def load_adni_data_from_file(
+    mri_files: List[pl.Path],  # List of nibablel files
+    xls_file: pl.Path,  # Path to the Excel file
+    xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
+) -> Result[ADNIDataset, str]:
+    """
+    Loads MRI and Excel data from the ADNI dataset.
+
+    Args:
+        mri_files (List[pl.Path]): List of paths to the MRI files.
+        xls_file (pl.Path): Path to the Excel file.
+
+    Returns:
+        Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message.
+    """
+    # Load the MRI data
+    mri_data_unstacked = [
+        torch.from_numpy(nib.load(file).get_fdata()) for file in mri_files  # type: ignore # type checking does not work well with nibabel
+    ]
+    mri_data = torch.stack(
+        mri_data_unstacked
+    )  # Stack the list of tensors into a single tensor\
+
+    # Load the Excel data
+    xls_data = torch.from_numpy(  # type: ignore
+        xls_preprocessor(pd.read_excel(xls_file)).to_numpy()  # type: ignore
+    ).float()
+
+    # Check if the number of samples in MRI and Excel data match
+    if mri_data.shape[0] == xls_data.shape[0]:
+        return Ok(ADNIDataset(mri_data, xls_data))
+    else:
+        return Err("Loading MRI data failed")
+
+
+def divide_dataset(
+    dataset: ADNIDataset,
+    ratios: Tuple[float, float, float],
+    seed: int = 0,
+) -> Result[List[data.Subset[ADNIDataset]], str]:
+    """
+    Divides the dataset into training, validation, and test sets.
+
+    Args:
+        dataset (ADNIDataset): The dataset to divide.
+        train_ratio (float): The ratio of the training set.
+        val_ratio (float): The ratio of the validation set.
+        test_ratio (float): The ratio of the test set.
+
+    Returns:
+        Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
+    """
+    if sum(ratios) != 1.0:
+        return Err("Ratios must sum to 1.0")
+
+    # Set the random seed for reproducibility
+    gen = torch.Generator().manual_seed(seed)
+    return Ok(data.random_split(dataset, ratios, generator=gen))

+ 72 - 0
model/cnn.py

@@ -0,0 +1,72 @@
+from typing import Tuple
+from torch import nn
+import torch
+import model.layers as ly
+from jaxtyping import Float
+
+
+class CNN_Image_Section(nn.Module):
+    def __init__(self, image_channels: int, droprate: float = 0.0):
+        super().__init__()
+        # Initial Convolutional Blocks
+        self.conv1 = ly.ConvBlock(
+            image_channels,
+            192,
+            (11, 13, 11),
+            stride=(4, 4, 4),
+            droprate=droprate,
+            pool=False,
+        )
+        self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
+
+        # Midflow Block
+        self.midflow = ly.MidFlowBlock(384, droprate)
+
+        # Split Convolutional Block
+        self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
+
+        # Fully Connected Block
+        self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
+
+    def forward(self, x: Float[torch.Tensor, "N C D H W"]):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.midflow(x)
+        x = self.splitconv(x)
+        x = torch.flatten(x, 1)
+        x = self.fc_image(x)
+
+        return x
+
+
+class CNN3D(nn.Module):
+    def __init__(
+        self,
+        image_channels: int,
+        clin_data_channels: int,
+        num_classes: int,
+        droprate: float = 0.0,
+    ):
+        super().__init__()
+
+        self.image_section = CNN_Image_Section(image_channels, droprate=droprate)
+        self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
+        self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
+
+        self.dense1 = nn.Linear(20 + 20, 10)
+        self.dense2 = nn.Linear(10, num_classes)
+        self.softmax = nn.Softmax(dim=1)
+
+    def forward(
+        self, x_in: Tuple[Float[torch.Tensor, "N C D H W"], Float[torch.Tensor, "N F"]]
+    ):
+        image_data, clin_data = x_in
+        image_out = self.image_section(image_data)
+        clin_out = self.fc_clin2(self.fc_clin1(clin_data))
+
+        combined = torch.cat((image_out, clin_out), dim=1)
+        x = self.dense1(combined)
+        x = self.dense2(x)
+        x = self.softmax(x)
+
+        return x

+ 179 - 0
model/layers.py

@@ -0,0 +1,179 @@
+from torch import nn
+from jaxtyping import Float
+import torch
+from typing import Tuple
+
+
+class SepConv3d(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Tuple[int, int, int],
+        stride: int | Tuple[int, int, int] = 1,
+        padding: int | str = 0,
+        bias: bool = False,
+    ):
+        super(SepConv3d, self).__init__()
+        self.depthwise = nn.Conv3d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            groups=out_channels,
+            padding=padding,
+            bias=bias,
+            stride=stride,
+        )
+
+    def forward(self, x: Float[torch.Tensor, "N C D H W"]):
+        x = self.depthwise(x)
+        return x
+
+
+class SplitConvBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        mid_channels: int,
+        out_channels: int,
+        split_dim: int,
+        drop_rate: float,
+    ):
+        super(SplitConvBlock, self).__init__()
+
+        self.split_dim = split_dim
+
+        self.leftconv_1 = SepConvBlock(
+            in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
+        )
+        self.rightconv_1 = SepConvBlock(
+            in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
+        )
+
+        self.leftconv_2 = SepConvBlock(
+            mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
+        )
+        self.rightconv_2 = SepConvBlock(
+            mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
+        )
+
+    def forward(self, x: Float[torch.Tensor, "N C D H W"]):
+        (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
+
+        self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
+        self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
+
+        left = self.leftblock(left)
+        right = self.rightblock(right)
+        a = torch.cat((left, right), dim=self.split_dim)
+        return a
+
+
+class MidFlowBlock(nn.Module):
+    def __init__(self, channels: int, drop_rate: float):
+        super(MidFlowBlock, self).__init__()
+
+        self.conv1 = ConvBlock(
+            channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
+        )
+        self.conv2 = ConvBlock(
+            channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
+        )
+        self.conv3 = ConvBlock(
+            channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
+        )
+
+        # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
+        self.block = self.conv1
+
+    def forward(self, x: Float[torch.Tensor, "N C D H W"]):
+        a = nn.ELU()(self.block(x) + x)
+        return a
+
+
+class ConvBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Tuple[int, int, int],
+        stride: Tuple[int, int, int] = (1, 1, 1),
+        padding: str = "valid",
+        droprate: float = 0.0,
+        pool: bool = False,
+    ):
+        super(ConvBlock, self).__init__()
+        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
+        self.norm = nn.BatchNorm3d(out_channels)
+        self.elu = nn.ELU()
+        self.dropout = nn.Dropout(droprate)
+
+        if pool:
+            self.maxpool = nn.MaxPool3d(3, stride=2)
+        else:
+            self.maxpool = None
+
+    def forward(self, x: Float[torch.Tensor, "N C D H W"]):
+        a = self.conv(x)
+        a = self.norm(a)
+        a = self.elu(a)
+
+        if self.maxpool:
+            a = self.maxpool(a)
+
+        a = self.dropout(a)
+
+        return a
+
+
+class FullConnBlock(nn.Module):
+    def __init__(self, in_channels: int, out_channels: int, droprate: float = 0.0):
+        super(FullConnBlock, self).__init__()
+        self.dense = nn.Linear(in_channels, out_channels)
+        self.norm = nn.BatchNorm1d(out_channels)
+        self.elu = nn.ELU()
+        self.dropout = nn.Dropout(droprate)
+
+    def forward(self, x: Float[torch.Tensor, "N C"]):
+        x = self.dense(x)
+        x = self.norm(x)
+        x = self.elu(x)
+        x = self.dropout(x)
+        return x
+
+
+class SepConvBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Tuple[int, int, int],
+        stride: Tuple[int, int, int] = (1, 1, 1),,
+        padding: str | int="valid",
+        droprate: float = 0.0,
+        pool: bool =False,
+    ):
+        super(SepConvBlock, self).__init__()
+        self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
+        self.norm = nn.BatchNorm3d(out_channels)
+        self.elu = nn.ELU()
+        self.dropout = nn.Dropout(droprate)
+
+        if pool:
+            self.maxpool = nn.MaxPool3d(3, stride=2)
+        else:
+            self.maxpool = None
+
+    def forward(self, x: Float[torch.Tensor, "N C D H W"]):
+        x = self.conv(x)
+        x = self.norm(x)
+        x = self.elu(x)
+
+        if self.maxpool:
+            x = self.maxpool(x)
+
+        x = self.dropout(x)
+
+        return x  
+
+

+ 2 - 0
pyproject.toml

@@ -0,0 +1,2 @@
+[tool.mypy]
+exclude = [".venv/**"]

+ 20 - 0
train_model.py

@@ -0,0 +1,20 @@
+# Torch
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+# Config
+from utils.config import Config
+import pathlib as pl
+
+
+# Custom modules
+from model.cnn import CNN3D
+from data.dataset import ADNIDataset, load_adni_data_from_file, divide_dataset
+
+
+# Load config
+conf = Config()
+
+# Load data
+mri_files = pl.Path(conf["data"]["mri_files"]).glob("*.nii")

+ 65 - 0
utils/config.py

@@ -0,0 +1,65 @@
+import typing
+from pathlib import Path
+import tomli
+import os
+
+
+@typing.no_type_check
+class SingletonMeta(type):
+    """
+    Singleton metaclass to ensure only one instance of a class exists.
+    """
+
+    _instances = {}
+
+    @typing.no_type_check
+    def __call__(cls, *args, **kwargs):
+        if cls not in cls._instances:
+            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+        return cls._instances[cls]
+
+
+class Config(metaclass=SingletonMeta):
+    def __init__(
+        self,
+        config_path: Path | None = Path(__file__).parent.parent / "config.toml",
+    ):
+        """
+        Initialize the configuration object.
+
+        Args:
+            config_path (Path): Path to the configuration file.
+        """
+
+        if config_path is None and "ADL_CONFIG_PATH" in os.environ:
+            self.config_path = Path(os.environ["ADL_CONFIG_PATH"])
+        elif config_path is not None:
+            self.config_path = config_path
+        else:
+            raise ValueError("Either config_path or ADL_CONFIG_PATH must be provided")
+
+        self.loaded_config_path = None
+        self._load_config()
+
+    def _load_config(self):
+        """
+        Load the configuration from the specified file.
+        """
+        with open(self.config_path, "rb") as f:
+            config = tomli.load(f)
+            self.loaded_config_path = str(self.config_path)
+            print(f"Loaded config from {self.config_path}")
+
+        self._config_dict = config
+
+    def __getitem__(self, key: str):
+        """
+        Get a configuration value by key.
+
+        Args:
+            key (str): The key of the configuration value.
+
+        Returns:
+            The configuration value.
+        """
+        return self._config_dict.get(key, None)

+ 0 - 0
utils/training.py