Nicholas Schense пре 4 дана
родитељ
комит
ef6fea17e2
9 измењених фајлова са 175 додато и 99 уклоњено
  1. BIN
      .DS_Store
  2. 2 0
      .gitignore
  3. 20 0
      config.toml
  4. 72 25
      data/dataset.py
  5. 0 0
      evaluate_models.py
  6. 4 6
      model/layers.py
  7. 45 41
      train_model.py
  8. 1 1
      utils/config.py
  9. 31 26
      utils/training.py

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+.venv/
+__pycache__/

+ 20 - 0
config.toml

@@ -0,0 +1,20 @@
+[data]
+mri_files_path = "../data/PET_volumes_customtemplate_float32/"
+xls_file_path = "../data/LP_ADNIMERGE.csv"
+seed = 1
+data_splits = [0.7, 0.2, 0.1] # train, validation, test
+image_channels = 1
+clin_data_channels = 2
+num_classes = 2 # AD, NL
+
+
+[training]
+device = "mps" # "cpu", "cuda", "mps"
+batch_size = 32
+ensemble_size = 1
+droprate = 0.1
+learning_rate = 0.001
+num_epochs = 1
+
+[output]
+path = "../models/"

+ 72 - 25
data/dataset.py

@@ -4,11 +4,11 @@ import torch.utils.data as data
 import pathlib as pl
 import pandas as pd
 from torch.utils.data import Subset, DataLoader
+import re
 
 
 from jaxtyping import Float
 from typing import Tuple, Iterator, Callable, List
-from result import Ok, Err, Result
 
 
 class ADNIDataset(data.Dataset):  # type: ignore
@@ -19,17 +19,19 @@ class ADNIDataset(data.Dataset):  # type: ignore
 
     def __init__(
         self,
-        mri_data: Float[torch.Tensor, "n_samples width height depth"],
+        mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
         xls_data: Float[torch.Tensor, "n_samples features"],
+        expected_classes: Float[torch.Tensor, "classes"],
         device: str = "cuda",
     ):
         """
         Args:
-            mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
+            mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
             xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
         """
         self.mri_data = mri_data.float().to(device)
         self.xls_data = xls_data.float().to(device)
+        self.expected_classes = expected_classes.float().to(device)
 
     def __len__(self) -> int:
         """
@@ -38,8 +40,9 @@ class ADNIDataset(data.Dataset):  # type: ignore
         return self.mri_data.shape[0]  # 0th dimension is the number of samples
 
     def __getitem__(self, idx: int) -> Tuple[
-        Float[torch.Tensor, "width height depth"],
+        Float[torch.Tensor, "channels width height depth"],
         Float[torch.Tensor, "features"],
+        Float[torch.Tensor, "classes"],
     ]:
         """
         Returns a sample from the dataset at the given index.
@@ -54,7 +57,10 @@ class ADNIDataset(data.Dataset):  # type: ignore
         # Slices the data on the 0th dimension, corresponding to the sample index
         mri_sample = self.mri_data[idx]
         xls_sample = self.xls_data[idx]
-        return mri_sample, xls_sample
+        # Assuming expected_classes is a tensor of classes, we return it as well
+        expected_classes = self.expected_classes[idx]
+
+        return mri_sample, xls_sample, expected_classes
 
 
 def load_adni_data_from_file(
@@ -62,7 +68,7 @@ def load_adni_data_from_file(
     xls_file: pl.Path,  # Path to the Excel file
     device: str = "cuda",
     xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
-) -> Result[ADNIDataset, str]:
+) -> ADNIDataset:
     """
     Loads MRI and Excel data from the ADNI dataset.
 
@@ -73,31 +79,72 @@ def load_adni_data_from_file(
     Returns:
         Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message.
     """
-    # Load the MRI data
-    mri_data_unstacked = [
-        torch.from_numpy(nib.load(file).get_fdata()) for file in mri_files  # type: ignore # type checking does not work well with nibabel
-    ]
-    mri_data = torch.stack(
-        mri_data_unstacked
-    )  # Stack the list of tensors into a single tensor\
-
     # Load the Excel data
-    xls_data = torch.from_numpy(  # type: ignore
-        xls_preprocessor(pd.read_excel(xls_file)).to_numpy()  # type: ignore
-    ).float()
+    xls_values = xls_preprocessor(pd.read_csv(xls_file))  # type: ignore
 
