training.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.data import DataLoader
  4. import xarray as xr
  5. from data.dataset import ADNIDataset
  6. from typing import Tuple
  7. from tqdm import tqdm
  8. type TrainMetrics = Tuple[
  9. float, float, float, float
  10. ] # (train_loss, val_loss, train_acc, val_acc)
  11. type TestMetrics = Tuple[float, float] # (test_loss, test_acc)
  12. def test_model(
  13. model: nn.Module,
  14. test_loader: DataLoader[ADNIDataset],
  15. criterion: nn.Module,
  16. ) -> TestMetrics:
  17. """
  18. Tests the model on the test dataset.
  19. Args:
  20. model (nn.Module): The model to test.
  21. test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset.
  22. criterion (nn.Module): Loss function to compute the loss.
  23. Returns:
  24. TrainMetrics: A tuple containing the test loss and test accuracy.
  25. """
  26. model.eval()
  27. test_loss = 0.0
  28. correct = 0
  29. total = 0
  30. with torch.no_grad():
  31. for _, (inputs, targets) in tqdm(
  32. enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
  33. ):
  34. outputs = model(inputs)
  35. loss = criterion(outputs, targets)
  36. test_loss += loss.item() * inputs.size(0)
  37. # Calculate accuracy
  38. predicted = (outputs > 0.5).float()
  39. correct += (predicted == targets).sum().item()
  40. total += targets.numel()
  41. test_loss /= len(test_loader)
  42. test_acc = correct / total if total > 0 else 0.0
  43. return test_loss, test_acc
  44. def train_epoch(
  45. model: nn.Module,
  46. train_loader: DataLoader[ADNIDataset],
  47. val_loader: DataLoader[ADNIDataset],
  48. optimizer: torch.optim.Optimizer,
  49. criterion: nn.Module,
  50. ) -> Tuple[float, float, float, float]:
  51. """
  52. Trains the model for one epoch and evaluates it on the validation set.
  53. Args:
  54. model (nn.Module): The model to train.
  55. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  56. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  57. optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
  58. criterion (nn.Module): Loss function to compute the loss.
  59. Returns:
  60. Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy.
  61. """
  62. model.train()
  63. train_loss = 0.0
  64. # Training loop
  65. for _, (inputs, targets) in tqdm(
  66. enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
  67. ):
  68. optimizer.zero_grad()
  69. outputs = model(inputs)
  70. loss = criterion(outputs, targets)
  71. loss.backward()
  72. optimizer.step()
  73. train_loss += loss.item() * inputs.size(0)
  74. train_loss /= len(train_loader)
  75. model.eval()
  76. val_loss = 0.0
  77. correct = 0
  78. total = 0
  79. with torch.no_grad():
  80. for _, (inputs, targets) in tqdm(
  81. enumerate(val_loader),
  82. desc="Validation",
  83. total=len(val_loader),
  84. unit="batch",
  85. ):
  86. outputs = model(inputs)
  87. loss = criterion(outputs, targets)
  88. val_loss += loss.item() * inputs.size(0)
  89. # Calculate accuracy
  90. predicted = (outputs > 0.5).float()
  91. correct += (predicted == targets).sum().item()
  92. total += targets.numel()
  93. val_loss /= len(val_loader)
  94. val_acc = correct / total if total > 0 else 0.0
  95. train_acc = correct / total if total > 0 else 0.0
  96. return train_loss, val_loss, train_acc, val_acc
  97. def train_model(
  98. model: nn.Module,
  99. train_loader: DataLoader[ADNIDataset],
  100. val_loader: DataLoader[ADNIDataset],
  101. optimizer: torch.optim.Optimizer,
  102. criterion: nn.Module,
  103. num_epochs: int,
  104. learning_rate: float,
  105. ) -> Tuple[nn.Module, xr.DataArray]:
  106. """
  107. Trains the model using the provided training and validation data loaders.
  108. Args:
  109. model (nn.Module): The model to train.
  110. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  111. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  112. num_epochs (int): Number of epochs to train the model.
  113. learning_rate (float): Learning rate for the optimizer.
  114. Returns:
  115. Result[nn.Module, str]: A Result object containing the trained model or an error message.
  116. """
  117. # Record the training history
  118. # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
  119. history = xr.DataArray(
  120. data=[],
  121. dims=["epoch", "metric"],
  122. coords={
  123. "epoch": range(num_epochs),
  124. "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
  125. },
  126. )
  127. for epoch in range(num_epochs):
  128. train_loss, val_loss, train_acc, val_acc = train_epoch(
  129. model,
  130. train_loader,
  131. val_loader,
  132. optimizer,
  133. criterion,
  134. )
  135. # Update the history
  136. history[
  137. {
  138. "epoch": epoch,
  139. "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
  140. }
  141. ] = [train_loss, val_loss, train_acc, val_acc]
  142. print(
  143. f"Epoch [{epoch + 1}/{num_epochs}], "
  144. f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
  145. f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
  146. )
  147. # If we are at 25, 50, or 75% of the epochs, save the model
  148. if (epoch + 1) % (num_epochs // 4) == 0:
  149. torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
  150. print(f"Model saved at epoch {epoch + 1}")
  151. # return the trained model and the traning history
  152. return model, history