Ver código fonte

Added average loss curve in training

Ruben 5 meses atrás
pai
commit
a452f1345c
3 arquivos alterados com 36 adições e 16 exclusões
  1. 3 3
      main.py
  2. 2 2
      original_model/utils/models.py
  3. 31 11
      utils/newCNN.py

+ 3 - 3
main.py

@@ -110,14 +110,14 @@ val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuff
 #     x = x+1
 
 
-train = False
-predict = True
+train = True
+predict = False
 CNN = CNN_Net(train_dataloader, prps=properties, final_layer_size=2)
 CNN.cuda()
 
 # RUN CNN
 if(train):
-    CNN.train_model(train_dataloader, CNN_filepath, epochs=10)
+    CNN.train_model(train_dataloader, CNN_filepath, epochs=5)
     CNN.evaluate_model(val_dataloader)
 
 else:

+ 2 - 2
original_model/utils/models.py

@@ -74,8 +74,8 @@ class CNN_Net ():
         self.optimizer = self.params.optimizer
         self.model.compile(optimizer = self.optimizer, loss = 'sparse_categorical_crossentropy', metrics =['acc']) 
         self.model.summary()
-        
-        history = self.model.fit_generator (data_flow_train,
+
+        history = self.model.fit_generator(data_flow_train,
                    steps_per_epoch = train_samples/self.params.CNN_batch_size,
                    epochs = self.params.epochs,
                    callbacks = callback,

+ 31 - 11
utils/newCNN.py

@@ -6,7 +6,8 @@ import utils.newCNN_Layers as CustomLayers
 import torch.nn.functional as F
 import torch.optim as optim
 import utils.CNN_methods as CNN
-import copy
+import pandas as pd
+import matplotlib.pyplot as plt
 
 
 class CNN_Net(nn.Module):
@@ -54,13 +55,16 @@ class CNN_Net(nn.Module):
     # TRAIN
     def train_model(self, trainloader, PATH, epochs):
         self.train()
-        criterion = nn.CrossEntropyLoss()
+        criterion = nn.CrossEntropyLoss(reduction='mean')
         optimizer = optim.Adam(self.parameters(), lr=1e-5)
 
-        for epoch in epochs:  # loop over the dataset multiple times
-            print(f"Training... {epoch}/{epochs}")
+        losses = pd.DataFrame(columns=['Epoch', 'Avg_loss'])
+
+        for epoch in range(epochs+1):  # loop over the dataset multiple times
+            print(f"Epoch {epoch}/{epochs}")
             running_loss = 0.0
-            for i, data in enumerate(trainloader, 0):
+
+            for i, data in enumerate(trainloader, 0):   # loops over batches
                 # get the inputs; data is a list of [inputs, labels]
                 inputs, labels = data[0].to(self.device), data[1].to(self.device)
 
@@ -69,17 +73,34 @@ class CNN_Net(nn.Module):
 
                 # forward + backward + optimize
                 outputs = self.forward(inputs)
-                loss = criterion(outputs, labels)
+                loss = criterion(outputs, labels)   # This loss is the mean of losses for the batch
                 loss.backward()
                 optimizer.step()
 
-                # print statistics
+                # adds average batch loss to running loss
                 running_loss += loss.item()
-                if i % 2000 == 1999:  # print every 2000 mini-batches
-                    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
-                    running_loss = 0.0
+
+            avg_loss = running_loss / len(trainloader)      # Running_loss / number of batches
+            print(f"Avg. loss: {avg_loss}")
+            losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss}, ignore_index=True)
+
+
+
+            # TODO COMPUTE LOSS ON VALIDATION
+            # TODO ADD TIME PER EPOCH, CALCULATE EXPECTED REMAINING TIME
 
         print('Finished Training')
+        print(losses)
+
+        # MAKES GRAPH
+        plt.plot(losses['Epoch'], losses['Avg_loss'])
+        plt.xlabel('Epoch')
+        plt.ylabel('Average Loss')
+        plt.title('Average Loss vs Epoch On Training')
+        plt.show()
+
+        plt.savefig('avgloss_epoch_curve.png')
+
         torch.save(self.state_dict(), PATH)
 
     # TEST
@@ -96,7 +117,6 @@ class CNN_Net(nn.Module):
                 # the class with the highest energy is what we choose as prediction
                 _, predicted = torch.max(outputs.data, 1)
                 total += labels.size(0)
-                print(f"Predicted class vals: {predicted}")
                 correct += (predicted == labels).sum().item()
 
         print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')