Explorar el Código

completed inital version of training

Nicholas Schense hace 4 semanas
padre
commit
a52bd0e038
Se han modificado 4 ficheros con 365 adiciones y 72 borrados
  1. 31 9
      data/dataset.py
  2. 136 7
      train_model.py
  3. 18 56
      utils/config.py
  4. 180 0
      utils/training.py

+ 31 - 9
data/dataset.py

@@ -3,10 +3,11 @@ import torch
 import torch.utils.data as data
 import pathlib as pl
 import pandas as pd
+from torch.utils.data import Subset, DataLoader
 
 
 from jaxtyping import Float
-from typing import Tuple, List, Callable
+from typing import Tuple, Iterator, Callable, List
 from result import Ok, Err, Result
 
 
@@ -18,16 +19,17 @@ class ADNIDataset(data.Dataset):  # type: ignore
 
     def __init__(
         self,
-        mri_data: Float[torch.Tensor, "n_samples, width, height, depth"],
-        xls_data: Float[torch.Tensor, "n_samples, features"],
+        mri_data: Float[torch.Tensor, "n_samples width height depth"],
+        xls_data: Float[torch.Tensor, "n_samples features"],
+        device: str = "cuda",
     ):
         """
         Args:
             mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
             xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
         """
-        self.mri_data = mri_data
-        self.xls_data = xls_data
+        self.mri_data = mri_data.float().to(device)
+        self.xls_data = xls_data.float().to(device)
 
     def __len__(self) -> int:
         """
@@ -36,7 +38,7 @@ class ADNIDataset(data.Dataset):  # type: ignore
         return self.mri_data.shape[0]  # 0th dimension is the number of samples
 
     def __getitem__(self, idx: int) -> Tuple[
-        Float[torch.Tensor, "width, height, depth"],
+        Float[torch.Tensor, "width height depth"],
         Float[torch.Tensor, "features"],
     ]:
         """
@@ -56,8 +58,9 @@ class ADNIDataset(data.Dataset):  # type: ignore
 
 
 def load_adni_data_from_file(
-    mri_files: List[pl.Path],  # List of nibablel files
+    mri_files: Iterator[pl.Path],  # List of nibablel files
     xls_file: pl.Path,  # Path to the Excel file
+    device: str = "cuda",
     xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
 ) -> Result[ADNIDataset, str]:
     """
@@ -85,7 +88,7 @@ def load_adni_data_from_file(
 
     # Check if the number of samples in MRI and Excel data match
     if mri_data.shape[0] == xls_data.shape[0]:
-        return Ok(ADNIDataset(mri_data, xls_data))
+        return Ok(ADNIDataset(mri_data, xls_data, device=device))
     else:
         return Err("Loading MRI data failed")
 
@@ -93,7 +96,7 @@ def load_adni_data_from_file(
 def divide_dataset(
     dataset: ADNIDataset,
     ratios: Tuple[float, float, float],
-    seed: int = 0,
+    seed: int,
 ) -> Result[List[data.Subset[ADNIDataset]], str]:
     """
     Divides the dataset into training, validation, and test sets.