-    # Check if the number of samples in MRI and Excel data match
-    if mri_data.shape[0] == xls_data.shape[0]:
-        return Ok(ADNIDataset(mri_data, xls_data, device=device))
-    else:
-        return Err("Loading MRI data failed")
+    # Load the MRI data
+    mri_data_unstacked: List[torch.Tensor] = []
+    expected_classes_unstacked: List[torch.Tensor] = []
+    xls_data_unstacked: List[torch.Tensor] = []
+    img_ids: List[int] = []
+    for file in mri_files:
+        filename = file.stem
+        match re.search(r".+?(?=_I)_I(\d+).+", filename):
+            case None:
+                raise ValueError(
+                    f"Filename {filename} does not match expected pattern."
+                )
+            case m:
+                img_id = int(m.group(1))
+
+        file_mri_data = torch.from_numpy(nib.load(file).get_fdata())  # type: ignore # type checking does not work well with nibabel
+
+        # Read the filename to determine the expected class
+        file_expected_class = torch.tensor([0.0, 0.0])  # Default to a tensor of zeros
+
+        if "AD" in filename:
+            file_expected_class = torch.tensor([1.0, 0.0])
+        elif "CN" in filename:
+            file_expected_class = torch.tensor([0.0, 1.0])
+
+        mri_data_unstacked.append(file_mri_data)
+        expected_classes_unstacked.append(file_expected_class)
+        # Extract the corresponding row from the Excel data using the img_id
+        xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
+        if xls_row.empty:
+            raise ValueError(
+                f"No matching row found in Excel data for Image Data ID {img_id}."
+            )
+        elif len(xls_row) > 1:
+            raise ValueError(
+                f"Multiple rows found in Excel data for Image Data ID {img_id}."
+            )
+        file_xls_data = torch.tensor(
+            xls_row.drop(columns=["Image Data ID"]).values.flatten()  # type: ignore
+        )
+
+        xls_data_unstacked.append(file_xls_data)
+        img_ids.append(img_id)
+
+    mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
+    # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
+
+    xls_data = torch.stack(
+        xls_data_unstacked
+    )  # Stack the list of tensors into a single tensor
+
+    expected_classes = torch.stack(
+        expected_classes_unstacked
+    )  # Stack the list of expected classes into a single tensor
+
+    return ADNIDataset(mri_data, xls_data, expected_classes, device=device)
 
 
 def divide_dataset(
     dataset: ADNIDataset,
     ratios: Tuple[float, float, float],
     seed: int,
-) -> Result[List[data.Subset[ADNIDataset]], str]:
+) -> List[data.Subset[ADNIDataset]]:
     """
     Divides the dataset into training, validation, and test sets.
 
@@ -111,11 +158,11 @@ def divide_dataset(
         Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
     """
     if sum(ratios) != 1.0:
-        return Err("Ratios must sum to 1.0")
+        raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
 
     # Set the random seed for reproducibility
     gen = torch.Generator().manual_seed(seed)
-    return Ok(data.random_split(dataset, ratios, generator=gen))
+    return data.random_split(dataset, ratios, generator=gen)
 
 
 def initalize_dataloaders(

+ 0 - 0
conflg.toml → evaluate_models.py


+ 4 - 6
model/layers.py

@@ -148,10 +148,10 @@ class SepConvBlock(nn.Module):
         in_channels: int,
         out_channels: int,
         kernel_size: Tuple[int, int, int],
-        stride: Tuple[int, int, int] = (1, 1, 1),,
-        padding: str | int="valid",
+        stride: Tuple[int, int, int] = (1, 1, 1),
+        padding: str | int = "valid",
         droprate: float = 0.0,
-        pool: bool =False,
+        pool: bool = False,
     ):
         super(SepConvBlock, self).__init__()
         self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
@@ -174,6 +174,4 @@ class SepConvBlock(nn.Module):
 
         x = self.dropout(x)
 
-        return x  
-
-
+        return x

+ 45 - 41
train_model.py

@@ -6,7 +6,7 @@ import torch.optim as optim
 # Config
 from utils.config import config
 import pathlib as pl
-from result import Ok, Err
+import pandas as pd
 import json
 
 
@@ -26,34 +26,33 @@ xls_file = pl.Path(config["data"]["xls_file_path"])
 
 # Load the data
 
-match load_adni_data_from_file(
-    mri_files, xls_file, device=config["training"]["device"]
-):
-    case Ok(d):
-        dataset = d
-        print("Data loaded successfully")
-    case Err(e):
-        print(f"Error loading data: {e}")
-        exit(-1)
 
