training.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. from torch.utils.data import DataLoader
  6. import pandas as pd
  7. from tqdm import tqdm
  8. def train_epoch(model, train_loader, val_loader, criterion, optimizer, config, epoch):
  9. model.train()
  10. train_loss = 0
  11. val_loss = 0
  12. for i, (data, target) in tqdm(
  13. enumerate(train_loader),
  14. total=len(train_loader),
  15. desc="Epoch " + str(epoch + 1) + "/" + str(config["training"]["max_epochs"]),
  16. unit="batch",
  17. disable=config["operation"]["silent"],
  18. ):
  19. optimizer.zero_grad()
  20. output = model(data)
  21. loss = criterion(output, target)
  22. loss.backward()
  23. optimizer.step()
  24. train_loss += loss.item()
  25. train_loss /= len(train_loader)
  26. model.eval()
  27. with torch.no_grad():
  28. for i, (data, target) in enumerate(val_loader):
  29. output = model(data)
  30. loss = criterion(output, target)
  31. val_loss += loss.item()
  32. val_loss /= len(val_loader)
  33. return train_loss, val_loss
  34. def evaluate_accuracy(model, loader):
  35. model.eval()
  36. correct = 0
  37. total = 0
  38. predictions = []
  39. actual = []
  40. with torch.no_grad():
  41. for data, target in loader:
  42. output = model(data)
  43. _, predicted = torch.max(output.data, 1)
  44. total += target.size(0)
  45. correct += (predicted == target).sum().item()
  46. out = output[:, 1].tolist()
  47. predictions.extend(out)
  48. act = target[:, 1].tolist()
  49. actual.extend(act)
  50. return correct / total, predictions, actual
  51. def train_model(model, train_loader, val_loader, criterion, optimizer, config):
  52. history = pd.DataFrame(
  53. columns=["Epoch", "Train Loss", "Val Loss", "Train Acc", "Val Acc"]
  54. ).set_index("Epoch")
  55. for epoch in range(config["training"]["max_epochs"]):
  56. train_loss, val_loss = train_epoch(
  57. model, train_loader, val_loader, criterion, optimizer, config, epoch
  58. )
  59. if config["operation"]["silent"] is False:
  60. print(
  61. f"Epoch {epoch + 1} - Train Loss: {train_loss} - Val Loss: {val_loss}"
  62. )
  63. train_acc, _, _ = evaluate_accuracy(model, train_loader)
  64. val_acc, _, _ = evaluate_accuracy(model, val_loader)
  65. history.loc[epoch] = [train_loss, val_loss, train_acc, val_acc]
  66. return history