training.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import torch
  2. import pandas as pd
  3. from tqdm import tqdm
  4. def train_epoch(model, train_loader, val_loader, criterion, optimizer, config, epoch):
  5. model.train()
  6. train_loss = 0
  7. val_loss = 0
  8. for i, (data, target) in tqdm(
  9. enumerate(train_loader),
  10. total=len(train_loader),
  11. desc=" - Epoch " + str(epoch + 1) + "/" + str(config["training"]["max_epochs"]),
  12. unit="batch",
  13. disable=config["operation"]["silent"],
  14. ):
  15. optimizer.zero_grad()
  16. output = model(data)
  17. loss = criterion(output, target)
  18. loss.backward()
  19. optimizer.step()
  20. train_loss += loss.item()
  21. train_loss /= len(train_loader)
  22. model.eval()
  23. with torch.no_grad():
  24. for i, (data, target) in enumerate(val_loader):
  25. output = model(data)
  26. loss = criterion(output, target)
  27. val_loss += loss.item()
  28. val_loss /= len(val_loader)
  29. return train_loss, val_loss
  30. def evaluate_accuracy(model, loader):
  31. model.eval()
  32. correct = 0
  33. total = 0
  34. predictions = []
  35. actual = []
  36. with torch.no_grad():
  37. for data, target in loader:
  38. output = model(data)
  39. _, predicted = torch.max(output.data, 1)
  40. _, expected = torch.max(target.data, 1)
  41. total += target.size(0)
  42. correct += (predicted == expected).sum().item()
  43. out = output[:, 1].tolist()
  44. predictions.extend(out)
  45. act = target[:, 1].tolist()
  46. actual.extend(act)
  47. return correct / total, predictions, actual
  48. def train_model(model, train_loader, val_loader, criterion, optimizer, config):
  49. history = pd.DataFrame(
  50. columns=["Epoch", "Train Loss", "Val Loss", "Train Acc", "Val Acc"]
  51. ).set_index("Epoch")
  52. for epoch in range(config["training"]["max_epochs"]):
  53. train_loss, val_loss = train_epoch(
  54. model, train_loader, val_loader, criterion, optimizer, config, epoch
  55. )
  56. train_acc, _, _ = evaluate_accuracy(model, train_loader)
  57. val_acc, _, _ = evaluate_accuracy(model, val_loader)
  58. if config["operation"]["silent"] is False:
  59. print(
  60. f" --- Epoch {epoch + 1} - Train Loss: {round(train_loss, 3)}, Val Loss: {round(val_loss, 3)}, Train Acc: {round(train_acc, 3) * 100}%, Val Acc: {round(val_acc, 3) * 100}%"
  61. )
  62. history.loc[epoch] = [train_loss, val_loss, train_acc, val_acc]
  63. return history