threshold_refac.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. import itertools as it
  13. import matplotlib.ticker as ticker
  14. import glob
  15. import pickle as pk
  16. import warnings
  17. warnings.filterwarnings('error')
  18. # CONFIGURATION
  19. if os.getenv('ADL_CONFIG_PATH') is None:
  20. with open('config.toml', 'rb') as f:
  21. config = toml.load(f)
  22. else:
  23. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  24. config = toml.load(f)
  25. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  26. V3_PATH = ENSEMBLE_PATH + '/v3'
  27. # Create the directory if it does not exist
  28. if not os.path.exists(V3_PATH):
  29. os.makedirs(V3_PATH)
  30. # Models is a dictionary with the model ids as keys and the model data as values
  31. def get_model_predictions(models, data):
  32. predictions = {}
  33. for model_id, model in models.items():
  34. model.eval()
  35. with torch.no_grad():
  36. # Get the predictions
  37. output = model(data)
  38. predictions[model_id] = output.detach().cpu().numpy()
  39. return predictions
  40. def load_models_v2(folder, device):
  41. glob_path = os.path.join(folder, '*.pt')
  42. model_files = glob.glob(glob_path)
  43. model_dict = {}
  44. for model_file in model_files:
  45. model = torch.load(model_file, map_location=device)
  46. model_id = os.path.basename(model_file).split('_')[0]
  47. model_dict[model_id] = model
  48. if len(model_dict) == 0:
  49. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  50. return model_dict
  51. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  52. def preprocess_data(data, device):
  53. mri, xls = data
  54. mri = mri.unsqueeze(0).to(device)
  55. xls = xls.unsqueeze(0).to(device)
  56. return (mri, xls)
  57. def ensemble_dataset_predictions(models, dataset, device):
  58. # For each datapoint, get the predictions of each model
  59. predictions = {}
  60. for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
  61. # Preprocess data
  62. data = preprocess_data(data, device)
  63. # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
  64. # The key is the id of the image
  65. predictions[i] = (
  66. target.detach().cpu().numpy(),
  67. get_model_predictions(models, data),
  68. )
  69. return predictions
  70. # Given a dictionary of predictions, select one model and eliminate the rest
  71. def select_individual_model(predictions, model_id):
  72. selected_model_predictions = {}
  73. for key, value in predictions.items():
  74. selected_model_predictions[key] = (
  75. value[0],
  76. {model_id: value[1][str(model_id)]},
  77. )
  78. return selected_model_predictions
  79. # Given a dictionary of predictions, select a subset of models and eliminate the rest
  80. def select_subset_models(predictions, model_ids):
  81. selected_model_predictions = {}
  82. for key, value in predictions.items():
  83. selected_model_predictions[key] = (
  84. value[0],
  85. {model_id: value[1][model_id] for model_id in model_ids},
  86. )
  87. return selected_model_predictions
  88. # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
  89. # Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
  90. def calculate_statistics(predictions):
  91. # Create DataFrame with columns for each statistic
  92. stats_df = pd.DataFrame(
  93. columns=[
  94. 'mean',
  95. 'stdev',
  96. 'entropy',
  97. 'confidence',
  98. 'correct',
  99. 'predicted',
  100. 'actual',
  101. ]
  102. )
  103. # First, loop through each prediction
  104. for key, value in predictions.items():
  105. target = value[0]
  106. model_predictions = list(value[1].values())
  107. # Calculate the mean and stdev of predictions
  108. mean = np.squeeze(np.mean(model_predictions, axis=0))
  109. stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
  110. # Calculate the entropy of the predictions
  111. entropy = met.entropy(mean)
  112. # Calculate confidence
  113. confidence = (np.max(mean) - 0.5) * 2
  114. # Calculate predicted and actual
  115. predicted = np.argmax(mean)
  116. actual = np.argmax(target)
  117. # Determine if the prediction is correct
  118. correct = predicted == actual
  119. # Add the statistics to the dataframe
  120. stats_df.loc[key] = [
  121. mean,
  122. stdev,
  123. entropy,
  124. confidence,
  125. correct,
  126. predicted,
  127. actual,
  128. ]
  129. return stats_df
  130. # Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
  131. # Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
  132. def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
  133. # Gives a dataframe
  134. percentile_df = statistics[statistic_name].quantile(
  135. q=np.linspace(0.05, 0.95, num=18)
  136. )
  137. # Dictionary of form {threshold: {metric: value}}
  138. thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
  139. for percentile, value in percentile_df.items():
  140. # Filter the statistics
  141. if low_to_high:
  142. filtered_statistics = statistics[statistics[statistic_name] < value]
  143. else:
  144. filtered_statistics = statistics[statistics[statistic_name] >= value]
  145. # Calculate accuracy and f1 score
  146. accuracy = filtered_statistics['correct'].mean()
  147. # Calculate F1 score
  148. predicted = filtered_statistics['predicted'].values
  149. actual = filtered_statistics['actual'].values
  150. f1 = metrics.f1_score(actual, predicted)
  151. # Add the metrics to the dataframe
  152. thresholds_pd.loc[percentile] = [accuracy, f1]
  153. return thresholds_pd
  154. # Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
  155. # Can plot an additional line if given (used for individual results)
  156. def plot_threshold_analysis(
  157. thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
  158. ):
  159. # Initialize the plot
  160. fig, ax = plt.subplots()
  161. # Get the thresholds and metrics
  162. thresholds = list(thresholds_metric.index)
  163. metric = list(thresholds_metric.values)
  164. # Plot the metric against the threshold
  165. plt.plot(thresholds, metric, 'bo-', label='Ensemble')
  166. if additional_set is not None:
  167. # Get the thresholds and metrics
  168. thresholds = list(additional_set.index)
  169. metric = list(additional_set.values)
  170. # Plot the metric against the threshold
  171. plt.plot(thresholds, metric, 'rx-', label='Individual')
  172. if flip:
  173. ax.invert_xaxis()
  174. # Add labels
  175. plt.title(title)
  176. plt.xlabel(x_label)
  177. plt.ylabel(y_label)
  178. plt.legend()
  179. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
  180. plt.savefig(path)
  181. plt.close()
  182. # Code from https://stackoverflow.com/questions/16458340
  183. # Returns the intersections of multiple dictionaries
  184. def common_entries(*dcts):
  185. if not dcts:
  186. return
  187. for i in set(dcts[0]).intersection(*dcts[1:]):
  188. yield (i,) + tuple(d[i] for d in dcts)
  189. def main():
  190. # Load the models
  191. device = torch.device(config['training']['device'])
  192. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  193. # Load Dataset
  194. dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  195. f'{ENSEMBLE_PATH}/val_dataset.pt'
  196. )
  197. if config['ensemble']['run_models']:
  198. # Get thre predicitons of the ensemble
  199. ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
  200. # Save to file using pickle
  201. with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
  202. pk.dump(ensemble_predictions, f)
  203. else:
  204. # Load the predictions from file
  205. with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
  206. ensemble_predictions = pk.load(f)
  207. # Get the statistics and thresholds of the ensemble
  208. ensemble_statistics = calculate_statistics(ensemble_predictions)
  209. stdev_thresholds = conduct_threshold_analysis(
  210. ensemble_statistics, 'stdev', low_to_high=True
  211. )
  212. entropy_thresholds = conduct_threshold_analysis(
  213. ensemble_statistics, 'entropy', low_to_high=True
  214. )
  215. confidence_thresholds = conduct_threshold_analysis(
  216. ensemble_statistics, 'confidence', low_to_high=False
  217. )
  218. # Print ECE and MCE Values
  219. conf_ece = met.ECE(
  220. ensemble_statistics['predicted'],
  221. ensemble_statistics['confidence'],
  222. ensemble_statistics['actual'],
  223. )
  224. conf_mce = met.MCE(
  225. ensemble_statistics['predicted'],
  226. ensemble_statistics['confidence'],
  227. ensemble_statistics['actual'],
  228. )
  229. ent_ece = met.ECE(
  230. ensemble_statistics['predicted'],
  231. ensemble_statistics['entropy'],
  232. ensemble_statistics['actual'],
  233. )
  234. ent_mce = met.MCE(
  235. ensemble_statistics['predicted'],
  236. ensemble_statistics['entropy'],
  237. ensemble_statistics['actual'],
  238. )
  239. stdev_ece = met.ECE(
  240. ensemble_statistics['predicted'],
  241. ensemble_statistics['stdev'],
  242. ensemble_statistics['actual'],
  243. )
  244. stdev_mce = met.MCE(
  245. ensemble_statistics['predicted'],
  246. ensemble_statistics['stdev'],
  247. ensemble_statistics['actual'],
  248. )
  249. print(f'Confidence ECE: {conf_ece}, Confidence MCE: {conf_mce}')
  250. print(f'Entropy ECE: {ent_ece}, Entropy MCE: {ent_mce}')
  251. print(f'Stdev ECE: {stdev_ece}, Stdev MCE: {stdev_mce}')
  252. # Print overall ensemble statistics
  253. print('Ensemble Statistics')
  254. print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
  255. print(
  256. f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
  257. )
  258. # Get the predictions, statistics and thresholds an individual model
  259. indv_id = config['ensemble']['individual_id']
  260. indv_predictions = select_individual_model(ensemble_predictions, indv_id)
  261. indv_statistics = calculate_statistics(indv_predictions)
  262. # Calculate entropy and confidence thresholds for individual model
  263. indv_entropy_thresholds = conduct_threshold_analysis(
  264. indv_statistics, 'entropy', low_to_high=True
  265. )
  266. indv_confidence_thresholds = conduct_threshold_analysis(
  267. indv_statistics, 'confidence', low_to_high=False
  268. )
  269. # Plot the threshold analysis for standard deviation
  270. plot_threshold_analysis(
  271. stdev_thresholds['accuracy'],
  272. 'Stdev Threshold Analysis for Accuracy',
  273. 'Stdev Threshold',
  274. 'Accuracy',
  275. f'{V3_PATH}/stdev_threshold_analysis.png',
  276. flip=True,
  277. )
  278. plot_threshold_analysis(
  279. stdev_thresholds['f1'],
  280. 'Stdev Threshold Analysis for F1 Score',
  281. 'Stdev Threshold',
  282. 'F1 Score',
  283. f'{V3_PATH}/stdev_threshold_analysis_f1.png',
  284. flip=True,
  285. )
  286. # Plot the threshold analysis for entropy
  287. plot_threshold_analysis(
  288. entropy_thresholds['accuracy'],
  289. 'Entropy Threshold Analysis for Accuracy',
  290. 'Entropy Threshold',
  291. 'Accuracy',
  292. f'{V3_PATH}/entropy_threshold_analysis.png',
  293. indv_entropy_thresholds['accuracy'],
  294. flip=True,
  295. )
  296. plot_threshold_analysis(
  297. entropy_thresholds['f1'],
  298. 'Entropy Threshold Analysis for F1 Score',
  299. 'Entropy Threshold',
  300. 'F1 Score',
  301. f'{V3_PATH}/entropy_threshold_analysis_f1.png',
  302. indv_entropy_thresholds['f1'],
  303. flip=True,
  304. )
  305. # Plot the threshold analysis for confidence
  306. plot_threshold_analysis(
  307. confidence_thresholds['accuracy'],
  308. 'Confidence Threshold Analysis for Accuracy',
  309. 'Confidence Threshold',
  310. 'Accuracy',
  311. f'{V3_PATH}/confidence_threshold_analysis.png',
  312. indv_confidence_thresholds['accuracy'],
  313. )
  314. plot_threshold_analysis(
  315. confidence_thresholds['f1'],
  316. 'Confidence Threshold Analysis for F1 Score',
  317. 'Confidence Threshold',
  318. 'F1 Score',
  319. f'{V3_PATH}/confidence_threshold_analysis_f1.png',
  320. indv_confidence_thresholds['f1'],
  321. )
  322. if __name__ == '__main__':
  323. main()