123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- import xarray as xr
- from data.dataset import ADNIDataset
- from typing import Tuple
- from tqdm import tqdm
- import numpy as np
- type TrainMetrics = Tuple[
- float, float, float, float
- ] # (train_loss, val_loss, train_acc, val_acc)
- type TestMetrics = Tuple[float, float] # (test_loss, test_acc)
- def test_model(
- model: nn.Module,
- test_loader: DataLoader[ADNIDataset],
- criterion: nn.Module,
- ) -> TestMetrics:
- """
- Tests the model on the test dataset.
- Args:
- model (nn.Module): The model to test.
- test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset.
- criterion (nn.Module): Loss function to compute the loss.
- Returns:
- TrainMetrics: A tuple containing the test loss and test accuracy.
- """
- model.eval()
- test_loss = 0.0
- correct = 0
- total = 0
- with torch.no_grad():
- for _, (mri, xls, targets) in tqdm(
- enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
- ):
- outputs = model((mri, xls))
- loss = criterion(outputs, targets)
- test_loss += loss.item() * (mri.size(0) + xls.size(0))
- # Calculate accuracy
- predicted = (outputs > 0.5).float()
- correct += (predicted == targets).sum().item()
- total += targets.numel()
- test_loss /= len(test_loader)
- test_acc = correct / total if total > 0 else 0.0
- return test_loss, test_acc
- def train_epoch(
- model: nn.Module,
- train_loader: DataLoader[ADNIDataset],
- val_loader: DataLoader[ADNIDataset],
- optimizer: torch.optim.Optimizer,
- criterion: nn.Module,
- ) -> Tuple[float, float, float, float]:
- """
- Trains the model for one epoch and evaluates it on the validation set.
- Args:
- model (nn.Module): The model to train.
- train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
- val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
- optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
- criterion (nn.Module): Loss function to compute the loss.
- Returns:
- Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy.
- """
- model.train()
- train_loss = 0.0
- # Training loop
- for _, (mri, xls, targets) in tqdm(
- enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
- ):
- optimizer.zero_grad()
- outputs = model((mri, xls))
- loss = criterion(outputs, targets)
- loss.backward()
- optimizer.step()
- train_loss += loss.item() * (mri.size(0) + xls.size(0))
- train_loss /= len(train_loader)
- model.eval()
- val_loss = 0.0
- correct = 0
- total = 0
- with torch.no_grad():
- for _, (mri, xls, targets) in tqdm(
- enumerate(val_loader),
- desc="Validation",
- total=len(val_loader),
- unit="batch",
- ):
- outputs = model((mri, xls))
- loss = criterion(outputs, targets)
- val_loss += loss.item() * (mri.size(0) + xls.size(0))
- # Calculate accuracy
- predicted = (outputs > 0.5).float()
- correct += (predicted == targets).sum().item()
- total += targets.numel()
- val_loss /= len(val_loader)
- val_acc = correct / total if total > 0 else 0.0
- train_acc = correct / total if total > 0 else 0.0
- return train_loss, val_loss, train_acc, val_acc
- def train_model(
- model: nn.Module,
- train_loader: DataLoader[ADNIDataset],
- val_loader: DataLoader[ADNIDataset],
- optimizer: torch.optim.Optimizer,
- criterion: nn.Module,
- num_epochs: int,
- learning_rate: float,
- ) -> Tuple[nn.Module, xr.DataArray]:
- """
- Trains the model using the provided training and validation data loaders.
- Args:
- model (nn.Module): The model to train.
- train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
- val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
- num_epochs (int): Number of epochs to train the model.
- learning_rate (float): Learning rate for the optimizer.
- Returns:
- Result[nn.Module, str]: A Result object containing the trained model or an error message.
- """
- # Record the training history
- # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
- # use a (num_epochs, 4) shape ndarray to store the history before creating the DataArray
- nhist = np.zeros((num_epochs, 4), dtype=np.float32)
- for epoch in range(num_epochs):
- train_loss, val_loss, train_acc, val_acc = train_epoch(
- model,
- train_loader,
- val_loader,
- optimizer,
- criterion,
- )
- # Update the history
- nhist[epoch, 0] = train_loss
- nhist[epoch, 1] = val_loss
- nhist[epoch, 2] = train_acc
- nhist[epoch, 3] = val_acc
- print(
- f"Epoch [{epoch + 1}/{num_epochs}], "
- f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
- f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
- )
- # If we are at 25, 50, or 75% of the epochs, save the model
- if num_epochs > 4:
- if (epoch + 1) % (num_epochs // 4) == 0:
- torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
- print(f"Model saved at epoch {epoch + 1}")
- # return the trained model and the traning history
- history = xr.DataArray(
- data=nhist,
- dims=["epoch", "metric"],
- coords={
- "epoch": range(num_epochs),
- "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
- },
- )
- return model, history
|