threshold.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. import utils.ensemble as ens
  6. import torch
  7. import matplotlib.pyplot as plt
  8. import sklearn.metrics as metrics
  9. from tqdm import tqdm
  10. import utils.metrics as met
  11. import itertools as it
  12. import matplotlib.ticker as ticker
  13. # Define plotting helper function
  14. def plot_coverage(
  15. percentiles,
  16. ensemble_results,
  17. individual_results,
  18. title,
  19. x_lablel,
  20. y_label,
  21. save_path,
  22. flip=False,
  23. ):
  24. fig, ax = plt.subplots()
  25. plt.plot(
  26. percentiles,
  27. ensemble_results,
  28. 'ob',
  29. label='Ensemble',
  30. )
  31. plt.plot(
  32. percentiles,
  33. individual_results,
  34. 'xr',
  35. label='Individual (on entire dataset)',
  36. )
  37. plt.xlabel(x_lablel)
  38. plt.ylabel(y_label)
  39. plt.title(title)
  40. plt.legend()
  41. if flip:
  42. plt.gca().invert_xaxis()
  43. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  44. plt.savefig(save_path)
  45. plt.close()
  46. RUN = False
  47. # CONFIGURATION
  48. if os.getenv('ADL_CONFIG_PATH') is None:
  49. with open('config.toml', 'rb') as f:
  50. config = toml.load(f)
  51. else:
  52. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  53. config = toml.load(f)
  54. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  55. V2_PATH = ENSEMBLE_PATH + '/v2'
  56. # Result is a 1x2 tensor, with the softmax of the 2 predicted classes
  57. # Want to convert to a predicted class and a confidence
  58. def output_to_confidence(result):
  59. predicted_class = torch.argmax(result).item()
  60. confidence = (torch.max(result).item() - 0.5) * 2
  61. return torch.Tensor([predicted_class, confidence])
  62. # This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
  63. def get_predictions(config):
  64. models, model_descs = ens.load_models(
  65. f'{ENSEMBLE_PATH}/models/',
  66. config['training']['device'],
  67. )
  68. models = [model.to(config['training']['device']) for model in models]
  69. test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  70. f'{ENSEMBLE_PATH}/val_dataset.pt'
  71. )
  72. print(f'Loaded {len(test_set)} samples')
  73. # [([model results], labels)]
  74. results = []
  75. # [(class_1, class_2, true_label)]
  76. indv_results = []
  77. for _, (data, target) in tqdm(
  78. enumerate(test_set),
  79. total=len(test_set),
  80. desc='Getting predictions',
  81. unit='sample',
  82. ):
  83. mri, xls = data
  84. mri = mri.unsqueeze(0).to(config['training']['device'])
  85. xls = xls.unsqueeze(0).to(config['training']['device'])
  86. data = (mri, xls)
  87. res = []
  88. for j, model in enumerate(models):
  89. model.eval()
  90. with torch.no_grad():
  91. output = model(data)
  92. output = output.tolist()
  93. if j == 0:
  94. indv_results.append((output[0][0], output[0][1], target[1].item()))
  95. res.append(output)
  96. results.append((res, target.tolist()))
  97. # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
  98. # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
  99. # And one with the ensemble predicted class, ensemble standard deviation and true label
  100. # [(ensemble predicted class, ensemble confidence, true label)]
  101. confidences = []
  102. # [(ensemble predicted class, ensemble standard deviation, true label)]
  103. stdevs = []
  104. # [(ensemble predicted class, ensemble entropy, true label)]
  105. entropies = []
  106. for result in results:
  107. model_results, true_label = result
  108. # Get the ensemble mean and variance with numpy, as these are lists
  109. mean = np.mean(model_results, axis=0)
  110. variance = np.var(model_results, axis=0)
  111. # Calculate the entropy
  112. entropy = -1 * np.sum(mean * np.log(mean))
  113. # Calculate confidence and standard deviation
  114. confidence = (np.max(mean) - 0.5) * 2
  115. stdev = np.sqrt(variance)
  116. # Get the predicted class
  117. predicted_class = np.argmax(mean)
  118. # Get the confidence and standard deviation of the predicted class
  119. pc_stdev = np.squeeze(stdev)[predicted_class]
  120. # Get the individual classes
  121. class_1 = mean[0][0]
  122. class_2 = mean[0][1]
  123. # Get the true label
  124. true_label = true_label[1]
  125. confidences.append((predicted_class, confidence, true_label, class_1, class_2))
  126. stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
  127. entropies.append((predicted_class, entropy, true_label, class_1, class_2))
  128. return results, confidences, stdevs, entropies, indv_results
  129. if RUN:
  130. results, confs, stdevs, entropies, indv_results = get_predictions(config)
  131. # Convert to pandas dataframes
  132. confs_df = pd.DataFrame(
  133. confs,
  134. columns=['predicted_class', 'confidence', 'true_label', 'class_1', 'class_2'],
  135. )
  136. stdevs_df = pd.DataFrame(
  137. stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
  138. )
  139. entropies_df = pd.DataFrame(
  140. entropies,
  141. columns=['predicted_class', 'entropy', 'true_label', 'class_1', 'class_2'],
  142. )
  143. indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
  144. if not os.path.exists(V2_PATH):
  145. os.makedirs(V2_PATH)
  146. confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
  147. stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  148. entropies_df.to_csv(f'{V2_PATH}/ensemble_entropies.csv')
  149. indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
  150. else:
  151. confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
  152. stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  153. entropies_df = pd.read_csv(f'{V2_PATH}/ensemble_entropies.csv')
  154. indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
  155. # Plot confidence vs standard deviation, and change color of dots based on if they are correct
  156. correct_conf = confs_df[confs_df['predicted_class'] == confs_df['true_label']]
  157. incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
  158. correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
  159. incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
  160. correct_ent = entropies_df[
  161. entropies_df['predicted_class'] == entropies_df['true_label']
  162. ]
  163. incorrect_ent = entropies_df[
  164. entropies_df['predicted_class'] != entropies_df['true_label']
  165. ]
  166. plot, ax = plt.subplots()
  167. plt.scatter(
  168. correct_conf['confidence'],
  169. correct_stdev['stdev'],
  170. color='green',
  171. label='Correct Prediction',
  172. )
  173. plt.scatter(
  174. incorrect_conf['confidence'],
  175. incorrect_stdev['stdev'],
  176. color='red',
  177. label='Incorrect Prediction',
  178. )
  179. plt.xlabel('Confidence (Raw Value)')
  180. plt.ylabel('Standard Deviation (Raw Value)')
  181. plt.title('Confidence vs Standard Deviation')
  182. plt.legend()
  183. plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
  184. plt.close()
  185. # Do the same for confidence vs entropy
  186. plot, ax = plt.subplots()
  187. plt.scatter(
  188. correct_conf['confidence'],
  189. correct_ent['entropy'],
  190. color='green',
  191. label='Correct Prediction',
  192. )
  193. plt.scatter(
  194. incorrect_conf['confidence'],
  195. incorrect_ent['entropy'],
  196. color='red',
  197. label='Incorrect Prediction',
  198. )
  199. plt.xlabel('Confidence (Raw Value)')
  200. plt.ylabel('Entropy (Raw Value)')
  201. plt.title('Confidence vs Entropy')
  202. plt.legend()
  203. plt.savefig(f'{V2_PATH}/confidence_vs_entropy.png')
  204. plt.close()
  205. # Calculate individual model accuracy and entropy
  206. # Determine predicted class
  207. indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
  208. indv_df['predicted_class'] = indv_df['predicted_class'].apply(
  209. lambda x: 0 if x == 'class_1' else 1
  210. )
  211. indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
  212. accuracy_indv = indv_df['correct'].mean()
  213. f1_indv = met.F1(
  214. indv_df['predicted_class'].to_numpy(), indv_df['true_label'].to_numpy()
  215. )
  216. auc_indv = metrics.roc_auc_score(
  217. indv_df['true_label'].to_numpy(), indv_df['class_2'].to_numpy()
  218. )
  219. indv_df['entropy'] = -1 * indv_df[['class_1', 'class_2']].apply(
  220. lambda x: x * np.log(x), axis=0
  221. ).sum(axis=1)
  222. # Calculate percentiles for confidence and standard deviation
  223. quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  224. 'confidence'
  225. ]
  226. quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  227. 'stdev'
  228. ]
  229. # Additionally for individual confidence
  230. quantiles_indv_conf = indv_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  231. 'class_2'
  232. ]
  233. # For indivual entropy
  234. quantiles_indv_entropy = indv_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  235. 'entropy'
  236. ]
  237. #
  238. accuracies_conf = []
  239. # Use the quantiles to calculate the coverage
  240. iter_conf = it.islice(quantiles_conf.items(), 0, None)
  241. for quantile in iter_conf:
  242. percentile = quantile[0]
  243. filt = confs_df[confs_df['confidence'] >= quantile[1]]
  244. accuracy = (
  245. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  246. )
  247. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  248. accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  249. accuracies_df = pd.DataFrame(accuracies_conf)
  250. indv_conf = []
  251. # Use the quantiles to calculate the coverage
  252. iter_conf = it.islice(quantiles_indv_conf.items(), 0, None)
  253. for quantile in iter_conf:
  254. percentile = quantile[0]
  255. filt = indv_df[indv_df['class_2'] >= quantile[1]]
  256. accuracy = filt['correct'].mean()
  257. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  258. indv_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  259. indv_conf_df = pd.DataFrame(indv_conf)
  260. # Do the same for entropy
  261. indv_entropy = []
  262. iter_entropy = it.islice(quantiles_indv_entropy.items(), 0, None)
  263. for quantile in iter_entropy:
  264. percentile = quantile[0]
  265. filt = indv_df[indv_df['entropy'] <= quantile[1]]
  266. accuracy = filt['correct'].mean()
  267. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  268. indv_entropy.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  269. indv_entropy_df = pd.DataFrame(indv_entropy)
  270. # Plot the coverage for confidence and accuracy
  271. plot_coverage(
  272. accuracies_df['percentile'],
  273. accuracies_df['accuracy'],
  274. indv_conf_df['accuracy'],
  275. 'Confidence Accuracy Coverage Plot',
  276. 'Minimum Confidence Percentile (Low to High)',
  277. 'Accuracy',
  278. f'{V2_PATH}/coverage_conf.png',
  279. )
  280. # Plot the coverage for confidence and F1
  281. plot_coverage(
  282. accuracies_df['percentile'],
  283. accuracies_df['f1'],
  284. indv_conf_df['f1'],
  285. 'Confidence F1 Coverage Plot',
  286. 'Minimum Confidence Percentile (Low to High)',
  287. 'F1',
  288. f'{V2_PATH}/f1_coverage_conf.png',
  289. )
  290. # Repeat for standard deviation
  291. accuracies_stdev = []
  292. iter_stdev = it.islice(quantiles_stdev.items(), 0, None)
  293. for quantile in iter_stdev:
  294. percentile = quantile[0]
  295. filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
  296. accuracy = (
  297. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  298. )
  299. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  300. accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  301. accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
  302. fig, ax = plt.subplots()
  303. plt.plot(
  304. accuracies_stdev_df['percentile'],
  305. accuracies_stdev_df['accuracy'],
  306. 'ob',
  307. label='Ensemble',
  308. )
  309. plt.plot(
  310. accuracies_stdev_df['percentile'],
  311. [accuracy_indv] * len(accuracies_stdev_df['percentile']),
  312. 'xr',
  313. label='Individual (on entire dataset)',
  314. )
  315. plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
  316. plt.ylabel('Accuracy')
  317. plt.title('Standard Deviation Accuracy Coverage Plot')
  318. plt.legend()
  319. plt.gca().invert_xaxis()
  320. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  321. plt.savefig(f'{V2_PATH}/coverage_stdev.png')
  322. plt.close()
  323. # Plot coverage vs F1 for standard deviation
  324. fig, ax = plt.subplots()
  325. plt.plot(
  326. accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], 'ob', label='Ensemble'
  327. )
  328. plt.plot(
  329. accuracies_stdev_df['percentile'],
  330. [f1_indv] * len(accuracies_stdev_df['percentile']),
  331. 'xr',
  332. label='Individual (on entire dataset)',
  333. )
  334. plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
  335. plt.ylabel('F1')
  336. plt.title('Standard Deviation F1 Coverage Plot')
  337. plt.legend()
  338. plt.gca().invert_xaxis()
  339. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  340. plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
  341. plt.close()
  342. # Print overall accuracy
  343. overall_accuracy = (
  344. confs_df[confs_df['predicted_class'] == confs_df['true_label']].shape[0]
  345. / confs_df.shape[0]
  346. )
  347. overall_f1 = met.F1(
  348. confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
  349. )
  350. # Calculate ECE and MCE
  351. conf_ece = met.ECE(
  352. confs_df['predicted_class'].to_numpy(),
  353. confs_df['confidence'].to_numpy(),
  354. confs_df['true_label'].to_numpy(),
  355. )
  356. stdev_ece = met.ECE(
  357. stdevs_df['predicted_class'].to_numpy(),
  358. stdevs_df['stdev'].to_numpy(),
  359. stdevs_df['true_label'].to_numpy(),
  360. )
  361. print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1},')
  362. print(f'Confidence ECE: {conf_ece}')
  363. print(f'Standard Deviation ECE: {stdev_ece}')
  364. # Repeat for entropy
  365. quantiles_entropy = entropies_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  366. 'entropy'
  367. ]
  368. accuracies_entropy = []
  369. iter_entropy = it.islice(quantiles_entropy.items(), 0, None)
  370. for quantile in iter_entropy:
  371. percentile = quantile[0]
  372. filt = entropies_df[entropies_df['entropy'] <= quantile[1]]
  373. accuracy = (
  374. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  375. )
  376. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  377. accuracies_entropy.append(
  378. {'percentile': percentile, 'accuracy': accuracy, 'f1': f1}
  379. )
  380. accuracies_entropy_df = pd.DataFrame(accuracies_entropy)
  381. # Plot the coverage for entropy and accuracy
  382. plot_coverage(
  383. accuracies_entropy_df['percentile'],
  384. accuracies_entropy_df['accuracy'],
  385. indv_entropy_df['accuracy'],
  386. 'Entropy Accuracy Coverage Plot',
  387. 'Minimum Entropy Percentile (Low to High)',
  388. 'Accuracy',
  389. f'{V2_PATH}/coverage_entropy.png',
  390. )
  391. # Plot the coverage for entropy and F1
  392. plot_coverage(
  393. accuracies_entropy_df['percentile'],
  394. accuracies_entropy_df['f1'],
  395. indv_entropy_df['f1'],
  396. 'Entropy F1 Coverage Plot',
  397. 'Maximum Entropy Percentile (High to Low)',
  398. 'F1',
  399. f'{V2_PATH}/f1_coverage_entropy.png',
  400. flip=True,
  401. )