12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- print("--- INITIALIZING LIBRARIES ---")
- from utils.training import train_model, test_model, initalize_dataloaders, plot_confusion_matrix, plot_roc_curve, plot_image_selection
- import tomli as tl
- import torch
- import os
- from utils.models import CNN_Net
-
- print("--- LIBRARIES INITIALIZED ---")
- #GET CONFIG SETTINGS
- if os.getenv('ADL_CONFIG_PATH') is None:
- with open ('config.toml', 'rb') as f:
- config = tl.load(f)
- else:
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
- config = tl.load(f)
-
- device = torch.device(config['cuda']['device'])
- #For each file in the model directory, run model tests and save results
- plot_path = config['paths']['plot_output']
- model_path = config['paths']['model_output']
- test_output_path = config['paths']['testing_record_output']
- #get all files in model directory
- model_files = os.listdir(model_path)
- #for each model in the model path, determine timestamp from file name and load the model, then test the model
- print("--- TESTING MODELS ---")
- for model_file in model_files:
- #get model name from file name
- model_name = model_file[:model_file.find("_")]
-
- #get timestamp from file name
- timestamp = model_file[(model_file.find("t-") + 2): model_file.find("_", model_file.find("t-"))]
-
- #get seed from file name
- seed = int(model_file[(model_file.find("s-") + 2): model_file.find("_", model_file.find("s-"))])
-
- print(" - Testing Model: " + timestamp + ", Seed: ", seed)
- print(" * Loading Dataset")
-
- _, _, test_loader, test_set = initalize_dataloaders(config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed, cuda_device=torch.device('cpu'), batch_size=config['training']['batch_size'])
-
- print(" * Loading Model")
- model = torch.load(model_path + model_file)
- model.eval()
-
- print(" * Testing Model")
- predicted, actual, correct, incorrect, max_preds, max_actuals = test_model(model, test_loader, cuda_device=device)
- print(" * Accuracy: " + str(correct / (correct + incorrect)))
-
- plot_confusion_matrix(max_preds, max_actuals, model_name, timestamp, plot_path)
- plot_roc_curve(predicted, actual, model_name, timestamp, plot_path)
- plot_image_selection(model, test_set, model_name, timestamp, plot_path, cuda_device=device)
-
-
-
-
-
-
-
-
|