Ver código fonte

Added saving of images and predictions

Nicholas Schense 10 meses atrás
pai
commit
dd8d9e5f98
4 arquivos alterados com 74 adições e 19 exclusões
  1. 1 1
      config.toml
  2. 2 2
      main.py
  3. 8 5
      test_models.py
  4. 63 11
      utils/training.py

+ 1 - 1
config.toml

@@ -15,7 +15,7 @@ device = 1
 
 [training]
 batch_size = 64
-epochs = 10
+epochs = 30
 learning_rate = 0.0001
 runs = 1
 

+ 2 - 2
main.py

@@ -55,7 +55,7 @@ for seed in seeds:
     time_stamp = datetime.now().strftime('%Y%m%d+%H%M%S')
 
     #initialize dataloaders, train model, and test model
-    train_loader, val_loader, test_loader = initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=cuda_device, batch_size=config['training']['batch_size'])
+    train_loader, val_loader, test_loader, _ = initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=cuda_device, batch_size=config['training']['batch_size'])
     
 
     print("--- TRAINING MODEL ---")
@@ -64,7 +64,7 @@ for seed in seeds:
         
     train_results = train_model(model_CNN, seed, time_stamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=cuda_device)
     print("--- TESTING MODEL ---")
-    predicted, actual, correct, incorrect = test_model(model_CNN, test_loader, cuda_device=cuda_device)
+    predicted, actual, correct, incorrect, _, _ = test_model(model_CNN, test_loader, cuda_device=cuda_device)
     
     print("Accuracy: " + str(correct / (correct + incorrect)))
     

+ 8 - 5
test_models.py

@@ -1,5 +1,5 @@
 print("--- INITIALIZING LIBRARIES ---")
-from utils.training import train_model, test_model, initalize_dataloaders, plot_confusion_matrix, plot_roc_curve
+from utils.training import train_model, test_model, initalize_dataloaders, plot_confusion_matrix, plot_roc_curve, plot_image_selection
 import tomli as tl
 import torch
 import os
@@ -15,7 +15,7 @@ else:
     with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
         config = tl.load(f)
         
-cuda_device = torch.device(config['cuda']['device'])
+device = torch.device(config['cuda']['device'])
 
 #For each file in the model directory, run model tests and save results
 plot_path = config['paths']['plot_output']
@@ -40,17 +40,20 @@ for model_file in model_files:
     print("  - Testing Model: " + timestamp + ", Seed: ", seed)
     print("    * Loading Dataset")
     
-    _, _, test_loader = initalize_dataloaders(config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed, cuda_device=torch.device('cpu'), batch_size=config['training']['batch_size'])
+    _, _, test_loader, test_set = initalize_dataloaders(config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed, cuda_device=torch.device('cpu'), batch_size=config['training']['batch_size'])
     
     print("    * Loading Model")
     model = torch.load(model_path + model_file)
+    model.eval()
     
     print("    * Testing Model")
-    predicted, actual, correct, incorrect = test_model(model, test_loader, cuda_device=cuda_device)
+    predicted, actual, correct, incorrect, max_preds, max_actuals = test_model(model, test_loader, cuda_device=device)
     print("    * Accuracy: " + str(correct / (correct + incorrect)))
         
-    plot_confusion_matrix(predicted, actual, model_name, timestamp, plot_path)
+    plot_confusion_matrix(max_preds, max_actuals, model_name, timestamp, plot_path)
     plot_roc_curve(predicted, actual, model_name, timestamp, plot_path)
+    plot_image_selection(model, test_set, model_name, timestamp, plot_path, cuda_device=device)
+    
     
     
 

+ 63 - 11
utils/training.py

@@ -109,29 +109,42 @@ def test_model(model, test_loader, cuda_device=torch.device('cuda:0')):
     
     predictions = []
     actual = []
