|
@@ -0,0 +1,180 @@
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+import xarray as xr
|
|
|
+from data.dataset import ADNIDataset
|
|
|
+from typing import Tuple
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+type TrainMetrics = Tuple[
|
|
|
+ float, float, float, float
|
|
|
+] # (train_loss, val_loss, train_acc, val_acc)
|
|
|
+
|
|
|
+type TestMetrics = Tuple[float, float] # (test_loss, test_acc)
|
|
|
+
|
|
|
+
|
|
|
+def test_model(
|
|
|
+ model: nn.Module,
|
|
|
+ test_loader: DataLoader[ADNIDataset],
|
|
|
+ criterion: nn.Module,
|
|
|
+) -> TestMetrics:
|
|
|
+ """
|
|
|
+ Tests the model on the test dataset.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model (nn.Module): The model to test.
|
|
|
+ test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset.
|
|
|
+ criterion (nn.Module): Loss function to compute the loss.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ TrainMetrics: A tuple containing the test loss and test accuracy.
|
|
|
+ """
|
|
|
+ model.eval()
|
|
|
+ test_loss = 0.0
|
|
|
+ correct = 0
|
|
|
+ total = 0
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ for _, (inputs, targets) in tqdm(
|
|
|
+ enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
|
|
|
+ ):
|
|
|
+ outputs = model(inputs)
|
|
|
+ loss = criterion(outputs, targets)
|
|
|
+ test_loss += loss.item() * inputs.size(0)
|
|
|
+
|
|
|
+ # Calculate accuracy
|
|
|
+ predicted = (outputs > 0.5).float()
|
|
|
+ correct += (predicted == targets).sum().item()
|
|
|
+ total += targets.numel()
|
|
|
+
|
|
|
+ test_loss /= len(test_loader)
|
|
|
+ test_acc = correct / total if total > 0 else 0.0
|
|
|
+ return test_loss, test_acc
|
|
|
+
|
|
|
+
|
|
|
+def train_epoch(
|
|
|
+ model: nn.Module,
|
|
|
+ train_loader: DataLoader[ADNIDataset],
|
|
|
+ val_loader: DataLoader[ADNIDataset],
|
|
|
+ optimizer: torch.optim.Optimizer,
|
|
|
+ criterion: nn.Module,
|
|
|
+) -> Tuple[float, float, float, float]:
|
|
|
+ """
|
|
|
+ Trains the model for one epoch and evaluates it on the validation set.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model (nn.Module): The model to train.
|
|
|
+ train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
|
|
|
+ val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
|
|
|
+ optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
|
|
|
+ criterion (nn.Module): Loss function to compute the loss.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy.
|
|
|
+ """
|
|
|
+ model.train()
|
|
|
+ train_loss = 0.0
|
|
|
+
|
|
|
+ # Training loop
|
|
|
+ for _, (inputs, targets) in tqdm(
|
|
|
+ enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
|
|
|
+ ):
|
|
|
+ optimizer.zero_grad()
|
|
|
+ outputs = model(inputs)
|
|
|
+ loss = criterion(outputs, targets)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ train_loss += loss.item() * inputs.size(0)
|
|
|
+ train_loss /= len(train_loader)
|
|
|
+
|
|
|
+ model.eval()
|
|
|
+ val_loss = 0.0
|
|
|
+ correct = 0
|
|
|
+ total = 0
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ for _, (inputs, targets) in tqdm(
|
|
|
+ enumerate(val_loader),
|
|
|
+ desc="Validation",
|
|
|
+ total=len(val_loader),
|
|
|
+ unit="batch",
|
|
|
+ ):
|
|
|
+ outputs = model(inputs)
|
|
|
+ loss = criterion(outputs, targets)
|
|
|
+ val_loss += loss.item() * inputs.size(0)
|
|
|
+
|
|
|
+ # Calculate accuracy
|
|
|
+ predicted = (outputs > 0.5).float()
|
|
|
+ correct += (predicted == targets).sum().item()
|
|
|
+ total += targets.numel()
|
|
|
+
|
|
|
+ val_loss /= len(val_loader)
|
|
|
+ val_acc = correct / total if total > 0 else 0.0
|
|
|
+ train_acc = correct / total if total > 0 else 0.0
|
|
|
+ return train_loss, val_loss, train_acc, val_acc
|
|
|
+
|
|
|
+
|
|
|
+def train_model(
|
|
|
+ model: nn.Module,
|
|
|
+ train_loader: DataLoader[ADNIDataset],
|
|
|
+ val_loader: DataLoader[ADNIDataset],
|
|
|
+ optimizer: torch.optim.Optimizer,
|
|
|
+ criterion: nn.Module,
|
|
|
+ num_epochs: int,
|
|
|
+ learning_rate: float,
|
|
|
+) -> Tuple[nn.Module, xr.DataArray]:
|
|
|
+ """
|
|
|
+ Trains the model using the provided training and validation data loaders.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model (nn.Module): The model to train.
|
|
|
+ train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
|
|
|
+ val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
|
|
|
+ num_epochs (int): Number of epochs to train the model.
|
|
|
+ learning_rate (float): Learning rate for the optimizer.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Result[nn.Module, str]: A Result object containing the trained model or an error message.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Record the training history
|
|
|
+ # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
|
|
|
+ history = xr.DataArray(
|
|
|
+ data=[],
|
|
|
+ dims=["epoch", "metric"],
|
|
|
+ coords={
|
|
|
+ "epoch": range(num_epochs),
|
|
|
+ "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ for epoch in range(num_epochs):
|
|
|
+ train_loss, val_loss, train_acc, val_acc = train_epoch(
|
|
|
+ model,
|
|
|
+ train_loader,
|
|
|
+ val_loader,
|
|
|
+ optimizer,
|
|
|
+ criterion,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Update the history
|
|
|
+ history[
|
|
|
+ {
|
|
|
+ "epoch": epoch,
|
|
|
+ "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
|
|
|
+ }
|
|
|
+ ] = [train_loss, val_loss, train_acc, val_acc]
|
|
|
+
|
|
|
+ print(
|
|
|
+ f"Epoch [{epoch + 1}/{num_epochs}], "
|
|
|
+ f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
|
|
|
+ f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # If we are at 25, 50, or 75% of the epochs, save the model
|
|
|
+ if (epoch + 1) % (num_epochs // 4) == 0:
|
|
|
+ torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
|
|
|
+ print(f"Model saved at epoch {epoch + 1}")
|
|
|
+
|
|
|
+ # return the trained model and the traning history
|
|
|
+ return model, history
|