12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torchvision
- from torch.utils.data import DataLoader
- import pandas as pd
- def train_epoch(model, train_loader, val_loader, criterion, optimizer):
- model.train()
- train_loss = 0
- val_loss = 0
- for i, (data, target) in enumerate(train_loader):
- optimizer.zero_grad()
-
- output = model(data)
- loss = criterion(output, target)
- loss.backward()
- optimizer.step()
-
- train_loss += loss.item()
-
- train_loss /= len(train_loader)
-
- model.eval()
- with torch.no_grad():
- for i, (data, target) in enumerate(val_loader):
- output = model(data)
- loss = criterion(output, target)
- val_loss += loss.item()
- val_loss /= len(val_loader)
-
- return train_loss, val_loss
- def evaluate_accuracy(model, loader):
- model.eval()
- correct = 0
- total = 0
- predictions = []
- actual = []
-
- with torch.no_grad():
- for data, target in loader:
- output = model(data)
- _, predicted = torch.max(output.data, 1)
- total += target.size(0)
- correct += (predicted == target).sum().item()
-
- out = output[:, 1].tolist()
- predictions.extend(out)
-
- act = target[:, 1].tolist()
- actual.extend(act)
-
- return correct / total, predictions, actual
- def train_model(model, train_loader, val_loader, criterion, optimizer, config):
-
- history = pd.DataFrame(columns = ["Epoch", "Train Loss", "Val Loss", "Train Acc","Val Acc"]).set_index("Epoch")
-
-
- for epoch in range(config["training"]["max_epochs"]):
- train_loss, val_loss = train_epoch(model, train_loader, val_loader, criterion, optimizer)
- if config["operation"]["silent"] is False: print(f"Epoch {epoch + 1} - Train Loss: {train_loss} - Val Loss: {val_loss}")
-
- train_acc, _, _ = evaluate_accuracy(model, train_loader)
- val_acc, _, _ = evaluate_accuracy(model, val_loader)
-
- history.loc[epoch] = [train_loss, val_loss, train_acc, val_acc]
-
- return history
-
-
-
-
|