training.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.data import DataLoader
  4. from data.dataset import ADNIDataset
  5. from typing import Callable, Tuple, cast
  6. from tqdm import tqdm
  7. import numpy as np
  8. import pathlib as pl
  9. import pandas as pd
  10. type TrainMetrics = Tuple[
  11. float, float, float, float
  12. ] # (train_loss, val_loss, train_acc, val_acc)
  13. type TestMetrics = Tuple[float, float] # (test_loss, test_acc)
  14. type KLLossFn = Callable[[nn.Module], torch.Tensor | None]
  15. def test_model(
  16. model: nn.Module,
  17. test_loader: DataLoader[ADNIDataset],
  18. criterion: nn.Module,
  19. ) -> TestMetrics:
  20. """
  21. Tests the model on the test dataset.
  22. Args:
  23. model (nn.Module): The model to test.
  24. test_loader (DataLoader[ADNIDataset]): DataLoader for the test dataset.
  25. criterion (nn.Module): Loss function to compute the loss.
  26. Returns:
  27. TrainMetrics: A tuple containing the test loss and test accuracy.
  28. """
  29. model.eval()
  30. test_loss = 0.0
  31. correct = 0
  32. total = 0
  33. with torch.no_grad():
  34. for _, (mri, xls, targets, _) in tqdm(
  35. enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
  36. ):
  37. outputs = model((mri, xls))
  38. loss = criterion(outputs, targets)
  39. test_loss += loss.item() * (mri.size(0) + xls.size(0))
  40. # Calculate accuracy
  41. predicted = (outputs > 0.5).float()
  42. correct += (predicted == targets).sum().item()
  43. total += targets.numel()
  44. test_loss /= len(test_loader)
  45. test_acc = correct / total if total > 0 else 0.0
  46. return test_loss, test_acc
  47. def train_epoch(
  48. model: nn.Module,
  49. train_loader: DataLoader[ADNIDataset],
  50. val_loader: DataLoader[ADNIDataset],
  51. optimizer: torch.optim.Optimizer,
  52. criterion: nn.Module,
  53. ) -> Tuple[float, float, float, float]:
  54. """
  55. Trains the model for one epoch and evaluates it on the validation set.
  56. Args:
  57. model (nn.Module): The model to train.
  58. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  59. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  60. optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
  61. criterion (nn.Module): Loss function to compute the loss.
  62. Returns:
  63. Tuple[float, float, float, float]: A tuple containing the training loss, validation loss, training accuracy, and validation accuracy.
  64. """
  65. model.train()
  66. train_loss = 0.0
  67. # Training loop
  68. for _, (mri, xls, targets, _) in tqdm(
  69. enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
  70. ):
  71. optimizer.zero_grad()
  72. outputs = model((mri, xls))
  73. loss = criterion(outputs, targets)
  74. loss.backward() # type: ignore[reportUnknownMemberType]
  75. optimizer.step()
  76. train_loss += loss.item() * (mri.size(0) + xls.size(0))
  77. train_loss /= len(train_loader)
  78. model.eval()
  79. val_loss = 0.0
  80. correct = 0
  81. total = 0
  82. with torch.no_grad():
  83. for _, (mri, xls, targets, _) in tqdm(
  84. enumerate(val_loader),
  85. desc="Validation",
  86. total=len(val_loader),
  87. unit="batch",
  88. ):
  89. outputs = model((mri, xls))
  90. loss = criterion(outputs, targets)
  91. val_loss += loss.item() * (mri.size(0) + xls.size(0))
  92. # Calculate accuracy
  93. predicted = (outputs > 0.5).float()
  94. correct += (predicted == targets).sum().item()
  95. total += targets.numel()
  96. val_loss /= len(val_loader)
  97. val_acc = correct / total if total > 0 else 0.0
  98. train_acc = correct / total if total > 0 else 0.0
  99. return train_loss, val_loss, train_acc, val_acc
  100. def train_model(
  101. model: nn.Module,
  102. train_loader: DataLoader[ADNIDataset],
  103. val_loader: DataLoader[ADNIDataset],
  104. optimizer: torch.optim.Optimizer,
  105. criterion: nn.Module,
  106. num_epochs: int,
  107. output_path: pl.Path,
  108. ) -> Tuple[nn.Module, pd.DataFrame]:
  109. """
  110. Trains the model using the provided training and validation data loaders.
  111. Args:
  112. model (nn.Module): The model to train.
  113. train_loader (DataLoader[ADNIDataset]): DataLoader for the training dataset.
  114. val_loader (DataLoader[ADNIDataset]): DataLoader for the validation dataset.
  115. num_epochs (int): Number of epochs to train the model.
  116. learning_rate (float): Learning rate for the optimizer.
  117. Returns:
  118. Result[nn.Module, str]: A Result object containing the trained model or an error message.
  119. """
  120. # Record the training history
  121. # We record the Epoch, Training Loss, Validation Loss, Training Accuracy, and Validation Accuracy
  122. # use a (num_epochs, 4) shape ndarray to store the history before creating the DataArray
  123. nhist = np.zeros((num_epochs, 4), dtype=np.float32)
  124. for epoch in range(num_epochs):
  125. train_loss, val_loss, train_acc, val_acc = train_epoch(
  126. model,
  127. train_loader,
  128. val_loader,
  129. optimizer,
  130. criterion,
  131. )
  132. # Update the history
  133. nhist[epoch, 0] = train_loss
  134. nhist[epoch, 1] = val_loss
  135. nhist[epoch, 2] = train_acc
  136. nhist[epoch, 3] = val_acc
  137. print(
  138. f"Epoch [{epoch + 1}/{num_epochs}], "
  139. f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
  140. f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
  141. )
  142. # If we are at 25, 50, or 75% of the epochs, save the model
  143. if num_epochs > 4:
  144. if (epoch + 1) % (num_epochs // 4) == 0:
  145. model_save_path = (
  146. output_path / "intermediate_models" / f"model_epoch_{epoch + 1}.pt"
  147. )
  148. torch.save(model.state_dict(), model_save_path)
  149. print(f"Model saved at epoch {epoch + 1}")
  150. # return the trained model and the traning history
  151. history = pd.DataFrame(
  152. data=nhist.astype(np.float32),
  153. columns=["train_loss", "val_loss", "train_acc", "val_acc"],
  154. index=np.arange(1, num_epochs + 1),
  155. )
  156. return model, history
  157. def test_model_bayesian(
  158. model: nn.Module,
  159. test_loader: DataLoader[ADNIDataset],
  160. criterion: nn.Module,
  161. get_kl_loss: KLLossFn,
  162. ) -> TestMetrics:
  163. """
  164. Tests a Bayesian model on the test dataset with KL-augmented loss.
  165. """
  166. model.eval()
  167. test_loss = 0.0
  168. correct = 0
  169. total = 0
  170. with torch.no_grad():
  171. for _, (mri, xls, targets, _) in tqdm(
  172. enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
  173. ):
  174. outputs = model((mri, xls))
  175. data_loss = cast(torch.Tensor, criterion(outputs, targets))
  176. batch_size = mri.size(0)
  177. kl_term = get_kl_loss(model)
  178. kl_loss = (
  179. kl_term / batch_size
  180. if kl_term is not None
  181. else torch.tensor(0.0, device=outputs.device)
  182. )
  183. loss: torch.Tensor = data_loss + kl_loss
  184. test_loss += loss.item() * (mri.size(0) + xls.size(0))
  185. predicted = (outputs > 0.5).float()
  186. correct += (predicted == targets).sum().item()
  187. total += targets.numel()
  188. test_loss /= len(test_loader)
  189. test_acc = correct / total if total > 0 else 0.0
  190. return test_loss, test_acc
  191. def train_epoch_bayesian(
  192. model: nn.Module,
  193. train_loader: DataLoader[ADNIDataset],
  194. val_loader: DataLoader[ADNIDataset],
  195. optimizer: torch.optim.Optimizer,
  196. criterion: nn.Module,
  197. get_kl_loss: KLLossFn,
  198. ) -> TrainMetrics:
  199. """
  200. Trains a Bayesian model for one epoch and evaluates on validation data.
  201. """
  202. model.train()
  203. train_loss = 0.0
  204. for _, (mri, xls, targets, _) in tqdm(
  205. enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
  206. ):
  207. optimizer.zero_grad()
  208. outputs = model((mri, xls))
  209. data_loss = cast(torch.Tensor, criterion(outputs, targets))
  210. batch_size = mri.size(0)
  211. kl_term = get_kl_loss(model)
  212. kl_loss = (
  213. kl_term / batch_size
  214. if kl_term is not None
  215. else torch.tensor(0.0, device=outputs.device)
  216. )
  217. loss: torch.Tensor = data_loss + kl_loss
  218. loss.backward() # type: ignore[reportUnknownMemberType]
  219. optimizer.step()
  220. train_loss += loss.item() * (mri.size(0) + xls.size(0))
  221. train_loss /= len(train_loader)
  222. model.eval()
  223. val_loss = 0.0
  224. correct = 0
  225. total = 0
  226. with torch.no_grad():
  227. for _, (mri, xls, targets, _) in tqdm(
  228. enumerate(val_loader),
  229. desc="Validation",
  230. total=len(val_loader),
  231. unit="batch",
  232. ):
  233. outputs = model((mri, xls))
  234. data_loss = cast(torch.Tensor, criterion(outputs, targets))
  235. batch_size = mri.size(0)
  236. kl_term = get_kl_loss(model)
  237. kl_loss = (
  238. kl_term / batch_size
  239. if kl_term is not None
  240. else torch.tensor(0.0, device=outputs.device)
  241. )
  242. loss: torch.Tensor = data_loss + kl_loss
  243. val_loss += loss.item() * (mri.size(0) + xls.size(0))
  244. predicted = (outputs > 0.5).float()
  245. correct += (predicted == targets).sum().item()
  246. total += targets.numel()
  247. val_loss /= len(val_loader)
  248. val_acc = correct / total if total > 0 else 0.0
  249. train_acc = correct / total if total > 0 else 0.0
  250. return train_loss, val_loss, train_acc, val_acc
  251. def train_model_bayesian(
  252. model: nn.Module,
  253. train_loader: DataLoader[ADNIDataset],
  254. val_loader: DataLoader[ADNIDataset],
  255. optimizer: torch.optim.Optimizer,
  256. criterion: nn.Module,
  257. num_epochs: int,
  258. output_path: pl.Path,
  259. get_kl_loss: KLLossFn,
  260. ) -> Tuple[nn.Module, pd.DataFrame]:
  261. """
  262. Trains a Bayesian model with KL-augmented objective.
  263. """
  264. nhist = np.zeros((num_epochs, 4), dtype=np.float32)
  265. for epoch in range(num_epochs):
  266. train_loss, val_loss, train_acc, val_acc = train_epoch_bayesian(
  267. model,
  268. train_loader,
  269. val_loader,
  270. optimizer,
  271. criterion,
  272. get_kl_loss,
  273. )
  274. nhist[epoch, 0] = train_loss
  275. nhist[epoch, 1] = val_loss
  276. nhist[epoch, 2] = train_acc
  277. nhist[epoch, 3] = val_acc
  278. print(
  279. f"Epoch [{epoch + 1}/{num_epochs}], "
  280. f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
  281. f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
  282. )
  283. if num_epochs > 4:
  284. if (epoch + 1) % (num_epochs // 4) == 0:
  285. model_save_path = (
  286. output_path / "intermediate_models" / f"model_epoch_{epoch + 1}.pt"
  287. )
  288. torch.save(model.state_dict(), model_save_path)
  289. print(f"Model saved at epoch {epoch + 1}")
  290. history = pd.DataFrame(
  291. data=nhist.astype(np.float32),
  292. columns=["train_loss", "val_loss", "train_acc", "val_acc"],
  293. index=np.arange(1, num_epochs + 1),
  294. )
  295. return model, history