+    
+    max_preds = []
+    max_actuals = []
 
     with torch.no_grad():
         length = len(test_loader)
         for i, data in tqdm(enumerate(test_loader, 0), total=length, desc="Testing", unit="batch"):
-            mri, xls, label = data
+            mri, xls, labels = data
 
             mri = mri.to(cuda_device).float()
             xls = xls.to(cuda_device).float()
-            label = label.to(cuda_device).float()
+            labels = labels.to(cuda_device).float()
 
             outputs = model((mri, xls))
 
-            _, predicted = torch.max(outputs.data, 1)
-            _, labels = torch.max(label.data, 1)
+            _, m_predicted = torch.max(outputs.data, 1)
+            _, m_labels = torch.max(labels.data, 1)
 
-            incorrect += (predicted != labels).sum().item()
-            correct += (predicted == labels).sum().item()
+            incorrect += (m_predicted != m_labels).sum().item()
+            correct += (m_predicted == m_labels).sum().item()
+            
+            #We just want the positive class, since there are only 2 classes and we use softmax
+            pos_outputs = outputs[:, 1]
+            pos_labels = labels[:, 1]
             
             
-            predictions.extend(predicted.tolist())
-            actual.extend(labels.tolist())
+            predictions.extend(pos_outputs.tolist())
+            actual.extend(pos_labels.tolist())
                 
-    return predictions, actual, correct, incorrect
+            _, max_pred = torch.max(outputs.data, 1)
+            _, max_actual = torch.max(labels.data, 1)
+            
+            max_preds.extend(max_pred.tolist())
+            max_actuals.extend(max_actual.tolist())
+            
+    return predictions, actual, correct, incorrect, max_preds, max_actuals
 
 def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0'), batch_size=64):
     training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed)
@@ -140,7 +153,7 @@ def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch
     test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device))
     val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
 
-    return train_dataloader, val_dataloader, test_dataloader
+    return train_dataloader, val_dataloader, test_dataloader, test_data
 
 
 def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp, plot_path):
@@ -188,10 +201,49 @@ def plot_roc_curve(predicted, actual, model_name, timestamp, plot_path):
     np.array(actual, dtype=np.float64)
     
     fpr, tpr, _ = roc_curve(actual, predicted)
-    print(fpr, tpr)
     auc = roc_auc_score(actual, predicted)
     plt.figure()
     RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc).plot()
     plt.savefig(plot_path + model_name + "_t-" + timestamp + "_roc_curve.png")
     plt.close()
+    
+    
+def plot_image_selection(model, test_set, model_name, timestamp, plot_path, cuda_device=torch.device('cuda:0')):
+    #Plot a bevy of random images from the test set and their predictions for the positive class
+    if not os.path.exists(plot_path):
+        os.makedirs(plot_path)
+        
+    #Get random images
+    images = []
+    
+    for i in range(8):
+        images.append(test_set[np.random.randint(0, len(test_set))])
+        
+    #Now that we have our images, create a subplot for each image
+    plt.figure()
+    fig, axs = plt.subplots(2, 4)
+    
+    for i, image in enumerate(images):
+        mri, xls, label = image
+        
+        
+        mri = mri.to(cuda_device).float()
+        xls = xls.to(cuda_device).float()
+        label = label[1]
+        
+        mri = mri.unsqueeze(0)
+        xls = xls.unsqueeze(0)
+        
+        output = model((mri, xls))
+        
+        prediction = output[:, 1]
+    
+        sliced_image = torch.permute(torch.select(torch.squeeze(mri, 0), 3, 80), (1, 2, 0)).cpu().numpy()
+        axs[i // 4, i % 4].imshow(sliced_image, cmap="gray")
+        axs[i // 4, i % 4].set_title("Pr: " + str(round(prediction.item(), 3)) + ", \nAc: " + str(label.item()))
+        
+    plt.savefig(plot_path + model_name + "_t-" + timestamp + "_image_selection.png")
+    plt.close()
+
+