threshold_refac.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. import itertools as it
  13. import matplotlib.ticker as ticker
  14. import glob
  15. # CONFIGURATION
  16. if os.getenv('ADL_CONFIG_PATH') is None:
  17. with open('config.toml', 'rb') as f:
  18. config = toml.load(f)
  19. else:
  20. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  21. config = toml.load(f)
  22. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  23. V2_PATH = ENSEMBLE_PATH + '/v2'
  24. # Models is a dictionary with the model ids as keys and the model data as values
  25. def get_model_predictions(models, data):
  26. predictions = {}
  27. for model_id, model in models.items():
  28. model.eval()
  29. with torch.no_grad:
  30. # Get the predictions
  31. output = model(data)
  32. predictions[model_id] = output.detach().cpu().numpy()
  33. return predictions
  34. def load_models_v2(folder, device):
  35. glob_path = os.path.join(folder, '*.pt')
  36. model_files = glob(glob_path)
  37. model_dict = {}
  38. for model_file in model_files:
  39. model = torch.load(model_file, map_location=device)
  40. model_id = os.path.basename(model_file).split('_')[0]
  41. model_dict[model_id] = model
  42. return model_dict
  43. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  44. def preprocess_data(data, device):
  45. mri, xls = data
  46. mri = mri.unsqueeze(0).to(device)
  47. xls = xls.unsqueeze(0).to(device)
  48. return (mri, xls)
  49. def ensemble_dataset_predictions(models, dataset, device):
  50. # For each datapoint, get the predictions of each model
  51. predictions = {}
  52. for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
  53. # Preprocess data
  54. data = preprocess_data(data, device)
  55. # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
  56. # The key is the id of the image
  57. predictions[i] = (
  58. target.detach().cpu().numpy(),
  59. get_model_predictions(models, data),
  60. )
  61. return predictions
  62. # Given a dictionary of predictions, select one model and eliminate the rest
  63. def select_individual_model(predictions, model_id):
  64. selected_model_predictions = {}
  65. for key, value in predictions.items():
  66. selected_model_predictions[key] = (value[0], {model_id: value[1][model_id]})
  67. return selected_model_predictions
  68. # Given a dictionary of predictions, select a subset of models and eliminate the rest
  69. def select_subset_models(predictions, model_ids):
  70. selected_model_predictions = {}
  71. for key, value in predictions.items():
  72. selected_model_predictions[key] = (
  73. value[0],
  74. {model_id: value[1][model_id] for model_id in model_ids},
  75. )
  76. return selected_model_predictions
  77. # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, accuracy, f1) for each result
  78. def calculate_statistics(predictions):