threshold.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. import itertools as it
  13. RUN = False
  14. # CONFIGURATION
  15. if os.getenv('ADL_CONFIG_PATH') is None:
  16. with open('config.toml', 'rb') as f:
  17. config = toml.load(f)
  18. else:
  19. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  20. config = toml.load(f)
  21. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  22. V2_PATH = ENSEMBLE_PATH + '/v2'
  23. # Result is a 1x2 tensor, with the softmax of the 2 predicted classes
  24. # Want to convert to a predicted class and a confidence
  25. def output_to_confidence(result):
  26. predicted_class = torch.argmax(result).item()
  27. confidence = (torch.max(result).item() - 0.5) * 2
  28. return torch.Tensor([predicted_class, confidence])
  29. # This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
  30. def get_predictions(config):
  31. models, model_descs = ens.load_models(
  32. f'{ENSEMBLE_PATH}/models/',
  33. config['training']['device'],
  34. )
  35. models = [model.to(config['training']['device']) for model in models]
  36. test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  37. f'{ENSEMBLE_PATH}/val_dataset.pt'
  38. )
  39. # [([model results], labels)]
  40. results = []
  41. # [(class_1, class_2, true_label)]
  42. indv_results = []
  43. for i, (data, target) in tqdm(
  44. enumerate(test_set),
  45. total=len(test_set),
  46. desc='Getting predictions',
  47. unit='sample',
  48. ):
  49. mri, xls = data
  50. mri = mri.unsqueeze(0).to(config['training']['device'])
  51. xls = xls.unsqueeze(0).to(config['training']['device'])
  52. data = (mri, xls)
  53. res = []
  54. for j, model in enumerate(models):
  55. model.eval()
  56. with torch.no_grad():
  57. output = model(data)
  58. output = output.tolist()
  59. if j == 0:
  60. indv_results.append((output[0][0], output[0][1], target[1].item()))
  61. res.append(output)
  62. results.append((res, target.tolist()))
  63. # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
  64. # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
  65. # And one with the ensemble predicted class, ensemble standard deviation and true label
  66. # [(ensemble predicted class, ensemble confidence, true label)]
  67. confidences = []
  68. # [(ensemble predicted class, ensemble standard deviation, true label)]
  69. stdevs = []
  70. for result in results:
  71. model_results, true_label = result
  72. # Get the ensemble mean and variance with numpy, as these are lists
  73. mean = np.mean(model_results, axis=0)
  74. variance = np.var(model_results, axis=0)
  75. # Calculate confidence and standard deviation
  76. confidence = (np.max(mean) - 0.5) * 2
  77. stdev = np.sqrt(variance)
  78. # Get the predicted class
  79. predicted_class = np.argmax(mean)
  80. # Get the confidence and standard deviation of the predicted class
  81. print(stdev)
  82. pc_stdev = np.squeeze(stdev)[predicted_class]
  83. # Get the individual classes
  84. class_1 = mean[0][0]
  85. class_2 = mean[0][1]
  86. # Get the true label
  87. true_label = true_label[1]
  88. confidences.append((predicted_class, confidence, true_label, class_1, class_2))
  89. stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
  90. return results, confidences, stdevs, indv_results
  91. if RUN:
  92. results, confs, stdevs, indv_results = get_predictions(config)
  93. # Convert to pandas dataframes
  94. confs_df = pd.DataFrame(
  95. confs,
  96. columns=['predicted_class', 'confidence', 'true_label', 'class_1', 'class_2'],
  97. )
  98. stdevs_df = pd.DataFrame(
  99. stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
  100. )
  101. indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
  102. if not os.path.exists(V2_PATH):
  103. os.makedirs(V2_PATH)
  104. confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
  105. stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  106. indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
  107. else:
  108. confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
  109. stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  110. indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
  111. # Plot confidence vs standard deviation, and change color of dots based on if they are correct
  112. correct_conf = confs_df[confs_df['predicted_class'] == confs_df['true_label']]
  113. incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
  114. correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
  115. incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
  116. plt.scatter(correct_conf['confidence'], correct_stdev['stdev'], color='green')
  117. plt.scatter(incorrect_conf['confidence'], incorrect_stdev['stdev'], color='red')
  118. plt.xlabel('Confidence')
  119. plt.ylabel('Standard Deviation')
  120. plt.title('Confidence vs Standard Deviation')
  121. plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
  122. plt.close()
  123. # Calculate individual model accuracy
  124. # Determine predicted class
  125. indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
  126. indv_df['predicted_class'] = indv_df['predicted_class'].apply(
  127. lambda x: 0 if x == 'class_1' else 1
  128. )
  129. indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
  130. accuracy_indv = indv_df['correct'].mean()
  131. f1_indv = met.F1(
  132. indv_df['predicted_class'].to_numpy(), indv_df['true_label'].to_numpy()
  133. )
  134. auc_indv = metrics.roc_auc_score(
  135. indv_df['true_label'].to_numpy(), indv_df['class_2'].to_numpy()
  136. )
  137. # Calculate percentiles for confidence and standard deviation
  138. quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  139. 'confidence'
  140. ]
  141. quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  142. 'stdev'
  143. ]
  144. accuracies_conf = []
  145. # Use the quantiles to calculate the coverage
  146. iter_conf = it.islice(quantiles_conf.items(), 0, None)
  147. for quantile in iter_conf:
  148. percentile = quantile[0]
  149. filt = confs_df[confs_df['confidence'] >= quantile[1]]
  150. accuracy = (
  151. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  152. )
  153. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  154. accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  155. accuracies_df = pd.DataFrame(accuracies_conf)
  156. # Plot the coverage
  157. plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], label='Ensemble')
  158. plt.plot(
  159. accuracies_df['percentile'],
  160. [accuracy_indv] * len(accuracies_df['percentile']),
  161. label='Individual',
  162. linestyle='--',
  163. )
  164. plt.xlabel('Percentile')
  165. plt.ylabel('Accuracy')
  166. plt.title('Coverage conf')
  167. plt.legend()
  168. plt.savefig(f'{V2_PATH}/coverage_conf.png')
  169. plt.close()
  170. # Plot coverage vs F1 for confidence
  171. plt.plot(accuracies_df['percentile'], accuracies_df['f1'], label='Ensemble')
  172. plt.plot(
  173. accuracies_df['percentile'],
  174. [f1_indv] * len(accuracies_df['percentile']),
  175. label='Individual',
  176. linestyle='--',
  177. )
  178. plt.xlabel('Percentile')
  179. plt.ylabel('F1')
  180. plt.title('Coverage F1')
  181. plt.legend()
  182. plt.savefig(f'{V2_PATH}/coverage_f1_conf.png')
  183. plt.close()
  184. # Repeat for standard deviation
  185. accuracies_stdev = []
  186. iter_stdev = it.islice(quantiles_stdev.items(), 0, None)
  187. for quantile in iter_stdev:
  188. percentile = quantile[0]
  189. filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
  190. accuracy = (
  191. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  192. )
  193. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  194. accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  195. accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
  196. # Plot the coverage
  197. plt.plot(
  198. accuracies_stdev_df['percentile'], accuracies_stdev_df['accuracy'], label='Ensemble'
  199. )
  200. plt.plot(
  201. accuracies_stdev_df['percentile'],
  202. [accuracy_indv] * len(accuracies_stdev_df['percentile']),
  203. label='Individual',
  204. linestyle='--',
  205. )
  206. plt.xlabel('Percentile')
  207. plt.ylabel('Accuracy')
  208. plt.title('Coverage Stdev')
  209. plt.legend()
  210. plt.gca().invert_xaxis()
  211. plt.savefig(f'{V2_PATH}/coverage_stdev.png')
  212. plt.close()
  213. # Plot coverage vs F1 for standard deviation
  214. plt.plot(accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], label='Ensemble')
  215. plt.plot(
  216. accuracies_stdev_df['percentile'],
  217. [f1_indv] * len(accuracies_stdev_df['percentile']),
  218. label='Individual',
  219. linestyle='--',
  220. )
  221. plt.xlabel('Percentile')
  222. plt.ylabel('F1')
  223. plt.title('Coverage F1 Stdev')
  224. plt.legend()
  225. plt.gca().invert_xaxis()
  226. plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
  227. plt.close()
  228. # Print overall accuracy
  229. overall_accuracy = (
  230. confs_df[confs_df['predicted_class'] == confs_df['true_label']].shape[0]
  231. / confs_df.shape[0]
  232. )
  233. overall_f1 = met.F1(
  234. confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
  235. )
  236. print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1}')