Bläddra i källkod

Actually working now

Nicholas Schense 1 år sedan
förälder
incheckning
af17419096
1 ändrade filer med 28 tillägg och 4 borttagningar
  1. 28 4
      main.py

+ 28 - 4
main.py

@@ -40,7 +40,7 @@ print("Pytorch Version: " + torch. __version__)
 # data & training properties:
 val_split = 0.2     # % of val and test, rest will be train
 runs = 1
-epochs = 10
+epochs = 5
 time_stamp = timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
 seeds = [np.random.randint(0, 1000) for _ in range(runs)]
 
@@ -53,6 +53,7 @@ local_path = '/export/home/nschense/alzheimers/Pytorch_CNN-RNN'
 xls_path = local_path + '/LP_ADNIMERGE.csv'
 saved_model_path = local_path + 'saved_models/'
 
+DEBUG = False
 
 # TODO: Datasets include multiple labels, such as medical info
 
@@ -99,6 +100,9 @@ def evaluate_model(seed):
 
             outputs = model_CNN((mri, xls))
 
+            if DEBUG:
+                print(outputs.shape, label.shape)
+
             loss = criterion(outputs, label)
             loss.backward()
             optimizer.step()
@@ -110,15 +114,35 @@ def evaluate_model(seed):
                 running_loss = 0.0
         epoch_number += 1
 
+
+    print("--- TESTING MODEL ---")
     #Test model
     correct = 0
     total = 0
 
     with torch.no_grad():
-        for data in test_dataloader:
-            images, labels = data
-            outputs = model_CNN(images)
+        length = len(test_dataloader)
+        for i, data in tqdm(enumerate(test_dataloader, 0), total=length, desc="Testing", unit="batch"):
+            mri, xls, label = data
+
+            mri = mri.to(cuda_device).float()
+            xls = xls.to(cuda_device).float()
+            label = label.to(cuda_device).float()
+
+
+            outputs = model_CNN((mri, xls))
+
+            if DEBUG:
+                print(outputs.shape, label.shape)
+
             _, predicted = torch.max(outputs.data, 1)
+            _, labels = torch.max(label.data, 1)
+
+            if DEBUG:
+                print("Predicted: ", predicted)
+                print("Labels: ", labels)
+
+
             total += labels.size(0)
             correct += (predicted == labels).sum().item()