training.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. from torch.utils.data import DataLoader
  6. import pandas as pd
  7. def train_epoch(model, train_loader, val_loader, criterion, optimizer):
  8. model.train()
  9. train_loss = 0
  10. val_loss = 0
  11. for i, (data, target) in enumerate(train_loader):
  12. optimizer.zero_grad()
  13. output = model(data)
  14. loss = criterion(output, target)
  15. loss.backward()
  16. optimizer.step()
  17. train_loss += loss.item()
  18. train_loss /= len(train_loader)
  19. model.eval()
  20. with torch.no_grad():
  21. for i, (data, target) in enumerate(val_loader):
  22. output = model(data)
  23. loss = criterion(output, target)
  24. val_loss += loss.item()
  25. val_loss /= len(val_loader)
  26. return train_loss, val_loss
  27. def evaluate_accuracy(model, loader):
  28. model.eval()
  29. correct = 0
  30. total = 0
  31. predictions = []
  32. actual = []
  33. with torch.no_grad():
  34. for data, target in loader:
  35. output = model(data)
  36. _, predicted = torch.max(output.data, 1)
  37. total += target.size(0)
  38. correct += (predicted == target).sum().item()
  39. out = output[:, 1].tolist()
  40. predictions.extend(out)
  41. act = target[:, 1].tolist()
  42. actual.extend(act)
  43. return correct / total, predictions, actual
  44. def train_model(model, train_loader, val_loader, criterion, optimizer, config):
  45. history = pd.DataFrame(columns = ["Epoch", "Train Loss", "Val Loss", "Train Acc","Val Acc"]).set_index("Epoch")
  46. for epoch in range(config["training"]["max_epochs"]):
  47. train_loss, val_loss = train_epoch(model, train_loader, val_loader, criterion, optimizer)
  48. if config["operation"]["silent"] is False: print(f"Epoch {epoch + 1} - Train Loss: {train_loss} - Val Loss: {val_loss}")
  49. train_acc, _, _ = evaluate_accuracy(model, train_loader)
  50. val_acc, _, _ = evaluate_accuracy(model, val_loader)
  51. history.loc[epoch] = [train_loss, val_loss, train_acc, val_acc]
  52. return history