|
@@ -6,7 +6,8 @@ import utils.newCNN_Layers as CustomLayers
|
|
|
import torch.nn.functional as F
|
|
|
import torch.optim as optim
|
|
|
import utils.CNN_methods as CNN
|
|
|
-import copy
|
|
|
+import pandas as pd
|
|
|
+import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
class CNN_Net(nn.Module):
|
|
@@ -54,13 +55,16 @@ class CNN_Net(nn.Module):
|
|
|
# TRAIN
|
|
|
def train_model(self, trainloader, PATH, epochs):
|
|
|
self.train()
|
|
|
- criterion = nn.CrossEntropyLoss()
|
|
|
+ criterion = nn.CrossEntropyLoss(reduction='mean')
|
|
|
optimizer = optim.Adam(self.parameters(), lr=1e-5)
|
|
|
|
|
|
- for epoch in epochs: # loop over the dataset multiple times
|
|
|
- print(f"Training... {epoch}/{epochs}")
|
|
|
+ losses = pd.DataFrame(columns=['Epoch', 'Avg_loss'])
|
|
|
+
|
|
|
+ for epoch in range(epochs+1): # loop over the dataset multiple times
|
|
|
+ print(f"Epoch {epoch}/{epochs}")
|
|
|
running_loss = 0.0
|
|
|
- for i, data in enumerate(trainloader, 0):
|
|
|
+
|
|
|
+ for i, data in enumerate(trainloader, 0): # loops over batches
|
|
|
# get the inputs; data is a list of [inputs, labels]
|
|
|
inputs, labels = data[0].to(self.device), data[1].to(self.device)
|
|
|
|
|
@@ -69,17 +73,34 @@ class CNN_Net(nn.Module):
|
|
|
|
|
|
# forward + backward + optimize
|
|
|
outputs = self.forward(inputs)
|
|
|
- loss = criterion(outputs, labels)
|
|
|
+ loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
- # print statistics
|
|
|
+ # adds average batch loss to running loss
|
|
|
running_loss += loss.item()
|
|
|
- if i % 2000 == 1999: # print every 2000 mini-batches
|
|
|
- print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
|
|
|
- running_loss = 0.0
|
|
|
+
|
|
|
+ avg_loss = running_loss / len(trainloader) # Running_loss / number of batches
|
|
|
+ print(f"Avg. loss: {avg_loss}")
|
|
|
+ losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss}, ignore_index=True)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ # TODO COMPUTE LOSS ON VALIDATION
|
|
|
+ # TODO ADD TIME PER EPOCH, CALCULATE EXPECTED REMAINING TIME
|
|
|
|
|
|
print('Finished Training')
|
|
|
+ print(losses)
|
|
|
+
|
|
|
+ # MAKES GRAPH
|
|
|
+ plt.plot(losses['Epoch'], losses['Avg_loss'])
|
|
|
+ plt.xlabel('Epoch')
|
|
|
+ plt.ylabel('Average Loss')
|
|
|
+ plt.title('Average Loss vs Epoch On Training')
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ plt.savefig('avgloss_epoch_curve.png')
|
|
|
+
|
|
|
torch.save(self.state_dict(), PATH)
|
|
|
|
|
|
# TEST
|
|
@@ -96,7 +117,6 @@ class CNN_Net(nn.Module):
|
|
|
# the class with the highest energy is what we choose as prediction
|
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
|
total += labels.size(0)
|
|
|
- print(f"Predicted class vals: {predicted}")
|
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
|
|
print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
|