threshold.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. import itertools as it
  13. import matplotlib.ticker as ticker
  14. # Define plotting helper function
  15. def plot_coverage(
  16. percentiles,
  17. ensemble_results,
  18. individual_results,
  19. title,
  20. x_lablel,
  21. y_label,
  22. save_path,
  23. flip=False,
  24. ):
  25. fig, ax = plt.subplots()
  26. plt.plot(
  27. percentiles,
  28. ensemble_results,
  29. 'ob',
  30. label='Ensemble',
  31. )
  32. plt.plot(
  33. percentiles,
  34. individual_results,
  35. 'xr',
  36. label='Individual (on entire dataset)',
  37. )
  38. plt.xlabel(x_lablel)
  39. plt.ylabel(y_label)
  40. plt.title(title)
  41. plt.legend()
  42. if flip:
  43. plt.gca().invert_xaxis()
  44. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  45. plt.savefig(save_path)
  46. plt.close()
  47. RUN = False
  48. # CONFIGURATION
  49. if os.getenv('ADL_CONFIG_PATH') is None:
  50. with open('config.toml', 'rb') as f:
  51. config = toml.load(f)
  52. else:
  53. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  54. config = toml.load(f)
  55. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  56. V2_PATH = ENSEMBLE_PATH + '/v2'
  57. # Result is a 1x2 tensor, with the softmax of the 2 predicted classes
  58. # Want to convert to a predicted class and a confidence
  59. def output_to_confidence(result):
  60. predicted_class = torch.argmax(result).item()
  61. confidence = (torch.max(result).item() - 0.5) * 2
  62. return torch.Tensor([predicted_class, confidence])
  63. # This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
  64. def get_predictions(config):
  65. models, model_descs = ens.load_models(
  66. f'{ENSEMBLE_PATH}/models/',
  67. config['training']['device'],
  68. )
  69. models = [model.to(config['training']['device']) for model in models]
  70. test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  71. f'{ENSEMBLE_PATH}/val_dataset.pt'
  72. )
  73. print(f'Loaded {len(test_set)} samples')
  74. # [([model results], labels)]
  75. results = []
  76. # [(class_1, class_2, true_label)]
  77. indv_results = []
  78. for i, (data, target) in tqdm(
  79. enumerate(test_set),
  80. total=len(test_set),
  81. desc='Getting predictions',
  82. unit='sample',
  83. ):
  84. mri, xls = data
  85. mri = mri.unsqueeze(0).to(config['training']['device'])
  86. xls = xls.unsqueeze(0).to(config['training']['device'])
  87. data = (mri, xls)
  88. res = []
  89. for j, model in enumerate(models):
  90. model.eval()
  91. with torch.no_grad():
  92. output = model(data)
  93. output = output.tolist()
  94. if j == 0:
  95. indv_results.append((output[0][0], output[0][1], target[1].item()))
  96. res.append(output)
  97. results.append((res, target.tolist()))
  98. # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
  99. # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
  100. # And one with the ensemble predicted class, ensemble standard deviation and true label
  101. # [(ensemble predicted class, ensemble confidence, true label)]
  102. confidences = []
  103. # [(ensemble predicted class, ensemble standard deviation, true label)]
  104. stdevs = []
  105. # [(ensemble predicted class, ensemble entropy, true label)]
  106. entropies = []
  107. for result in results:
  108. model_results, true_label = result
  109. # Get the ensemble mean and variance with numpy, as these are lists
  110. mean = np.mean(model_results, axis=0)
  111. variance = np.var(model_results, axis=0)
  112. # Calculate the entropy
  113. entropy = -1 * np.sum(mean * np.log(mean))
  114. # Calculate confidence and standard deviation
  115. confidence = (np.max(mean) - 0.5) * 2
  116. stdev = np.sqrt(variance)
  117. # Get the predicted class
  118. predicted_class = np.argmax(mean)
  119. # Get the confidence and standard deviation of the predicted class
  120. pc_stdev = np.squeeze(stdev)[predicted_class]
  121. # Get the individual classes
  122. class_1 = mean[0][0]
  123. class_2 = mean[0][1]
  124. # Get the true label
  125. true_label = true_label[1]
  126. confidences.append((predicted_class, confidence, true_label, class_1, class_2))
  127. stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
  128. entropies.append((predicted_class, entropy, true_label, class_1, class_2))
  129. return results, confidences, stdevs, entropies, indv_results
  130. if RUN:
  131. results, confs, stdevs, entropies, indv_results = get_predictions(config)
  132. # Convert to pandas dataframes
  133. confs_df = pd.DataFrame(
  134. confs,
  135. columns=['predicted_class', 'confidence', 'true_label', 'class_1', 'class_2'],
  136. )
  137. stdevs_df = pd.DataFrame(
  138. stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
  139. )
  140. entropies_df = pd.DataFrame(
  141. entropies,
  142. columns=['predicted_class', 'entropy', 'true_label', 'class_1', 'class_2'],
  143. )
  144. indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
  145. if not os.path.exists(V2_PATH):
  146. os.makedirs(V2_PATH)
  147. confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
  148. stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  149. entropies_df.to_csv(f'{V2_PATH}/ensemble_entropies.csv')
  150. indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
  151. else:
  152. confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
  153. stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  154. entropies_df = pd.read_csv(f'{V2_PATH}/ensemble_entropies.csv')
  155. indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
  156. # Plot confidence vs standard deviation, and change color of dots based on if they are correct
  157. correct_conf = confs_df[confs_df['predicted_class'] == confs_df['true_label']]
  158. incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
  159. correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
  160. incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
  161. correct_ent = entropies_df[
  162. entropies_df['predicted_class'] == entropies_df['true_label']
  163. ]
  164. incorrect_ent = entropies_df[
  165. entropies_df['predicted_class'] != entropies_df['true_label']
  166. ]
  167. plot, ax = plt.subplots()
  168. plt.scatter(
  169. correct_conf['confidence'],
  170. correct_stdev['stdev'],
  171. color='green',
  172. label='Correct Prediction',
  173. )
  174. plt.scatter(
  175. incorrect_conf['confidence'],
  176. incorrect_stdev['stdev'],
  177. color='red',
  178. label='Incorrect Prediction',
  179. )
  180. plt.xlabel('Confidence (Raw Value)')
  181. plt.ylabel('Standard Deviation (Raw Value)')
  182. plt.title('Confidence vs Standard Deviation')
  183. plt.legend()
  184. plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
  185. plt.close()
  186. # Do the same for confidence vs entropy
  187. plot, ax = plt.subplots()
  188. plt.scatter(
  189. correct_conf['confidence'],
  190. correct_ent['entropy'],
  191. color='green',
  192. label='Correct Prediction',
  193. )
  194. plt.scatter(
  195. incorrect_conf['confidence'],
  196. incorrect_ent['entropy'],
  197. color='red',
  198. label='Incorrect Prediction',
  199. )
  200. plt.xlabel('Confidence (Raw Value)')
  201. plt.ylabel('Entropy (Raw Value)')
  202. plt.title('Confidence vs Entropy')
  203. plt.legend()
  204. plt.savefig(f'{V2_PATH}/confidence_vs_entropy.png')
  205. plt.close()
  206. # Calculate individual model accuracy and entropy
  207. # Determine predicted class
  208. indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
  209. indv_df['predicted_class'] = indv_df['predicted_class'].apply(
  210. lambda x: 0 if x == 'class_1' else 1
  211. )
  212. indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
  213. accuracy_indv = indv_df['correct'].mean()
  214. f1_indv = met.F1(
  215. indv_df['predicted_class'].to_numpy(), indv_df['true_label'].to_numpy()
  216. )
  217. auc_indv = metrics.roc_auc_score(
  218. indv_df['true_label'].to_numpy(), indv_df['class_2'].to_numpy()
  219. )
  220. indv_df['entropy'] = -1 * indv_df[['class_1', 'class_2']].apply(
  221. lambda x: x * np.log(x), axis=0
  222. ).sum(axis=1)
  223. # Calculate percentiles for confidence and standard deviation
  224. quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  225. 'confidence'
  226. ]
  227. quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  228. 'stdev'
  229. ]
  230. # Additionally for individual confidence
  231. quantiles_indv_conf = indv_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  232. 'class_2'
  233. ]
  234. # For indivual entropy
  235. quantiles_indv_entropy = indv_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  236. 'entropy'
  237. ]
  238. #
  239. accuracies_conf = []
  240. # Use the quantiles to calculate the coverage
  241. iter_conf = it.islice(quantiles_conf.items(), 0, None)
  242. for quantile in iter_conf:
  243. percentile = quantile[0]
  244. filt = confs_df[confs_df['confidence'] >= quantile[1]]
  245. accuracy = (
  246. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  247. )
  248. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  249. accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  250. accuracies_df = pd.DataFrame(accuracies_conf)
  251. indv_conf = []
  252. # Use the quantiles to calculate the coverage
  253. iter_conf = it.islice(quantiles_indv_conf.items(), 0, None)
  254. for quantile in iter_conf:
  255. percentile = quantile[0]
  256. filt = indv_df[indv_df['class_2'] >= quantile[1]]
  257. accuracy = filt['correct'].mean()
  258. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  259. indv_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  260. indv_conf_df = pd.DataFrame(indv_conf)
  261. # Do the same for entropy
  262. indv_entropy = []
  263. iter_entropy = it.islice(quantiles_indv_entropy.items(), 0, None)
  264. for quantile in iter_entropy:
  265. percentile = quantile[0]
  266. filt = indv_df[indv_df['entropy'] <= quantile[1]]
  267. accuracy = filt['correct'].mean()
  268. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  269. indv_entropy.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  270. indv_entropy_df = pd.DataFrame(indv_entropy)
  271. # Plot the coverage for confidence and accuracy
  272. plot_coverage(
  273. accuracies_df['percentile'],
  274. accuracies_df['accuracy'],
  275. indv_conf_df['accuracy'],
  276. 'Confidence Accuracy Coverage Plot',
  277. 'Minimum Confidence Percentile (Low to High)',
  278. 'Accuracy',
  279. f'{V2_PATH}/coverage_conf.png',
  280. )
  281. # Plot the coverage for confidence and F1
  282. plot_coverage(
  283. accuracies_df['percentile'],
  284. accuracies_df['f1'],
  285. indv_conf_df['f1'],
  286. 'Confidence F1 Coverage Plot',
  287. 'Minimum Confidence Percentile (Low to High)',
  288. 'F1',
  289. f'{V2_PATH}/f1_coverage_conf.png',
  290. )
  291. # Repeat for standard deviation
  292. accuracies_stdev = []
  293. iter_stdev = it.islice(quantiles_stdev.items(), 0, None)
  294. for quantile in iter_stdev:
  295. percentile = quantile[0]
  296. filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
  297. accuracy = (
  298. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  299. )
  300. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  301. accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
  302. accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
  303. fig, ax = plt.subplots()
  304. plt.plot(
  305. accuracies_stdev_df['percentile'],
  306. accuracies_stdev_df['accuracy'],
  307. 'ob',
  308. label='Ensemble',
  309. )
  310. plt.plot(
  311. accuracies_stdev_df['percentile'],
  312. [accuracy_indv] * len(accuracies_stdev_df['percentile']),
  313. 'xr',
  314. label='Individual (on entire dataset)',
  315. )
  316. plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
  317. plt.ylabel('Accuracy')
  318. plt.title('Standard Deviation Accuracy Coverage Plot')
  319. plt.legend()
  320. plt.gca().invert_xaxis()
  321. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  322. plt.savefig(f'{V2_PATH}/coverage_stdev.png')
  323. plt.close()
  324. # Plot coverage vs F1 for standard deviation
  325. fig, ax = plt.subplots()
  326. plt.plot(
  327. accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], 'ob', label='Ensemble'
  328. )
  329. plt.plot(
  330. accuracies_stdev_df['percentile'],
  331. [f1_indv] * len(accuracies_stdev_df['percentile']),
  332. 'xr',
  333. label='Individual (on entire dataset)',
  334. )
  335. plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
  336. plt.ylabel('F1')
  337. plt.title('Standard Deviation F1 Coverage Plot')
  338. plt.legend()
  339. plt.gca().invert_xaxis()
  340. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
  341. plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
  342. plt.close()
  343. # Print overall accuracy
  344. overall_accuracy = (
  345. confs_df[confs_df['predicted_class'] == confs_df['true_label']].shape[0]
  346. / confs_df.shape[0]
  347. )
  348. overall_f1 = met.F1(
  349. confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
  350. )
  351. # Calculate ECE and MCE
  352. conf_ece = met.ECE(
  353. confs_df['predicted_class'].to_numpy(),
  354. confs_df['confidence'].to_numpy(),
  355. confs_df['true_label'].to_numpy(),
  356. )
  357. stdev_ece = met.ECE(
  358. stdevs_df['predicted_class'].to_numpy(),
  359. stdevs_df['stdev'].to_numpy(),
  360. stdevs_df['true_label'].to_numpy(),
  361. )
  362. print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1},')
  363. print(f'Confidence ECE: {conf_ece}')
  364. print(f'Standard Deviation ECE: {stdev_ece}')
  365. # Repeat for entropy
  366. quantiles_entropy = entropies_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
  367. 'entropy'
  368. ]
  369. accuracies_entropy = []
  370. iter_entropy = it.islice(quantiles_entropy.items(), 0, None)
  371. for quantile in iter_entropy:
  372. percentile = quantile[0]
  373. filt = entropies_df[entropies_df['entropy'] <= quantile[1]]
  374. accuracy = (
  375. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  376. )
  377. f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
  378. accuracies_entropy.append(
  379. {'percentile': percentile, 'accuracy': accuracy, 'f1': f1}
  380. )
  381. accuracies_entropy_df = pd.DataFrame(accuracies_entropy)
  382. # Plot the coverage for entropy and accuracy
  383. plot_coverage(
  384. accuracies_entropy_df['percentile'],
  385. accuracies_entropy_df['accuracy'],
  386. indv_entropy_df['accuracy'],
  387. 'Entropy Accuracy Coverage Plot',
  388. 'Minimum Entropy Percentile (Low to High)',
  389. 'Accuracy',
  390. f'{V2_PATH}/coverage_entropy.png',
  391. )
  392. # Plot the coverage for entropy and F1
  393. plot_coverage(
  394. accuracies_entropy_df['percentile'],
  395. accuracies_entropy_df['f1'],
  396. indv_entropy_df['f1'],
  397. 'Entropy F1 Coverage Plot',
  398. 'Maximum Entropy Percentile (High to Low)',
  399. 'F1',
  400. f'{V2_PATH}/f1_coverage_entropy.png',
  401. flip=True,
  402. )