threshold_refac.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. import itertools as it
  13. import matplotlib.ticker as ticker
  14. import glob
  15. import pickle as pk
  16. import warnings
  17. warnings.filterwarnings('error')
  18. # CONFIGURATION
  19. if os.getenv('ADL_CONFIG_PATH') is None:
  20. with open('config.toml', 'rb') as f:
  21. config = toml.load(f)
  22. else:
  23. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  24. config = toml.load(f)
  25. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  26. V3_PATH = ENSEMBLE_PATH + '/v3'
  27. # Create the directory if it does not exist
  28. if not os.path.exists(V3_PATH):
  29. os.makedirs(V3_PATH)
  30. # Models is a dictionary with the model ids as keys and the model data as values
  31. def get_model_predictions(models, data):
  32. predictions = {}
  33. for model_id, model in models.items():
  34. model.eval()
  35. with torch.no_grad():
  36. # Get the predictions
  37. output = model(data)
  38. predictions[model_id] = output.detach().cpu().numpy()
  39. return predictions
  40. def load_models_v2(folder, device):
  41. glob_path = os.path.join(folder, '*.pt')
  42. model_files = glob.glob(glob_path)
  43. model_dict = {}
  44. for model_file in model_files:
  45. model = torch.load(model_file, map_location=device)
  46. model_id = os.path.basename(model_file).split('_')[0]
  47. model_dict[model_id] = model
  48. if len(model_dict) == 0:
  49. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  50. return model_dict
  51. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  52. def preprocess_data(data, device):
  53. mri, xls = data
  54. mri = mri.unsqueeze(0).to(device)
  55. xls = xls.unsqueeze(0).to(device)
  56. return (mri, xls)
  57. def ensemble_dataset_predictions(models, dataset, device):
  58. # For each datapoint, get the predictions of each model
  59. predictions = {}
  60. for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
  61. # Preprocess data
  62. data = preprocess_data(data, device)
  63. # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
  64. # The key is the id of the image
  65. predictions[i] = (
  66. target.detach().cpu().numpy(),
  67. get_model_predictions(models, data),
  68. )
  69. return predictions
  70. # Given a dictionary of predictions, select one model and eliminate the rest
  71. def select_individual_model(predictions, model_id):
  72. selected_model_predictions = {}
  73. for key, value in predictions.items():
  74. selected_model_predictions[key] = (
  75. value[0],
  76. {model_id: value[1][str(model_id)]},
  77. )
  78. return selected_model_predictions
  79. # Given a dictionary of predictions, select a subset of models and eliminate the rest
  80. def select_subset_models(predictions, model_ids):
  81. selected_model_predictions = {}
  82. for key, value in predictions.items():
  83. selected_model_predictions[key] = (
  84. value[0],
  85. {model_id: value[1][model_id] for model_id in model_ids},
  86. )
  87. return selected_model_predictions
  88. # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
  89. # Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
  90. def calculate_statistics(predictions):
  91. # Create DataFrame with columns for each statistic
  92. stats_df = pd.DataFrame(
  93. columns=[
  94. 'mean',
  95. 'stdev',
  96. 'entropy',
  97. 'confidence',
  98. 'correct',
  99. 'predicted',
  100. 'actual',
  101. ]
  102. )
  103. # First, loop through each prediction
  104. for key, value in predictions.items():
  105. target = value[0]
  106. model_predictions = list(value[1].values())
  107. # Calculate the mean and stdev of predictions
  108. mean = np.squeeze(np.mean(model_predictions, axis=0))
  109. stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
  110. # Calculate the entropy of the predictions
  111. entropy = met.entropy(mean)
  112. # Calculate confidence
  113. confidence = (np.max(mean) - 0.5) * 2
  114. # Calculate predicted and actual
  115. predicted = np.argmax(mean)
  116. actual = np.argmax(target)
  117. # Determine if the prediction is correct
  118. correct = predicted == actual
  119. # Add the statistics to the dataframe
  120. stats_df.loc[key] = [
  121. mean,
  122. stdev,
  123. entropy,
  124. confidence,
  125. correct,
  126. predicted,
  127. actual,
  128. ]
  129. return stats_df
  130. # Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
  131. # Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
  132. def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
  133. # Gives a dataframe
  134. percentile_df = statistics[statistic_name].quantile(
  135. q=np.linspace(0.05, 0.95, num=18)
  136. )
  137. # Dictionary of form {threshold: {metric: value}}
  138. thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
  139. for percentile, value in percentile_df.items():
  140. # Filter the statistics
  141. if low_to_high:
  142. filtered_statistics = statistics[statistics[statistic_name] < value]
  143. else:
  144. filtered_statistics = statistics[statistics[statistic_name] >= value]
  145. # Calculate accuracy and f1 score
  146. accuracy = filtered_statistics['correct'].mean()
  147. # Calculate F1 score
  148. predicted = filtered_statistics['predicted'].values
  149. actual = filtered_statistics['actual'].values
  150. f1 = metrics.f1_score(actual, predicted)
  151. # Add the metrics to the dataframe
  152. thresholds_pd.loc[percentile] = [accuracy, f1]
  153. return thresholds_pd
  154. # Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
  155. # Can plot an additional line if given (used for individual results)
  156. def plot_threshold_analysis(
  157. thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
  158. ):
  159. # Initialize the plot
  160. fig, ax = plt.subplots()
  161. # Get the thresholds and metrics
  162. thresholds = list(thresholds_metric.index)
  163. metric = list(thresholds_metric.values)
  164. # Plot the metric against the threshold
  165. plt.plot(thresholds, metric, 'bo-', label='Ensemble')
  166. if additional_set is not None:
  167. # Get the thresholds and metrics
  168. thresholds = list(additional_set.index)
  169. metric = list(additional_set.values)
  170. # Plot the metric against the threshold
  171. plt.plot(thresholds, metric, 'rx-', label='Individual')
  172. if flip:
  173. ax.invert_xaxis()
  174. # Add labels
  175. plt.title(title)
  176. plt.xlabel(x_label)
  177. plt.ylabel(y_label)
  178. plt.legend()
  179. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
  180. plt.savefig(path)
  181. plt.close()
  182. # Code from https://stackoverflow.com/questions/16458340
  183. # Returns the intersections of multiple dictionaries
  184. def common_entries(*dcts):
  185. if not dcts:
  186. return
  187. for i in set(dcts[0]).intersection(*dcts[1:]):
  188. yield (i,) + tuple(d[i] for d in dcts)
  189. #Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
  190. def calculate_overall_statistics(ensemble_statistics):
  191. predicted = ensemble_statistics['predicted']
  192. actual = ensemble_statistics['actual']
  193. # New dataframe to store the statistics
  194. stats_df = pd.DataFrame(columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']).set_index('stat')
  195. # Loop through and calculate the ECE, MCE, Brier Score, and NLL
  196. for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
  197. ece = met.ECE(predicted, ensemble_statistics[stat], actual)
  198. mce = met.MCE(predicted, ensemble_statistics[stat], actual)
  199. brier = met.brier_binary(ensemble_statistics[stat], actual)
  200. nll = met.nll_binary(ensemble_statistics[stat], actual)
  201. stats_df.loc[stat] = [ece, mce, brier, nll]
  202. return stats_df
  203. def main():
  204. # Load the models
  205. device = torch.device(config['training']['device'])
  206. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  207. # Load Dataset
  208. dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  209. f'{ENSEMBLE_PATH}/val_dataset.pt'
  210. )
  211. if config['ensemble']['run_models']:
  212. # Get thre predicitons of the ensemble
  213. ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
  214. # Save to file using pickle
  215. with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
  216. pk.dump(ensemble_predictions, f)
  217. else:
  218. # Load the predictions from file
  219. with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
  220. ensemble_predictions = pk.load(f)
  221. # Get the statistics and thresholds of the ensemble
  222. ensemble_statistics = calculate_statistics(ensemble_predictions)
  223. stdev_thresholds = conduct_threshold_analysis(
  224. ensemble_statistics, 'stdev', low_to_high=True
  225. )
  226. entropy_thresholds = conduct_threshold_analysis(
  227. ensemble_statistics, 'entropy', low_to_high=True
  228. )
  229. confidence_thresholds = conduct_threshold_analysis(
  230. ensemble_statistics, 'confidence', low_to_high=False
  231. )
  232. raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
  233. ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
  234. # Calculate overall statistics
  235. overall_statistics = calculate_overall_statistics(ensemble_statistics)
  236. # Print overall statistics
  237. print(overall_statistics)
  238. # Print overall ensemble statistics
  239. print('Ensemble Statistics')
  240. print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
  241. print(
  242. f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
  243. )
  244. # Get the predictions, statistics and thresholds an individual model
  245. indv_id = config['ensemble']['individual_id']
  246. indv_predictions = select_individual_model(ensemble_predictions, indv_id)
  247. indv_statistics = calculate_statistics(indv_predictions)
  248. # Calculate entropy and confidence thresholds for individual model
  249. indv_entropy_thresholds = conduct_threshold_analysis(
  250. indv_statistics, 'entropy', low_to_high=True
  251. )
  252. indv_confidence_thresholds = conduct_threshold_analysis(
  253. indv_statistics, 'confidence', low_to_high=False
  254. )
  255. # Plot the threshold analysis for standard deviation
  256. plot_threshold_analysis(
  257. stdev_thresholds['accuracy'],
  258. 'Stdev Threshold Analysis for Accuracy',
  259. 'Stdev Threshold',
  260. 'Accuracy',
  261. f'{V3_PATH}/stdev_threshold_analysis.png',
  262. flip=True,
  263. )
  264. plot_threshold_analysis(
  265. stdev_thresholds['f1'],
  266. 'Stdev Threshold Analysis for F1 Score',
  267. 'Stdev Threshold',
  268. 'F1 Score',
  269. f'{V3_PATH}/stdev_threshold_analysis_f1.png',
  270. flip=True,
  271. )
  272. # Plot the threshold analysis for entropy
  273. plot_threshold_analysis(
  274. entropy_thresholds['accuracy'],
  275. 'Entropy Threshold Analysis for Accuracy',
  276. 'Entropy Threshold',
  277. 'Accuracy',
  278. f'{V3_PATH}/entropy_threshold_analysis.png',
  279. indv_entropy_thresholds['accuracy'],
  280. flip=True,
  281. )
  282. plot_threshold_analysis(
  283. entropy_thresholds['f1'],
  284. 'Entropy Threshold Analysis for F1 Score',
  285. 'Entropy Threshold',
  286. 'F1 Score',
  287. f'{V3_PATH}/entropy_threshold_analysis_f1.png',
  288. indv_entropy_thresholds['f1'],
  289. flip=True,
  290. )
  291. # Plot the threshold analysis for confidence
  292. plot_threshold_analysis(
  293. confidence_thresholds['accuracy'],
  294. 'Confidence Threshold Analysis for Accuracy',
  295. 'Confidence Threshold',
  296. 'Accuracy',
  297. f'{V3_PATH}/confidence_threshold_analysis.png',
  298. indv_confidence_thresholds['accuracy'],
  299. )
  300. plot_threshold_analysis(
  301. confidence_thresholds['f1'],
  302. 'Confidence Threshold Analysis for F1 Score',
  303. 'Confidence Threshold',
  304. 'F1 Score',
  305. f'{V3_PATH}/confidence_threshold_analysis_f1.png',
  306. indv_confidence_thresholds['f1'],
  307. )
  308. if __name__ == '__main__':
  309. main()