123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
- from datareader import DataReader
- from sklearn.metrics import roc_auc_score, roc_curve
- class Testing():
- def __init__(self, main_path_to_data):
- self.main_path_to_data = main_path_to_data
-
- def test(self, test_info, model, path_to_model_weights):
- """Function for testing the model on test_info.
- Args:
- test_info (list): list of paths to 10 central slices per patient (ordered)
- model (nn.Module): architecture of the model
- path_to_model_weights (string): absolute path to the model weights
- Returns:
- auc (float): AUC calculated on test set
- fpr (ndarray): increasing false positive rates such that element i is the false positive rate of predictions with score >= thresholds[i]
- tpr (ndarray): increasing true positive rates such that element i is the true positive rate of predictions with score >= thresholds[i]
- thresholds (ndarray): decreasing thresholds on the decision function used to compute fpr and tpr
- trues (list): ground truth labels (0/1)
- predictions (list): predictions [0,1]
- """
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # 1. Load trained model and set it to eval mode
- model.to(device)
- model.load_state_dict(torch.load(path_to_model_weights))
- model.eval()
-
- # 2. Create datalodaer
- test_datareader = DataReader(self.main_path_to_data, test_info)
- test_generator = DataLoader(test_datareader, batch_size=10, shuffle=False, pin_memory = True, num_workers=2)
- # 3. Calculate metrics
- predictions = []
- trues = []
-
- for item_test in test_generator:
- # Load images (x) and labels (y)
- x, y = item_test
- x = x.to(device)
- y = y.to(device)
-
- # Forward pass
- with torch.no_grad():
- y_hat = model.forward(x)
- y_hat = torch.sigmoid(y_hat) # In training we are using BCEWithLogitsLoss for improved performance (sigmoid is already embedded), here we have to add it
- predictions.append(np.mean(y_hat.cpu().numpy()))
- trues.append(y.cpu().numpy()[0])
-
- auc = roc_auc_score(trues, predictions)
- fpr, tpr, thresholds = roc_curve(trues, predictions, pos_label=1)
-
- return auc, fpr, tpr, thresholds, trues, predictions
|