import torch import torch.nn as nn from torch.utils.data import DataLoader import xarray as xr from data.dataset import ADNIDataset from typing import Tuple from tqdm import tqdm import numpy as np type TrainMetrics = Tuple[ float, float, float, float ] # (train_loss, val_loss, train_acc, val_acc) type TestMetrics = Tuple[float, float] # (test_loss, test_acc) def test_model( model: nn.Module, test_loader: DataLoader[ADNIDataset], criterion: nn.Module, ) -> TestMetrics: """ Tests the model on the test dataset. Args: model (nn.Module): The model to test. test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset. criterion (nn.Module): Loss function to compute the loss. Returns: TrainMetrics: A tuple containing the test loss and test accuracy. """ model.eval() test_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for _, (mri, xls, targets) in tqdm( enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch" ): outputs = model((mri, xls)) loss = criterion(outputs, targets) test_loss += loss.item() * (mri.size(0) + xls.size(0)) # Calculate accuracy predicted = (outputs > 0.5).float() correct += (predicted == targets).sum().item() total += targets.numel() test_loss /= len(test_loader) test_acc = correct / total if total > 0 else 0.0 return test_loss, test_acc def train_epoch( model: nn.Module, train_loader: DataLoader[ADNIDataset], val_loader: DataLoader[ADNIDataset], optimizer: torch.optim.Optimizer, criterion: nn.Module, ) -> Tuple[float, float, float, float]: """ Trains the model for one epoch and evaluates it on the validation set. Args: model (nn.Module): The model to train. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset. optimizer (torch.optim.Optimizer): Optimizer for updating model parameters. criterion (nn.Module): Loss function to compute the loss. Returns: Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy. """ model.train() train_loss = 0.0 # Training loop for _, (mri, xls, targets) in tqdm( enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch" ): optimizer.zero_grad() outputs = model((mri, xls)) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() * (mri.size(0) + xls.size(0)) train_loss /= len(train_loader) model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for _, (mri, xls, targets) in tqdm( enumerate(val_loader), desc="Validation", total=len(val_loader), unit="batch", ): outputs = model((mri, xls)) loss = criterion(outputs, targets) val_loss += loss.item() * (mri.size(0) + xls.size(0)) # Calculate accuracy predicted = (outputs > 0.5).float() correct += (predicted == targets).sum().item() total += targets.numel() val_loss /= len(val_loader) val_acc = correct / total if total > 0 else 0.0 train_acc = correct / total if total > 0 else 0.0 return train_loss, val_loss, train_acc, val_acc def train_model( model: nn.Module, train_loader: DataLoader[ADNIDataset], val_loader: DataLoader[ADNIDataset], optimizer: torch.optim.Optimizer, criterion: nn.Module, num_epochs: int, learning_rate: float, ) -> Tuple[nn.Module, xr.DataArray]: """ Trains the model using the provided training and validation data loaders. Args: model (nn.Module): The model to train. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset. num_epochs (int): Number of epochs to train the model. learning_rate (float): Learning rate for the optimizer. Returns: Result[nn.Module, str]: A Result object containing the trained model or an error message. """ # Record the training history # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy # use a (num_epochs, 4) shape ndarray to store the history before creating the DataArray nhist = np.zeros((num_epochs, 4), dtype=np.float32) for epoch in range(num_epochs): train_loss, val_loss, train_acc, val_acc = train_epoch( model, train_loader, val_loader, optimizer, criterion, ) # Update the history nhist[epoch, 0] = train_loss nhist[epoch, 1] = val_loss nhist[epoch, 2] = train_acc nhist[epoch, 3] = val_acc print( f"Epoch [{epoch + 1}/{num_epochs}], " f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, " f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}" ) # If we are at 25, 50, or 75% of the epochs, save the model if num_epochs > 4: if (epoch + 1) % (num_epochs // 4) == 0: torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth") print(f"Model saved at epoch {epoch + 1}") # return the trained model and the traning history history = xr.DataArray( data=nhist, dims=["epoch", "metric"], coords={ "epoch": range(num_epochs), "metric": ["train_loss", "val_loss", "train_acc", "val_acc"], }, ) return model, history