threshold.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. RUN = False
  13. # CONFIGURATION
  14. if os.getenv('ADL_CONFIG_PATH') is None:
  15. with open('config.toml', 'rb') as f:
  16. config = toml.load(f)
  17. else:
  18. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  19. config = toml.load(f)
  20. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  21. V2_PATH = ENSEMBLE_PATH + '/v2'
  22. # Result is a 1x2 tensor, with the softmax of the 2 predicted classes
  23. # Want to convert to a predicted class and a confidence
  24. def output_to_confidence(result):
  25. predicted_class = torch.argmax(result).item()
  26. confidence = (torch.max(result).item() - 0.5) * 2
  27. return torch.Tensor([predicted_class, confidence])
  28. # This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
  29. def get_predictions(config):
  30. models, model_descs = ens.load_models(
  31. f'{ENSEMBLE_PATH}/models/',
  32. config['training']['device'],
  33. )
  34. models = [model.to(config['training']['device']) for model in models]
  35. test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  36. f'{ENSEMBLE_PATH}/val_dataset.pt'
  37. )
  38. # [([model results], labels)]
  39. results = []
  40. # [(class_1, class_2, true_label)]
  41. indv_results = []
  42. for i, (data, target) in tqdm(
  43. enumerate(test_set),
  44. total=len(test_set),
  45. desc='Getting predictions',
  46. unit='sample',
  47. ):
  48. mri, xls = data
  49. mri = mri.unsqueeze(0).to(config['training']['device'])
  50. xls = xls.unsqueeze(0).to(config['training']['device'])
  51. data = (mri, xls)
  52. res = []
  53. for j, model in enumerate(models):
  54. model.eval()
  55. with torch.no_grad():
  56. output = model(data)
  57. output = output.tolist()
  58. if j == 0:
  59. indv_results.append((output[0][0], output[0][1], target[1].item()))
  60. res.append(output)
  61. results.append((res, target.tolist()))
  62. # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
  63. # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
  64. # And one with the ensemble predicted class, ensemble standard deviation and true label
  65. # [(ensemble predicted class, ensemble confidence, true label)]
  66. confidences = []
  67. # [(ensemble predicted class, ensemble standard deviation, true label)]
  68. stdevs = []
  69. for result in results:
  70. model_results, true_label = result
  71. # Get the ensemble mean and variance with numpy, as these are lists
  72. mean = np.mean(model_results, axis=0)
  73. variance = np.var(model_results, axis=0)
  74. # Calculate confidence and standard deviation
  75. confidence = (np.max(mean) - 0.5) * 2
  76. stdev = np.sqrt(variance)
  77. # Get the predicted class
  78. predicted_class = np.argmax(mean)
  79. # Get the confidence and standard deviation of the predicted class
  80. print(stdev)
  81. pc_stdev = np.squeeze(stdev)[predicted_class]
  82. # Get the true label
  83. true_label = true_label[1]
  84. confidences.append((predicted_class, confidence, true_label))
  85. stdevs.append((predicted_class, pc_stdev, true_label))
  86. return results, confidences, stdevs, indv_results
  87. if RUN:
  88. results, confs, stdevs, indv_results = get_predictions(config)
  89. # Convert to pandas dataframes
  90. confs_df = pd.DataFrame(
  91. confs, columns=['predicted_class', 'confidence', 'true_label']
  92. )
  93. stdevs_df = pd.DataFrame(stdevs, columns=['predicted_class', 'stdev', 'true_label'])
  94. indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
  95. if not os.path.exists(V2_PATH):
  96. os.makedirs(V2_PATH)
  97. confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
  98. stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  99. indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
  100. else:
  101. confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
  102. stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  103. indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
  104. # Plot confidence vs standard deviation
  105. plt.scatter(confs_df['confidence'], stdevs_df['stdev'])
  106. plt.xlabel('Confidence')
  107. plt.ylabel('Standard Deviation')
  108. plt.title('Confidence vs Standard Deviation')
  109. plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
  110. plt.close()
  111. # Calculate Binning for Coverage
  112. # Sort Dataframes
  113. confs_df = confs_df.sort_values(by='confidence')
  114. stdevs_df = stdevs_df.sort_values(by='stdev')
  115. confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
  116. stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
  117. # Calculate individual model accuracy
  118. # Determine predicted class
  119. indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
  120. indv_df['predicted_class'] = indv_df['predicted_class'].apply(
  121. lambda x: 0 if x == 'class_1' else 1
  122. )
  123. indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
  124. accuracy_indv = indv_df['correct'].mean()
  125. # Calculate percentiles for confidence and standard deviation
  126. quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11))['confidence']
  127. quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11))['stdev']
  128. accuracies_conf = []
  129. # Use the quantiles to calculate the coverage
  130. for quantile in quantiles_conf.items():
  131. percentile = quantile[0]
  132. filt = confs_df[confs_df['confidence'] >= quantile[1]]
  133. accuracy = (
  134. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  135. )
  136. accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy})
  137. accuracies_df = pd.DataFrame(accuracies_conf)
  138. # Plot the coverage
  139. plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], label='Ensemble')
  140. plt.plot(
  141. accuracies_df['percentile'],
  142. [accuracy_indv] * len(accuracies_df['percentile']),
  143. label='Individual',
  144. linestyle='--',
  145. )
  146. plt.xlabel('Percentile')
  147. plt.ylabel('Accuracy')
  148. plt.title('Coverage conf')
  149. plt.legend()
  150. plt.savefig(f'{V2_PATH}/coverage.png')
  151. plt.close()
  152. # Repeat for standard deviation
  153. accuracies_stdev = []
  154. for quantile in quantiles_stdev.items():
  155. percentile = quantile[0]
  156. filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
  157. accuracy = (
  158. filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
  159. )
  160. accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy})
  161. accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
  162. # Plot the coverage
  163. plt.plot(
  164. accuracies_stdev_df['percentile'], accuracies_stdev_df['accuracy'], label='Ensemble'
  165. )
  166. plt.plot(
  167. accuracies_stdev_df['percentile'],
  168. [accuracy_indv] * len(accuracies_stdev_df['percentile']),
  169. label='Individual',
  170. linestyle='--',
  171. )
  172. plt.xlabel('Percentile')
  173. plt.ylabel('Accuracy')
  174. plt.title('Coverage Stdev')
  175. plt.legend()
  176. plt.gca().invert_xaxis()
  177. plt.savefig(f'{V2_PATH}/coverage_stdev.png')
  178. plt.close()