+def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
+    """
+    Preprocess the Excel DataFrame.
+    This function can be customized to filter or modify the DataFrame as needed.
+    """
+
+    data = df[["Image Data ID", "Sex", "Age (current)"]]
+    data["Sex"] = data["Sex"].str.strip()  # type: ignore
+    data = data.replace({"M": 0, "F": 1})  # type: ignore
+    data.set_index("Image Data ID")  # type: ignore
+
+    return data
+
+
+dataset = load_adni_data_from_file(
+    mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
+)
 
 # Divide the dataset into training and validation sets
 if config["data"]["seed"] is None:
     print("Warning: No seed provided for dataset division, using default seed 0")
     config["data"]["seed"] = 0
 
-match divide_dataset(
-    dataset, config["data"]["train_val_split"], seed=config["data"]["seed"]
-):
-    case Ok(s):
-        if len(s) != 3:
-            print(f"Error: Expected 3 subsets (train, val, test), got {len(s)}")
-            exit(-1)
-        datasets = s
-        print("Dataset divided successfully")
-    case Err(e):
-        print(f"Error dividing dataset: {e}")
-        exit(-1)
+datasets = divide_dataset(
+    dataset, config["data"]["data_splits"], seed=config["data"]["seed"]
+)
 
 
 # Initialize the dataloaders
@@ -62,7 +61,7 @@ train_loader, val_loader, test_loader = initalize_dataloaders(
 )
 
 # Save seed to output config file
-output_config_path = pl.Path(config["output"]["path"] / "config.json")
+output_config_path = pl.Path(config["output"]["path"]) / "config.json"
 if not output_config_path.parent.exists():
     output_config_path.parent.mkdir(parents=True, exist_ok=True)
 
@@ -74,8 +73,8 @@ print(f"Configuration saved to {output_config_path}")
 
 # Set up the ensemble training loop
 
-for run_num in range(config["training"]["ensemble_runs"]):
-    print(f"Starting run {run_num + 1}/{config['training']['ensemble_runs']}")
+for run_num in range(config["training"]["ensemble_size"]):
+    print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}")
 
     # Initialize the model
     model = (
@@ -83,7 +82,7 @@ for run_num in range(config["training"]["ensemble_runs"]):
             image_channels=config["data"]["image_channels"],
             clin_data_channels=config["data"]["clin_data_channels"],
             num_classes=config["data"]["num_classes"],
-            droprate=config["training"]["drop_rate"],
+            droprate=config["training"]["droprate"],
         )
         .float()
         .to(config["training"]["device"])
@@ -112,38 +111,43 @@ for run_num in range(config["training"]["ensemble_runs"]):
     )
 
     print(
-        f"Run {run_num + 1}/{config['training']['ensemble_runs']} - "
+        f"Run {run_num + 1}/{config['training']['ensemble_size']} - "
         f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
     )
 
     # Save the model
-    model_save_path = pl.Path(config["output"]["path"] / f"model_run_{run_num + 1}.pt")
+    model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt"
     torch.save(model.state_dict(), model_save_path)
     print(f"Model saved to {model_save_path}")
 
     # Save the training history
