import torch import torch.nn as nn from torch.utils.data import DataLoader import xarray as xr from data.dataset import ADNIDataset from typing import Tuple from tqdm import tqdm type TrainMetrics = Tuple[ float, float, float, float ] # (train_loss, val_loss, train_acc, val_acc) type TestMetrics = Tuple[float, float] # (test_loss, test_acc) def test_model( model: nn.Module, test_loader: DataLoader[ADNIDataset], criterion: nn.Module, ) -> TestMetrics: """ Tests the model on the test dataset. Args: model (nn.Module): The model to test. test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset. criterion (nn.Module): Loss function to compute the loss. Returns: TrainMetrics: A tuple containing the test loss and test accuracy. """ model.eval() test_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for _, (inputs, targets) in tqdm( enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch" ): outputs = model(inputs) loss = criterion(outputs, targets) test_loss += loss.item() * inputs.size(0) # Calculate accuracy predicted = (outputs > 0.5).float() correct += (predicted == targets).sum().item() total += targets.numel() test_loss /= len(test_loader) test_acc = correct / total if total > 0 else 0.0 return test_loss, test_acc def train_epoch( model: nn.Module, train_loader: DataLoader[ADNIDataset], val_loader: DataLoader[ADNIDataset], optimizer: torch.optim.Optimizer, criterion: nn.Module, ) -> Tuple[float, float, float, float]: """ Trains the model for one epoch and evaluates it on the validation set. Args: model (nn.Module): The model to train. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset. optimizer (torch.optim.Optimizer): Optimizer for updating model parameters. criterion (nn.Module): Loss function to compute the loss. Returns: Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy. """ model.train() train_loss = 0.0 # Training loop for _, (inputs, targets) in tqdm( enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch" ): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() * inputs.size(0) train_loss /= len(train_loader) model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for _, (inputs, targets) in tqdm( enumerate(val_loader), desc="Validation", total=len(val_loader), unit="batch", ): outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() * inputs.size(0) # Calculate accuracy predicted = (outputs > 0.5).float() correct += (predicted == targets).sum().item() total += targets.numel() val_loss /= len(val_loader) val_acc = correct / total if total > 0 else 0.0 train_acc = correct / total if total > 0 else 0.0 return train_loss, val_loss, train_acc, val_acc def train_model( model: nn.Module, train_loader: DataLoader[ADNIDataset], val_loader: DataLoader[ADNIDataset], optimizer: torch.optim.Optimizer, criterion: nn.Module, num_epochs: int, learning_rate: float, ) -> Tuple[nn.Module, xr.DataArray]: """ Trains the model using the provided training and validation data loaders. Args: model (nn.Module): The model to train. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset. num_epochs (int): Number of epochs to train the model. learning_rate (float): Learning rate for the optimizer. Returns: Result[nn.Module, str]: A Result object containing the trained model or an error message. """ # Record the training history # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy history = xr.DataArray( data=[], dims=["epoch", "metric"], coords={ "epoch": range(num_epochs), "metric": ["train_loss", "val_loss", "train_acc", "val_acc"], }, ) for epoch in range(num_epochs): train_loss, val_loss, train_acc, val_acc = train_epoch( model, train_loader, val_loader, optimizer, criterion, ) # Update the history history[ { "epoch": epoch, "metric": ["train_loss", "val_loss", "train_acc", "val_acc"], } ] = [train_loss, val_loss, train_acc, val_acc] print( f"Epoch [{epoch + 1}/{num_epochs}], " f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, " f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}" ) # If we are at 25, 50, or 75% of the epochs, save the model if (epoch + 1) % (num_epochs // 4) == 0: torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth") print(f"Model saved at epoch {epoch + 1}") # return the trained model and the traning history return model, history