test_models.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. print("--- INITIALIZING LIBRARIES ---")
  2. from utils.training import train_model, test_model, initalize_dataloaders, plot_confusion_matrix, plot_roc_curve, plot_image_selection
  3. import tomli as tl
  4. import torch
  5. import os
  6. from utils.models import CNN_Net
  7. print("--- LIBRARIES INITIALIZED ---")
  8. #GET CONFIG SETTINGS
  9. if os.getenv('ADL_CONFIG_PATH') is None:
  10. with open ('config.toml', 'rb') as f:
  11. config = tl.load(f)
  12. else:
  13. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  14. config = tl.load(f)
  15. device = torch.device(config['cuda']['device'])
  16. #For each file in the model directory, run model tests and save results
  17. plot_path = config['paths']['plot_output']
  18. model_path = config['paths']['model_output']
  19. test_output_path = config['paths']['testing_record_output']
  20. #get all files in model directory
  21. model_files = os.listdir(model_path)
  22. #for each model in the model path, determine timestamp from file name and load the model, then test the model
  23. print("--- TESTING MODELS ---")
  24. for model_file in model_files:
  25. #get model name from file name
  26. model_name = model_file[:model_file.find("_")]
  27. #get timestamp from file name
  28. timestamp = model_file[(model_file.find("t-") + 2): model_file.find("_", model_file.find("t-"))]
  29. #get seed from file name
  30. seed = int(model_file[(model_file.find("s-") + 2): model_file.find("_", model_file.find("s-"))])
  31. print(" - Testing Model: " + timestamp + ", Seed: ", seed)
  32. print(" * Loading Dataset")
  33. _, _, test_loader, test_set = initalize_dataloaders(config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed, cuda_device=torch.device('cpu'), batch_size=config['training']['batch_size'])
  34. print(" * Loading Model")
  35. model = torch.load(model_path + model_file)
  36. model.eval()
  37. print(" * Testing Model")
  38. predicted, actual, correct, incorrect, max_preds, max_actuals = test_model(model, test_loader, cuda_device=device)
  39. print(" * Accuracy: " + str(correct / (correct + incorrect)))
  40. plot_confusion_matrix(max_preds, max_actuals, model_name, timestamp, plot_path)
  41. plot_roc_curve(predicted, actual, model_name, timestamp, plot_path)
  42. plot_image_selection(model, test_set, model_name, timestamp, plot_path, cuda_device=device)