-    history_save_path = pl.Path(
-        config["output"]["path"] / f"history_run_{run_num + 1}.nc"
+    history_save_path = (
+        pl.Path(config["output"]["path"]) / f"history_run_{run_num + 1}.nc"
     )
 
     history.to_netcdf(history_save_path, mode="w")  # type: ignore
     print(f"Training history saved to {history_save_path}")
 
-    # Save test results
-    test_results_save_path = pl.Path(
-        config["output"]["path"] / f"test_results_run_{run_num + 1}.json"
-    )
-    with open(test_results_save_path, "w") as f:
-        json.dump(
+    # Save test results by appending to the results file
+    test_results_save_path = pl.Path(config["output"]["path"]) / f"results.json"
+    with open(test_results_save_path, "wr+") as f:
+        try:
+            results = json.load(f)
+        except json.JSONDecodeError:
+            # If the file is empty or not a valid JSON, initialize an empty list
+            print("No previous results found, initializing results list.")
+            results = []
+
+        results.append(  # type: ignore
             {
+                "run": run_num + 1,
                 "test_loss": test_loss,
                 "test_accuracy": test_acc,
-            },
-            f,
-            indent=4,
+            }
         )
-    print(f"Test results saved to {test_results_save_path}")
-    print(f"Run {run_num + 1}/{config['training']['ensemble_runs']} completed\n")
+        f.seek(0)
+        json.dump(results, f, indent=4)
+    print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
 
 # Completion message
 print(f"All runs completed. Models and results saved to {config['output']['path']}")

+ 1 - 1
utils/config.py

@@ -17,7 +17,7 @@ def get_config() -> dict[str, Any]:
             config_path = pl.Path(path)
 
     if not config_path.exists():
-        raise FileNotFoundError(f"Config file not found at {config_path}")
+        config_path = pl.Path(__file__).parent.parent / "config.toml"
     with open(config_path, "rb") as f:
         config = tomllib.load(f)
 

+ 31 - 26
utils/training.py

@@ -5,6 +5,7 @@ import xarray as xr
 from data.dataset import ADNIDataset
 from typing import Tuple
 from tqdm import tqdm
+import numpy as np
 
 type TrainMetrics = Tuple[
     float, float, float, float
@@ -35,12 +36,12 @@ def test_model(
     total = 0
 
     with torch.no_grad():
-        for _, (inputs, targets) in tqdm(
+        for _, (mri, xls, targets) in tqdm(
             enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
         ):
-            outputs = model(inputs)
+            outputs = model((mri, xls))
             loss = criterion(outputs, targets)
-            test_loss += loss.item() * inputs.size(0)
+            test_loss += loss.item() * (mri.size(0) + xls.size(0))
 
             # Calculate accuracy
             predicted = (outputs > 0.5).float()
@@ -76,15 +77,15 @@ def train_epoch(
     train_loss = 0.0
 
     # Training loop
-    for _, (inputs, targets) in tqdm(
+    for _, (mri, xls, targets) in tqdm(
         enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
     ):
         optimizer.zero_grad()
-        outputs = model(inputs)
+        outputs = model((mri, xls))
         loss = criterion(outputs, targets)
         loss.backward()
         optimizer.step()
-        train_loss += loss.item() * inputs.size(0)
+        train_loss += loss.item() * (mri.size(0) + xls.size(0))
     train_loss /= len(train_loader)
 
     model.eval()
@@ -93,15 +94,15 @@ def train_epoch(
     total = 0
 
     with torch.no_grad():
-        for _, (inputs, targets) in tqdm(
+        for _, (mri, xls, targets) in tqdm(
             enumerate(val_loader),
             desc="Validation",
             total=len(val_loader),
             unit="batch",
         ):
-            outputs = model(inputs)
+            outputs = model((mri, xls))
             loss = criterion(outputs, targets)
-            val_loss += loss.item() * inputs.size(0)
+            val_loss += loss.item() * (mri.size(0) + xls.size(0))
 
             # Calculate accuracy
             predicted = (outputs > 0.5).float()
@@ -139,14 +140,9 @@ def train_model(
 
     # Record the training history
     # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
-    history = xr.DataArray(
-        data=[],
-        dims=["epoch", "metric"],
-        coords={
-            "epoch": range(num_epochs),
-            "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
-        },
-    )
+    # use a (num_epochs, 4) shape ndarray to store the history before creating the DataArray
+
+    nhist = np.zeros((num_epochs, 4), dtype=np.float32)
 
     for epoch in range(num_epochs):
         train_loss, val_loss, train_acc, val_acc = train_epoch(
@@ -158,12 +154,10 @@ def train_model(
         )
 
         # Update the history
-        history[
-            {
-                "epoch": epoch,
-                "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
-            }
-        ] = [train_loss, val_loss, train_acc, val_acc]
+        nhist[epoch, 0] = train_loss
+        nhist[epoch, 1] = val_loss
+        nhist[epoch, 2] = train_acc
+        nhist[epoch, 3] = val_acc
 
         print(
             f"Epoch [{epoch + 1}/{num_epochs}], "
@@ -172,9 +166,20 @@ def train_model(
         )
 
         # If we are at 25, 50, or 75% of the epochs, save the model
-        if (epoch + 1) % (num_epochs // 4) == 0:
-            torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
-            print(f"Model saved at epoch {epoch + 1}")
+        if num_epochs > 4:
+            if (epoch + 1) % (num_epochs // 4) == 0:
+                torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
+                print(f"Model saved at epoch {epoch + 1}")
 
     # return the trained model and the traning history
+
+    history = xr.DataArray(
+        data=nhist,
+        dims=["epoch", "metric"],
+        coords={
+            "epoch": range(num_epochs),
+            "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
+        },
+    )
+
     return model, history