training.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.data import DataLoader
  4. import xarray as xr
  5. from data.dataset import ADNIDataset
  6. from typing import Tuple
  7. from tqdm import tqdm
  8. import numpy as np
  9. type TrainMetrics = Tuple[
  10. float, float, float, float
  11. ] # (train_loss, val_loss, train_acc, val_acc)
  12. type TestMetrics = Tuple[float, float] # (test_loss, test_acc)
  13. def test_model(
  14. model: nn.Module,
  15. test_loader: DataLoader[ADNIDataset],
  16. criterion: nn.Module,
  17. ) -> TestMetrics:
  18. """
  19. Tests the model on the test dataset.
  20. Args:
  21. model (nn.Module): The model to test.
  22. test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset.
  23. criterion (nn.Module): Loss function to compute the loss.
  24. Returns:
  25. TrainMetrics: A tuple containing the test loss and test accuracy.
  26. """
  27. model.eval()
  28. test_loss = 0.0
  29. correct = 0
  30. total = 0
  31. with torch.no_grad():
  32. for _, (mri, xls, targets) in tqdm(
  33. enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
  34. ):
  35. outputs = model((mri, xls))
  36. loss = criterion(outputs, targets)
  37. test_loss += loss.item() * (mri.size(0) + xls.size(0))
  38. # Calculate accuracy
  39. predicted = (outputs > 0.5).float()
  40. correct += (predicted == targets).sum().item()
  41. total += targets.numel()
  42. test_loss /= len(test_loader)
  43. test_acc = correct / total if total > 0 else 0.0
  44. return test_loss, test_acc
  45. def train_epoch(
  46. model: nn.Module,
  47. train_loader: DataLoader[ADNIDataset],
  48. val_loader: DataLoader[ADNIDataset],
  49. optimizer: torch.optim.Optimizer,
  50. criterion: nn.Module,
  51. ) -> Tuple[float, float, float, float]:
  52. """
  53. Trains the model for one epoch and evaluates it on the validation set.
  54. Args:
  55. model (nn.Module): The model to train.
  56. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  57. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  58. optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
  59. criterion (nn.Module): Loss function to compute the loss.
  60. Returns:
  61. Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy.
  62. """
  63. model.train()
  64. train_loss = 0.0
  65. # Training loop
  66. for _, (mri, xls, targets) in tqdm(
  67. enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
  68. ):
  69. optimizer.zero_grad()
  70. outputs = model((mri, xls))
  71. loss = criterion(outputs, targets)
  72. loss.backward()
  73. optimizer.step()
  74. train_loss += loss.item() * (mri.size(0) + xls.size(0))
  75. train_loss /= len(train_loader)
  76. model.eval()
  77. val_loss = 0.0
  78. correct = 0
  79. total = 0
  80. with torch.no_grad():
  81. for _, (mri, xls, targets) in tqdm(
  82. enumerate(val_loader),
  83. desc="Validation",
  84. total=len(val_loader),
  85. unit="batch",
  86. ):
  87. outputs = model((mri, xls))
  88. loss = criterion(outputs, targets)
  89. val_loss += loss.item() * (mri.size(0) + xls.size(0))
  90. # Calculate accuracy
  91. predicted = (outputs > 0.5).float()
  92. correct += (predicted == targets).sum().item()
  93. total += targets.numel()
  94. val_loss /= len(val_loader)
  95. val_acc = correct / total if total > 0 else 0.0
  96. train_acc = correct / total if total > 0 else 0.0
  97. return train_loss, val_loss, train_acc, val_acc
  98. def train_model(
  99. model: nn.Module,
  100. train_loader: DataLoader[ADNIDataset],
  101. val_loader: DataLoader[ADNIDataset],
  102. optimizer: torch.optim.Optimizer,
  103. criterion: nn.Module,
  104. num_epochs: int,
  105. learning_rate: float,
  106. ) -> Tuple[nn.Module, xr.DataArray]:
  107. """
  108. Trains the model using the provided training and validation data loaders.
  109. Args:
  110. model (nn.Module): The model to train.
  111. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  112. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  113. num_epochs (int): Number of epochs to train the model.
  114. learning_rate (float): Learning rate for the optimizer.
  115. Returns:
  116. Result[nn.Module, str]: A Result object containing the trained model or an error message.
  117. """
  118. # Record the training history
  119. # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
  120. # use a (num_epochs, 4) shape ndarray to store the history before creating the DataArray
  121. nhist = np.zeros((num_epochs, 4), dtype=np.float32)
  122. for epoch in range(num_epochs):
  123. train_loss, val_loss, train_acc, val_acc = train_epoch(
  124. model,
  125. train_loader,
  126. val_loader,
  127. optimizer,
  128. criterion,
  129. )
  130. # Update the history
  131. nhist[epoch, 0] = train_loss
  132. nhist[epoch, 1] = val_loss
  133. nhist[epoch, 2] = train_acc
  134. nhist[epoch, 3] = val_acc
  135. print(
  136. f"Epoch [{epoch + 1}/{num_epochs}], "
  137. f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
  138. f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
  139. )
  140. # If we are at 25, 50, or 75% of the epochs, save the model
  141. if num_epochs > 4:
  142. if (epoch + 1) % (num_epochs // 4) == 0:
  143. torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
  144. print(f"Model saved at epoch {epoch + 1}")
  145. # return the trained model and the traning history
  146. history = xr.DataArray(
  147. data=nhist,
  148. dims=["epoch", "metric"],
  149. coords={
  150. "epoch": range(num_epochs),
  151. "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
  152. },
  153. )
  154. return model, history