test.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import numpy as np
  2. import torch
  3. from torch.utils.data import DataLoader
  4. from datareader import DataReader
  5. from sklearn.metrics import roc_auc_score, roc_curve
  6. class Testing():
  7. def __init__(self, main_path_to_data):
  8. self.main_path_to_data = main_path_to_data
  9. def test(self, test_info, model, path_to_model_weights):
  10. """Function for testing the model on test_info.
  11. Args:
  12. test_info (list): list of paths to 10 central slices per patient (ordered)
  13. model (nn.Module): architecture of the model
  14. path_to_model_weights (string): absolute path to the model weights
  15. Returns:
  16. auc (float): AUC calculated on test set
  17. fpr (ndarray): increasing false positive rates such that element i is the false positive rate of predictions with score >= thresholds[i]
  18. tpr (ndarray): increasing true positive rates such that element i is the true positive rate of predictions with score >= thresholds[i]
  19. thresholds (ndarray): decreasing thresholds on the decision function used to compute fpr and tpr
  20. trues (list): ground truth labels (0/1)
  21. predictions (list): predictions [0,1]
  22. """
  23. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  24. # 1. Load trained model and set it to eval mode
  25. model.to(device)
  26. model.load_state_dict(torch.load(path_to_model_weights))
  27. model.eval()
  28. # 2. Create datalodaer
  29. test_datareader = DataReader(self.main_path_to_data, test_info)
  30. test_generator = DataLoader(test_datareader, batch_size=10, shuffle=False, pin_memory = True, num_workers=2)
  31. # 3. Calculate metrics
  32. predictions = []
  33. trues = []
  34. for item_test in test_generator:
  35. # Load images (x) and labels (y)
  36. x, y = item_test
  37. x = x.to(device)
  38. y = y.to(device)
  39. # Forward pass
  40. with torch.no_grad():
  41. y_hat = model.forward(x)
  42. y_hat = torch.sigmoid(y_hat) # In training we are using BCEWithLogitsLoss for improved performance (sigmoid is already embedded), here we have to add it
  43. predictions.append(np.mean(y_hat.cpu().numpy()))
  44. trues.append(y.cpu().numpy()[0])
  45. auc = roc_auc_score(trues, predictions)
  46. fpr, tpr, thresholds = roc_curve(trues, predictions, pos_label=1)
  47. return auc, fpr, tpr, thresholds, trues, predictions