@@ -113,3 +116,22 @@ def divide_dataset(
     # Set the random seed for reproducibility
     gen = torch.Generator().manual_seed(seed)
     return Ok(data.random_split(dataset, ratios, generator=gen))
+
+
+def initalize_dataloaders(
+    datasets: List[Subset[ADNIDataset]],
+    batch_size: int = 64,
+) -> List[DataLoader[ADNIDataset]]:
+    """
+    Initializes the DataLoader for the given datasets.
+
+    Args:
+        datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
+        batch_size (int): The batch size for the DataLoader.
+
+    Returns:
+        List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
+    """
+    return [
+        DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
+    ]

+ 136 - 7
train_model.py

@@ -1,20 +1,149 @@
 # Torch
-import torch
 import torch.nn as nn
+import torch
 import torch.optim as optim
 
 # Config
-from utils.config import Config
+from utils.config import config
 import pathlib as pl
+from result import Ok, Err
+import json
 
 
 # Custom modules
 from model.cnn import CNN3D
-from data.dataset import ADNIDataset, load_adni_data_from_file, divide_dataset
+from utils.training import train_model, test_model
+from data.dataset import (
+    load_adni_data_from_file,
+    divide_dataset,
+    initalize_dataloaders,
+)
 
+# Load data
+mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
+xls_file = pl.Path(config["data"]["xls_file_path"])
 
-# Load config
-conf = Config()
 
-# Load data
-mri_files = pl.Path(conf["data"]["mri_files"]).glob("*.nii")
+# Load the data
+
+match load_adni_data_from_file(
+    mri_files, xls_file, device=config["training"]["device"]
+):
+    case Ok(d):
+        dataset = d
+        print("Data loaded successfully")
+    case Err(e):
+        print(f"Error loading data: {e}")
+        exit(-1)
+
+
+# Divide the dataset into training and validation sets
+if config["data"]["seed"] is None:
+    print("Warning: No seed provided for dataset division, using default seed 0")
+    config["data"]["seed"] = 0
+
+match divide_dataset(
+    dataset, config["data"]["train_val_split"], seed=config["data"]["seed"]
+):
+    case Ok(s):
+        if len(s) != 3:
+            print(f"Error: Expected 3 subsets (train, val, test), got {len(s)}")
+            exit(-1)
+        datasets = s
+        print("Dataset divided successfully")
+    case Err(e):
+        print(f"Error dividing dataset: {e}")
+        exit(-1)
+
+
+# Initialize the dataloaders
+train_loader, val_loader, test_loader = initalize_dataloaders(
+    datasets, batch_size=config["training"]["batch_size"]
+)
+
+# Save seed to output config file
+output_config_path = pl.Path(config["output"]["path"] / "config.json")
+if not output_config_path.parent.exists():
+    output_config_path.parent.mkdir(parents=True, exist_ok=True)
+
+
+with open(output_config_path, "w") as f:
+    # Save as JSON
+    json.dump(config, f, indent=4)
+print(f"Configuration saved to {output_config_path}")
+
+# Set up the ensemble training loop
+
+for run_num in range(config["training"]["ensemble_runs"]):
+    print(f"Starting run {run_num + 1}/{config['training']['ensemble_runs']}")
+
+    # Initialize the model
+    model = (
+        CNN3D(
+            image_channels=config["data"]["image_channels"],
+            clin_data_channels=config["data"]["clin_data_channels"],
+            num_classes=config["data"]["num_classes"],
+            droprate=config["training"]["drop_rate"],
+        )
+        .float()
+        .to(config["training"]["device"])
+    )
+
+    # Set up the optimizer and loss function
+    optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
+    criterion = nn.BCELoss()
+
+    # Train model
+    model, history = train_model(
+        model=model,
+        train_loader=train_loader,
+        val_loader=val_loader,
+        optimizer=optimizer,
+        criterion=criterion,
+        num_epochs=config["training"]["num_epochs"],
+        learning_rate=config["training"]["learning_rate"],
+    )
+
+    # Test model
+    test_loss, test_acc = test_model(
+        model=model,
+        test_loader=test_loader,
+        criterion=criterion,
+    )
+
+    print(
+        f"Run {run_num + 1}/{config['training']['ensemble_runs']} - "
+        f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
+    )
+
+    # Save the model
+    model_save_path = pl.Path(config["output"]["path"] / f"model_run_{run_num + 1}.pt")
+    torch.save(model.state_dict(), model_save_path)
+    print(f"Model saved to {model_save_path}")
+
+    # Save the training history
+    history_save_path = pl.Path(
+        config["output"]["path"] / f"history_run_{run_num + 1}.nc"
+    )
+
+    history.to_netcdf(history_save_path, mode="w")  # type: ignore
+    print(f"Training history saved to {history_save_path}")
+
+    # Save test results
+    test_results_save_path = pl.Path(
+        config["output"]["path"] / f"test_results_run_{run_num + 1}.json"
+    )
+    with open(test_results_save_path, "w") as f:
+        json.dump(
+            {
+                "test_loss": test_loss,
+                "test_accuracy": test_acc,
+            },
+            f,
+            indent=4,
+        )
+    print(f"Test results saved to {test_results_save_path}")
+    print(f"Run {run_num + 1}/{config['training']['ensemble_runs']} completed\n")
+
+# Completion message
+print(f"All runs completed. Models and results saved to {config['output']['path']}")

+ 18 - 56
utils/config.py

@@ -1,65 +1,27 @@
-import typing
-from pathlib import Path
-import tomli
+# This file serves as a singleton for the configuration settings of the project
+
+import tomllib
 import os
+import pathlib as pl
+from typing import Any
 
 
-@typing.no_type_check
-class SingletonMeta(type):
+def get_config() -> dict[str, Any]:
     """
-    Singleton metaclass to ensure only one instance of a class exists.
+    Load the configuration file and return the settings as a dictionary.
     """
+    match os.getenv("ANN_CONFIG_PATH"):
+        case None:
+            config_path = pl.Path(__file__).parent.parent / "config.toml"
+        case str(path):
+            config_path = pl.Path(path)
 
-    _instances = {}
-
-    @typing.no_type_check
-    def __call__(cls, *args, **kwargs):
-        if cls not in cls._instances:
-            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
-        return cls._instances[cls]
-
-
-class Config(metaclass=SingletonMeta):
-    def __init__(
-        self,
-        config_path: Path | None = Path(__file__).parent.parent / "config.toml",
-    ):
-        """
-        Initialize the configuration object.
-
-        Args:
-            config_path (Path): Path to the configuration file.
-        """
-
-        if config_path is None and "ADL_CONFIG_PATH" in os.environ:
-            self.config_path = Path(os.environ["ADL_CONFIG_PATH"])
-        elif config_path is not None:
-            self.config_path = config_path
-        else:
-            raise ValueError("Either config_path or ADL_CONFIG_PATH must be provided")
-
-        self.loaded_config_path = None
-        self._load_config()
-
-    def _load_config(self):
-        """
-        Load the configuration from the specified file.
-        """
-        with open(self.config_path, "rb") as f:
-            config = tomli.load(f)
-            self.loaded_config_path = str(self.config_path)
-            print(f"Loaded config from {self.config_path}")
-
-        self._config_dict = config
+    if not config_path.exists():
+        raise FileNotFoundError(f"Config file not found at {config_path}")
+    with open(config_path, "rb") as f:
+        config = tomllib.load(f)
 
-    def __getitem__(self, key: str):
-        """
-        Get a configuration value by key.
+    return config
 
-        Args:
-            key (str): The key of the configuration value.
 
-        Returns:
-            The configuration value.
-        """
-        return self._config_dict.get(key, None)
+config = get_config()

+ 180 - 0
utils/training.py

@@ -0,0 +1,180 @@
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+import xarray as xr
+from data.dataset import ADNIDataset
+from typing import Tuple
+from tqdm import tqdm
+
+type TrainMetrics = Tuple[
+    float, float, float, float
+]  # (train_loss, val_loss, train_acc, val_acc)
+
+type TestMetrics = Tuple[float, float]  # (test_loss, test_acc)
+
+
+def test_model(
+    model: nn.Module,
+    test_loader: DataLoader[ADNIDataset],
+    criterion: nn.Module,
+) -> TestMetrics:
+    """
+    Tests the model on the test dataset.
+
+    Args:
+        model (nn.Module): The model to test.
+        test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset.
+        criterion (nn.Module): Loss function to compute the loss.
+
+    Returns:
+        TrainMetrics: A tuple containing the test loss and test accuracy.
+    """
+    model.eval()
+    test_loss = 0.0
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for _, (inputs, targets) in tqdm(
+            enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
+        ):
+            outputs = model(inputs)
+            loss = criterion(outputs, targets)
+            test_loss += loss.item() * inputs.size(0)
+
+            # Calculate accuracy
+            predicted = (outputs > 0.5).float()
+            correct += (predicted == targets).sum().item()
+            total += targets.numel()
+
+    test_loss /= len(test_loader)
+    test_acc = correct / total if total > 0 else 0.0
+    return test_loss, test_acc
+
+
+def train_epoch(
+    model: nn.Module,
+    train_loader: DataLoader[ADNIDataset],
+    val_loader: DataLoader[ADNIDataset],
+    optimizer: torch.optim.Optimizer,
+    criterion: nn.Module,
+) -> Tuple[float, float, float, float]:
+    """
+    Trains the model for one epoch and evaluates it on the validation set.
+
+    Args:
+        model (nn.Module): The model to train.
+        train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
+        val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
+        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
+        criterion (nn.Module): Loss function to compute the loss.
+
+    Returns:
+        Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy.
+    """
+    model.train()
+    train_loss = 0.0
+
+    # Training loop
+    for _, (inputs, targets) in tqdm(
+        enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
+    ):
+        optimizer.zero_grad()
+        outputs = model(inputs)
+        loss = criterion(outputs, targets)
+        loss.backward()
+        optimizer.step()
+        train_loss += loss.item() * inputs.size(0)
+    train_loss /= len(train_loader)
+
+    model.eval()
+    val_loss = 0.0
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for _, (inputs, targets) in tqdm(
+            enumerate(val_loader),
+            desc="Validation",
+            total=len(val_loader),
+            unit="batch",
+        ):
+            outputs = model(inputs)
+            loss = criterion(outputs, targets)
+            val_loss += loss.item() * inputs.size(0)
+
+            # Calculate accuracy
+            predicted = (outputs > 0.5).float()
+            correct += (predicted == targets).sum().item()
+            total += targets.numel()
+
+    val_loss /= len(val_loader)
+    val_acc = correct / total if total > 0 else 0.0
+    train_acc = correct / total if total > 0 else 0.0
+    return train_loss, val_loss, train_acc, val_acc
+
+
+def train_model(
+    model: nn.Module,
+    train_loader: DataLoader[ADNIDataset],
+    val_loader: DataLoader[ADNIDataset],
+    optimizer: torch.optim.Optimizer,
+    criterion: nn.Module,
+    num_epochs: int,
+    learning_rate: float,
+) -> Tuple[nn.Module, xr.DataArray]:
+    """
+    Trains the model using the provided training and validation data loaders.
+
+    Args:
+        model (nn.Module): The model to train.
+        train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
+        val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
+        num_epochs (int): Number of epochs to train the model.
+        learning_rate (float): Learning rate for the optimizer.
+
+    Returns:
+        Result[nn.Module, str]: A Result object containing the trained model or an error message.
+    """
+
+    # Record the training history
+    # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
+    history = xr.DataArray(
+        data=[],
+        dims=["epoch", "metric"],
+        coords={
+            "epoch": range(num_epochs),
+            "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
+        },
+    )
+
+    for epoch in range(num_epochs):
+        train_loss, val_loss, train_acc, val_acc = train_epoch(
+            model,
+            train_loader,
+            val_loader,
+            optimizer,
+            criterion,
+        )
+
+        # Update the history
+        history[
+            {
+                "epoch": epoch,
+                "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
+            }
+        ] = [train_loss, val_loss, train_acc, val_acc]
+
+        print(
+            f"Epoch [{epoch + 1}/{num_epochs}], "
+            f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
+            f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
+        )
+
+        # If we are at 25, 50, or 75% of the epochs, save the model
+        if (epoch + 1) % (num_epochs // 4) == 0:
+            torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
+            print(f"Model saved at epoch {epoch + 1}")
+
+    # return the trained model and the traning history
+    return model, history