training.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.data import DataLoader
  4. import xarray as xr
  5. from data.dataset import ADNIDataset
  6. from typing import Tuple
  7. from tqdm import tqdm
  8. import numpy as np
  9. import pathlib as pl
  10. import pandas as pd
  11. type TrainMetrics = Tuple[
  12. float, float, float, float
  13. ] # (train_loss, val_loss, train_acc, val_acc)
  14. type TestMetrics = Tuple[float, float] # (test_loss, test_acc)
  15. def test_model(
  16. model: nn.Module,
  17. test_loader: DataLoader[ADNIDataset],
  18. criterion: nn.Module,
  19. ) -> TestMetrics:
  20. """
  21. Tests the model on the test dataset.
  22. Args:
  23. model (nn.Module): The model to test.
  24. test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset.
  25. criterion (nn.Module): Loss function to compute the loss.
  26. Returns:
  27. TrainMetrics: A tuple containing the test loss and test accuracy.
  28. """
  29. model.eval()
  30. test_loss = 0.0
  31. correct = 0
  32. total = 0
  33. with torch.no_grad():
  34. for _, (mri, xls, targets) in tqdm(
  35. enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
  36. ):
  37. outputs = model((mri, xls))
  38. loss = criterion(outputs, targets)
  39. test_loss += loss.item() * (mri.size(0) + xls.size(0))
  40. # Calculate accuracy
  41. predicted = (outputs > 0.5).float()
  42. correct += (predicted == targets).sum().item()
  43. total += targets.numel()
  44. test_loss /= len(test_loader)
  45. test_acc = correct / total if total > 0 else 0.0
  46. return test_loss, test_acc
  47. def train_epoch(
  48. model: nn.Module,
  49. train_loader: DataLoader[ADNIDataset],
  50. val_loader: DataLoader[ADNIDataset],
  51. optimizer: torch.optim.Optimizer,
  52. criterion: nn.Module,
  53. ) -> Tuple[float, float, float, float]:
  54. """
  55. Trains the model for one epoch and evaluates it on the validation set.
  56. Args:
  57. model (nn.Module): The model to train.
  58. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  59. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  60. optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
  61. criterion (nn.Module): Loss function to compute the loss.
  62. Returns:
  63. Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy.
  64. """
  65. model.train()
  66. train_loss = 0.0
  67. # Training loop
  68. for _, (mri, xls, targets) in tqdm(
  69. enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
  70. ):
  71. optimizer.zero_grad()
  72. outputs = model((mri, xls))
  73. loss = criterion(outputs, targets)
  74. loss.backward()
  75. optimizer.step()
  76. train_loss += loss.item() * (mri.size(0) + xls.size(0))
  77. train_loss /= len(train_loader)
  78. model.eval()
  79. val_loss = 0.0
  80. correct = 0
  81. total = 0
  82. with torch.no_grad():
  83. for _, (mri, xls, targets) in tqdm(
  84. enumerate(val_loader),
  85. desc="Validation",
  86. total=len(val_loader),
  87. unit="batch",
  88. ):
  89. outputs = model((mri, xls))
  90. loss = criterion(outputs, targets)
  91. val_loss += loss.item() * (mri.size(0) + xls.size(0))
  92. # Calculate accuracy
  93. predicted = (outputs > 0.5).float()
  94. correct += (predicted == targets).sum().item()
  95. total += targets.numel()
  96. val_loss /= len(val_loader)
  97. val_acc = correct / total if total > 0 else 0.0
  98. train_acc = correct / total if total > 0 else 0.0
  99. return train_loss, val_loss, train_acc, val_acc
  100. def train_model(
  101. model: nn.Module,
  102. train_loader: DataLoader[ADNIDataset],
  103. val_loader: DataLoader[ADNIDataset],
  104. optimizer: torch.optim.Optimizer,
  105. criterion: nn.Module,
  106. num_epochs: int,
  107. output_path: pl.Path,
  108. ) -> Tuple[nn.Module, pd.DataFrame]:
  109. """
  110. Trains the model using the provided training and validation data loaders.
  111. Args:
  112. model (nn.Module): The model to train.
  113. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  114. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  115. num_epochs (int): Number of epochs to train the model.
  116. learning_rate (float): Learning rate for the optimizer.
  117. Returns:
  118. Result[nn.Module, str]: A Result object containing the trained model or an error message.
  119. """
  120. # Record the training history
  121. # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
  122. # use a (num_epochs, 4) shape ndarray to store the history before creating the DataArray
  123. nhist = np.zeros((num_epochs, 4), dtype=np.float32)
  124. for epoch in range(num_epochs):
  125. train_loss, val_loss, train_acc, val_acc = train_epoch(
  126. model,
  127. train_loader,
  128. val_loader,
  129. optimizer,
  130. criterion,
  131. )
  132. # Update the history
  133. nhist[epoch, 0] = train_loss
  134. nhist[epoch, 1] = val_loss
  135. nhist[epoch, 2] = train_acc
  136. nhist[epoch, 3] = val_acc
  137. print(
  138. f"Epoch [{epoch + 1}/{num_epochs}], "
  139. f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
  140. f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
  141. )
  142. # If we are at 25, 50, or 75% of the epochs, save the model
  143. if num_epochs > 4:
  144. if (epoch + 1) % (num_epochs // 4) == 0:
  145. model_save_path = (
  146. output_path / "intermediate_models" / f"model_epoch_{epoch + 1}.pt"
  147. )
  148. torch.save(model.state_dict(), model_save_path)
  149. print(f"Model saved at epoch {epoch + 1}")
  150. # return the trained model and the traning history
  151. history = pd.DataFrame(
  152. data=nhist.astype(np.float32),
  153. columns=["train_loss", "val_loss", "train_acc", "val_acc"],
  154. index=np.arange(1, num_epochs + 1),
  155. )
  156. return model, history