threshold.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. import itertools as it
  13. import matplotlib.ticker as ticker
  14. RUN = True
  15. # CONFIGURATION
  16. if os.getenv('ADL_CONFIG_PATH') is None:
  17. with open('config.toml', 'rb') as f:
  18. config = toml.load(f)
  19. else:
  20. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  21. config = toml.load(f)
  22. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  23. V2_PATH = ENSEMBLE_PATH + '/v2'
  24. # Result is a 1x2 tensor, with the softmax of the 2 predicted classes
  25. # Want to convert to a predicted class and a confidence
  26. def output_to_confidence(result):
  27. predicted_class = torch.argmax(result).item()
  28. confidence = (torch.max(result).item() - 0.5) * 2
  29. return torch.Tensor([predicted_class, confidence])
  30. # This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
  31. def get_predictions(config):
  32. models, model_descs = ens.load_models(
  33. f'{ENSEMBLE_PATH}/models/',
  34. config['training']['device'],
  35. )
  36. models = [model.to(config['training']['device']) for model in models]
  37. test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  38. f'{ENSEMBLE_PATH}/val_dataset.pt'
  39. )
  40. print(f'Loaded {len(test_set)} samples')
  41. # [([model results], labels)]
  42. results = []
  43. # [(class_1, class_2, true_label)]
  44. indv_results = []
  45. for i, (data, target) in tqdm(
  46. enumerate(test_set),
  47. total=len(test_set),
  48. desc='Getting predictions',
  49. unit='sample',
  50. ):
  51. mri, xls = data
  52. mri = mri.unsqueeze(0).to(config['training']['device'])
  53. xls = xls.unsqueeze(0).to(config['training']['device'])
  54. data = (mri, xls)
  55. res = []
  56. for j, model in enumerate(models):
  57. model.eval()
  58. with torch.no_grad():
  59. output = model(data)
  60. output = output.tolist()
  61. if j == 0:
  62. indv_results.append((output[0][0], output[0][1], target[1].item()))
  63. res.append(output)
  64. results.append((res, target.tolist()))
  65. # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
  66. # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
  67. # And one with the ensemble predicted class, ensemble standard deviation and true label
  68. # [(ensemble predicted class, ensemble confidence, true label)]
  69. confidences = []
  70. # [(ensemble predicted class, ensemble standard deviation, true label)]
  71. stdevs = []
  72. # [(ensemble predicted class, ensemble entropy, true label)]
  73. entropies = []
  74. for result in results:
  75. model_results, true_label = result
  76. # Get the ensemble mean and variance with numpy, as these are lists
  77. mean = np.mean(model_results, axis=0)
  78. variance = np.var(model_results, axis=0)
  79. # Calculate the entropy
  80. entropy = -1 * np.sum(mean * np.log(mean))
  81. # Calculate confidence and standard deviation
  82. confidence = (np.max(mean) - 0.5) * 2
  83. stdev = np.sqrt(variance)
  84. # Get the predicted class
  85. predicted_class = np.argmax(mean)
  86. # Get the confidence and standard deviation of the predicted class
  87. print(stdev)
  88. pc_stdev = np.squeeze(stdev)[predicted_class]
  89. # Get the individual classes
  90. class_1 = mean[0][0]
  91. class_2 = mean[0][1]
  92. # Get the true label
  93. true_label = true_label[1]
  94. confidences.append((predicted_class, confidence, true_label, class_1, class_2))
  95. stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
  96. entropies.append((predicted_class, entropy, true_label, class_1, class_2))
  97. return results, confidences, stdevs, entropies, indv_results
  98. if RUN:
  99. results, confs, stdevs, entropies, indv_results = get_predictions(config)
  100. # Convert to pandas dataframes
  101. confs_df = pd.DataFrame(
  102. confs,
  103. columns=['predicted_class', 'confidence', 'true_label', 'class_1', 'class_2'],
  104. )
  105. stdevs_df = pd.DataFrame(
  106. stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
  107. )
  108. entropies_df = pd.DataFrame(
  109. entropies,
  110. columns=['predicted_class', 'entropy', 'true_label', 'class_1', 'class_2'],
  111. )
  112. indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
  113. if not os.path.exists(V2_PATH):
  114. os.makedirs(V2_PATH)
  115. confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
  116. stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  117. entropies_df.to_csv(f'{V2_PATH}/ensemble_entropies.csv')
  118. indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
  119. else:
  120. confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
  121. stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  122. entropies_df = pd.read_csv(f'{V2_PATH}/ensemble_entropies.csv')
  123. indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
  124. # Plot confidence vs standard deviation, and change color of dots based on if they are correct
  125. correct_conf = confs_df[confs_df['predicted_class'] == confs_df['true_label']]
  126. incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
  127. correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
  128. incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
  129. plot, ax = plt.subplots()
  130. plt.scatter(
  131. correct_conf['confidence'],
  132. correct_stdev['stdev'],
  133. color='green',
  134. label='Correct Prediction',
  135. )
  136. plt.scatter(
  137. incorrect_conf['confidence'],
  138. incorrect_stdev['stdev'],
  139. color='red',
  140. label='Incorrect Prediction',
  141. )
  142. plt.xlabel('Confidence (Raw Value)')
  143. plt.ylabel('Standard Deviation (Raw Value)')
  144. plt.title('Confidence vs Standard Deviation')
  145. plt.legend()
  146. plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
  147. plt.close()
  148. # Calculate individual model accuracy
  149. # Determine predicted class
  150. indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
  151. indv_df['predicted_class'] = indv_df['predicted_class'].apply(
  152. lambda x: 0 if x == 'class_1' else 1
  153. )
  154. indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
  155. accuracy_indv = indv_df['correct'].mean()
  156. f1_indv = met.F1(
  157. indv_df['predicted_class'].to_numpy(), indv_df['true_label'].to_numpy()
  158. )
  159. auc_indv = metrics.roc_auc_score(
  160. indv_df['true_label'].to_numpy(), indv_df['class_2'].to_numpy()
  161. )
  162. # Calculate percentiles for confidence and standard deviation
  163. quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  164. 'confidence'
  165. ]
  166. quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  167. 'stdev'
  168. ]
  169. accuracies_conf = []
  170. # Use the quantiles to calculate the coverage
  171. iter_conf = it.islice(quantiles_conf.items(), 0, None)
  172. for quantile in iter_conf:
  173. percentile = quantile[0]
  174. filt = confs_df[confs_df['confidence'] >= quantile[1]]
  175. accuracy = (
  176. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  177. )
  178. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  179. accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  180. accuracies_df = pd.DataFrame(accuracies_conf)
  181. # Plot the coverage
  182. fig, ax = plt.subplots()
  183. plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], 'ob', label='Ensemble')
  184. plt.plot(
  185. accuracies_df['percentile'],
  186. [accuracy_indv] * len(accuracies_df['percentile']),
  187. 'xr',
  188. label='Individual (on entire dataset)',
  189. )
  190. plt.xlabel('Minimum Confidence Percentile (Low to High)')
  191. plt.ylabel('Accuracy')
  192. plt.title('Confidence Accuracy Coverage Plot')
  193. plt.legend()
  194. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  195. plt.savefig(f'{V2_PATH}/coverage_conf.png')
  196. plt.close()
  197. # Plot coverage vs F1 for confidence
  198. fig, ax = plt.subplots()
  199. plt.plot(accuracies_df['percentile'], accuracies_df['f1'], 'ob', label='Ensemble')
  200. plt.plot(
  201. accuracies_df['percentile'],
  202. [f1_indv] * len(accuracies_df['percentile']),
  203. 'xr',
  204. label='Individual (on entire dataset)',
  205. )
  206. plt.xlabel('Minimum Confidence Percentile (Low to High)')
  207. plt.ylabel('F1')
  208. plt.title('Confidence F1 Coverage Plot')
  209. plt.legend()
  210. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  211. plt.savefig(f'{V2_PATH}/coverage_f1_conf.png')
  212. plt.close()
  213. # Repeat for standard deviation
  214. accuracies_stdev = []
  215. iter_stdev = it.islice(quantiles_stdev.items(), 0, None)
  216. for quantile in iter_stdev:
  217. percentile = quantile[0]
  218. filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
  219. accuracy = (
  220. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  221. )
  222. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  223. accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  224. accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
  225. # Plot the coverage
  226. fig, ax = plt.subplots()
  227. plt.plot(
  228. accuracies_stdev_df['percentile'],
  229. accuracies_stdev_df['accuracy'],
  230. 'ob',
  231. label='Ensemble',
  232. )
  233. plt.plot(
  234. accuracies_stdev_df['percentile'],
  235. [accuracy_indv] * len(accuracies_stdev_df['percentile']),
  236. 'xr',
  237. label='Individual (on entire dataset)',
  238. )
  239. plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
  240. plt.ylabel('Accuracy')
  241. plt.title('Standard Deviation Accuracy Coverage Plot')
  242. plt.legend()
  243. plt.gca().invert_xaxis()
  244. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  245. plt.savefig(f'{V2_PATH}/coverage_stdev.png')
  246. plt.close()
  247. # Plot coverage vs F1 for standard deviation
  248. fig, ax = plt.subplots()
  249. plt.plot(
  250. accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], 'ob', label='Ensemble'
  251. )
  252. plt.plot(
  253. accuracies_stdev_df['percentile'],
  254. [f1_indv] * len(accuracies_stdev_df['percentile']),
  255. 'xr',
  256. label='Individual (on entire dataset)',
  257. )
  258. plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
  259. plt.ylabel('F1')
  260. plt.title('Standard Deviation F1 Coverage Plot')
  261. plt.legend()
  262. plt.gca().invert_xaxis()
  263. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  264. plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
  265. plt.close()
  266. # Print overall accuracy
  267. overall_accuracy = (
  268. confs_df[confs_df['predicted_class'] == confs_df['true_label']].shape[0]
  269. / confs_df.shape[0]
  270. )
  271. overall_f1 = met.F1(
  272. confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
  273. )
  274. # Calculate ECE and MCE
  275. conf_ece = met.ECE(
  276. confs_df['predicted_class'].to_numpy(),
  277. confs_df['confidence'].to_numpy(),
  278. confs_df['true_label'].to_numpy(),
  279. )
  280. stdev_ece = met.ECE(
  281. stdevs_df['predicted_class'].to_numpy(),
  282. stdevs_df['stdev'].to_numpy(),
  283. stdevs_df['true_label'].to_numpy(),
  284. )
  285. print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1},')
  286. print(f'Confidence ECE: {conf_ece}')
  287. print(f'Standard Deviation ECE: {stdev_ece}')
  288. # Repeat for entropy
  289. quantiles_entropy = entropies_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  290. 'entropy'
  291. ]
  292. accuracies_entropy = []
  293. iter_entropy = it.islice(quantiles_entropy.items(), 0, None)
  294. for quantile in iter_entropy:
  295. percentile = quantile[0]
  296. filt = entropies_df[entropies_df['entropy'] <= quantile[1]]
  297. accuracy = (
  298. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  299. )
  300. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  301. accuracies_entropy.append(
  302. {'percentile': percentile, 'accuracy': accuracy, 'f1': f1}
  303. )
  304. accuracies_entropy_df = pd.DataFrame(accuracies_entropy)
  305. # Plot the coverage
  306. fig, ax = plt.subplots()
  307. plt.plot(
  308. accuracies_entropy_df['percentile'],
  309. accuracies_entropy_df['accuracy'],
  310. 'ob',
  311. label='Ensemble',
  312. )
  313. plt.plot(
  314. accuracies_entropy_df['percentile'],
  315. [accuracy_indv] * len(accuracies_entropy_df['percentile']),
  316. 'xr',
  317. label='Individual (on entire dataset)',
  318. )
  319. plt.xlabel('Maximum Entropy Percentile (High to Low)')
  320. plt.ylabel('Accuracy')
  321. plt.title('Entropy Accuracy Coverage Plot')
  322. plt.legend()
  323. plt.gca().invert_xaxis()
  324. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  325. plt.savefig(f'{V2_PATH}/coverage_entropy.png')
  326. plt.close()
  327. # Plot coverage vs F1 for entropy
  328. fig, ax = plt.subplots()
  329. plt.plot(
  330. accuracies_entropy_df['percentile'],
  331. accuracies_entropy_df['f1'],
  332. 'ob',
  333. label='Ensemble',
  334. )
  335. plt.plot(
  336. accuracies_entropy_df['percentile'],
  337. [f1_indv] * len(accuracies_entropy_df['percentile']),
  338. 'xr',
  339. label='Individual (on entire dataset)',
  340. )
  341. plt.xlabel('Maximum Entropy Percentile (High to Low)')
  342. plt.ylabel('F1')
  343. plt.title('Entropy F1 Coverage Plot')
  344. plt.legend()
  345. plt.gca().invert_xaxis()
  346. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  347. plt.savefig(f'{V2_PATH}/coverage_f1_entropy.png')
  348. plt.close()