threshold_xarray.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. # Rewritten Program to use xarray instead of pandas for thresholding
  2. import xarray as xr
  3. import torch
  4. import numpy as np
  5. import os
  6. import glob
  7. import tomli as toml
  8. from tqdm import tqdm
  9. import utils.metrics as met
  10. import matplotlib.pyplot as plt
  11. import matplotlib.ticker as mtick
  12. # The datastructures for this file are as follows
  13. # models_dict: Dictionary - {model_id: model}
  14. # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
  15. # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
  16. # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
  17. # Additionally, we also have the thresholds and statistics for the individual models
  18. # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
  19. # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
  20. # Additionally, we have some for the sensitivity analysis for number of models
  21. # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
  22. # Loads configuration dictionary
  23. def load_config():
  24. if os.getenv('ADL_CONFIG_PATH') is None:
  25. with open('config.toml', 'rb') as f:
  26. config = toml.load(f)
  27. else:
  28. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  29. config = toml.load(f)
  30. return config
  31. # Loads models into a dictionary
  32. def load_models_v2(folder, device):
  33. glob_path = os.path.join(folder, '*.pt')
  34. model_files = glob.glob(glob_path)
  35. model_dict = {}
  36. for model_file in model_files:
  37. model = torch.load(model_file, map_location=device)
  38. model_id = os.path.basename(model_file).split('_')[0]
  39. model_dict[model_id] = model
  40. if len(model_dict) == 0:
  41. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  42. return model_dict
  43. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  44. def preprocess_data(data, device):
  45. mri, xls = data
  46. mri = mri.unsqueeze(0).to(device)
  47. xls = xls.unsqueeze(0).to(device)
  48. return (mri, xls)
  49. # Loads datasets and returns concatenated test and validation datasets
  50. def load_datasets(ensemble_path):
  51. return (
  52. torch.load(f'{ensemble_path}/test_dataset.pt'),
  53. torch.load(f'{ensemble_path}/val_dataset.pt'),
  54. )
  55. # Gets the predictions for a set of models on a dataset
  56. def get_ensemble_predictions(models, dataset, device, id_offset=0):
  57. zeros = np.zeros((len(dataset), len(models), 4))
  58. predictions = xr.DataArray(
  59. zeros,
  60. dims=('data_id', 'model_id', 'prediction_value'),
  61. coords={
  62. 'data_id': range(len(dataset)),
  63. 'model_id': list(models.keys()),
  64. 'prediction_value': [
  65. 'negative_prediction',
  66. 'positive_prediction',
  67. 'negative_actual',
  68. 'positive_actual',
  69. ],
  70. },
  71. )
  72. for data_id, (data, target) in tqdm(
  73. enumerate(dataset), total=len(dataset), unit='images'
  74. ):
  75. dat = preprocess_data(data, device)
  76. actual = list(target.cpu().numpy())
  77. for model_id, model in models.items():
  78. with torch.no_grad():
  79. output = model(dat)
  80. prediction = output.cpu().numpy().tolist()[0]
  81. predictions.loc[
  82. {'data_id': data_id + id_offset, 'model_id': model_id}
  83. ] = prediction + actual
  84. return predictions
  85. # Compute the ensemble statistics given an array of predictions
  86. def compute_ensemble_statistics(predictions: xr.DataArray):
  87. zeros = np.zeros((len(predictions.data_id), 7))
  88. ensemble_statistics = xr.DataArray(
  89. zeros,
  90. dims=('data_id', 'statistic'),
  91. coords={
  92. 'data_id': predictions.data_id,
  93. 'statistic': [
  94. 'mean',
  95. 'stdev',
  96. 'entropy',
  97. 'confidence',
  98. 'correct',
  99. 'predicted',
  100. 'actual',
  101. ],
  102. },
  103. )
  104. for data_id in predictions.data_id:
  105. data = predictions.loc[{'data_id': data_id}]
  106. mean = data.mean(dim='model_id')[
  107. 0:2
  108. ] # Only take the predictions, not the actual
  109. stdev = data.std(dim='model_id')[
  110. 1
  111. ] # Only need the standard deviation of the postive prediction
  112. entropy = (-mean * np.log(mean)).sum()
  113. # Compute confidence
  114. confidence = mean.max()
  115. # only need one of the actual values, since they are all the same, just get the first actual_positive
  116. actual = data.loc[{'prediction_value': 'positive_actual'}][0]
  117. predicted = mean.argmax()
  118. correct = actual == predicted
  119. ensemble_statistics.loc[{'data_id': data_id}] = [
  120. mean[1],
  121. stdev,
  122. entropy,
  123. confidence,
  124. correct,
  125. predicted,
  126. actual,
  127. ]
  128. return ensemble_statistics
  129. # Compute the thresholded predictions given an array of predictions
  130. def compute_thresholded_predictions(input_stats: xr.DataArray):
  131. quantiles = np.linspace(0.05, 0.95, 19) * 100
  132. metrics = ['accuracy', 'f1']
  133. statistics = ['stdev', 'entropy', 'confidence']
  134. zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
  135. thresholded_predictions = xr.DataArray(
  136. zeros,
  137. dims=('quantile', 'statistic', 'metric'),
  138. coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
  139. )
  140. for statistic in statistics:
  141. # First, we must compute the quantiles for the statistic
  142. quantile_values = np.percentile(
  143. input_stats.sel(statistic=statistic).values, quantiles, axis=0
  144. )
  145. # Then, we must compute the metrics for each quantile
  146. for i, quantile in enumerate(quantiles):
  147. if low_to_high(statistic):
  148. mask = (
  149. input_stats.sel(statistic=statistic) >= quantile_values[i]
  150. ).values
  151. else:
  152. mask = (
  153. input_stats.sel(statistic=statistic) <= quantile_values[i]
  154. ).values
  155. # Filter the data based on the mask
  156. filtered_data = input_stats.where(
  157. input_stats.data_id.isin(np.where(mask)), drop=True
  158. )
  159. for metric in metrics:
  160. thresholded_predictions.loc[
  161. {'quantile': quantile, 'statistic': statistic, 'metric': metric}
  162. ] = compute_metric(filtered_data, metric)
  163. return thresholded_predictions
  164. # Truth function to determine if metric should be thresholded low to high or high to low
  165. # Low confidence is bad, high entropy is bad, high stdev is bad
  166. # So we threshold confidence low to high, entropy and stdev high to low
  167. # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
  168. def low_to_high(stat):
  169. return stat in ['confidence']
  170. # Compute a given metric on a DataArray of statstics
  171. def compute_metric(arr, metric):
  172. if metric == 'accuracy':
  173. return np.mean(arr.loc[{'statistic': 'correct'}])
  174. elif metric == 'f1':
  175. return met.F1(
  176. arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
  177. )
  178. else:
  179. raise ValueError('Invalid metric: ' + metric)
  180. # Graph a thresholded prediction for a given statistic and metric
  181. def graph_thresholded_prediction(
  182. thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel
  183. ):
  184. data = thresholded_predictions.sel(statistic=statistic, metric=metric)
  185. x_data = data.coords['quantile'].values
  186. y_data = data.values
  187. fig, ax = plt.subplots()
  188. ax.plot(x_data, y_data, 'bx-', label='Ensemble')
  189. ax.set_title(title)
  190. ax.set_xlabel(xlabel)
  191. ax.set_ylabel(ylabel)
  192. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  193. if not low_to_high(statistic):
  194. ax.invert_xaxis()
  195. plt.savefig(save_path)
  196. # Graph all thresholded predictions
  197. def graph_all_thresholded_predictions(thresholded_predictions, save_path):
  198. # Confidence Accuracy
  199. graph_thresholded_prediction(
  200. thresholded_predictions,
  201. 'confidence',
  202. 'accuracy',
  203. f'{save_path}/confidence_accuracy.png',
  204. 'Coverage Analysis of Confidence vs. Accuracy',
  205. 'Minimum Confidence Percentile Threshold',
  206. 'Accuracy',
  207. )
  208. # Confidence F1
  209. graph_thresholded_prediction(
  210. thresholded_predictions,
  211. 'confidence',
  212. 'f1',
  213. f'{save_path}/confidence_f1.png',
  214. 'Coverage Analysis of Confidence vs. F1 Score',
  215. 'Minimum Confidence Percentile Threshold',
  216. 'F1 Score',
  217. )
  218. # Entropy Accuracy
  219. graph_thresholded_prediction(
  220. thresholded_predictions,
  221. 'entropy',
  222. 'accuracy',
  223. f'{save_path}/entropy_accuracy.png',
  224. 'Coverage Analysis of Entropy vs. Accuracy',
  225. 'Maximum Entropy Percentile Threshold',
  226. 'Accuracy',
  227. )
  228. # Entropy F1
  229. graph_thresholded_prediction(
  230. thresholded_predictions,
  231. 'entropy',
  232. 'f1',
  233. f'{save_path}/entropy_f1.png',
  234. 'Coverage Analysis of Entropy vs. F1 Score',
  235. 'Maximum Entropy Percentile Threshold',
  236. 'F1 Score',
  237. )
  238. # Stdev Accuracy
  239. graph_thresholded_prediction(
  240. thresholded_predictions,
  241. 'stdev',
  242. 'accuracy',
  243. f'{save_path}/stdev_accuracy.png',
  244. 'Coverage Analysis of Standard Deviation vs. Accuracy',
  245. 'Maximum Standard Deviation Percentile Threshold',
  246. 'Accuracy',
  247. )
  248. # Stdev F1
  249. graph_thresholded_prediction(
  250. thresholded_predictions,
  251. 'stdev',
  252. 'f1',
  253. f'{save_path}/stdev_f1.png',
  254. 'Coverage Analysis of Standard Deviation vs. F1 Score',
  255. 'Maximum Standard Deviation Percentile Threshold',
  256. 'F1',
  257. )
  258. # Graph two statistics against each other
  259. def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
  260. # Filter for correct predictions
  261. c_stats = stats.where(
  262. stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
  263. drop=True,
  264. )
  265. # Filter for incorrect predictions
  266. i_stats = stats.where(
  267. stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
  268. drop=True,
  269. )
  270. # x and y data for correct and incorrect predictions
  271. x_data_c = c_stats.sel(statistic=x_stat).values
  272. y_data_c = c_stats.sel(statistic=y_stat).values
  273. x_data_i = i_stats.sel(statistic=x_stat).values
  274. y_data_i = i_stats.sel(statistic=y_stat).values
  275. fig, ax = plt.subplots()
  276. ax.plot(x_data_c, y_data_c, 'go', label='Correct')
  277. ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
  278. ax.set_title(title)
  279. ax.set_xlabel(xlabel)
  280. ax.set_ylabel(ylabel)
  281. ax.legend()
  282. plt.savefig(save_path)
  283. # Prune the data based on excluded data_ids
  284. def prune_data(data, excluded_data_ids):
  285. return data.where(~data.data_id.isin(excluded_data_ids), drop=True)
  286. # Calculate individual model statistics
  287. def compute_individual_statistics(predictions: xr.DataArray):
  288. zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6))
  289. indv_statistics = xr.DataArray(
  290. zeros,
  291. dims=('data_id', 'model_id', 'statistic'),
  292. coords={
  293. 'data_id': predictions.data_id,
  294. 'model_id': predictions.model_id,
  295. 'statistic': [
  296. 'mean',
  297. 'entropy',
  298. 'confidence',
  299. 'correct',
  300. 'predicted',
  301. 'actual',
  302. ],
  303. },
  304. )
  305. for data_id in predictions.data_id:
  306. for model_id in predictions.model_id:
  307. data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
  308. mean = data[0:2]
  309. entropy = (-mean * np.log(mean)).sum()
  310. confidence = mean.max()
  311. actual = data[3]
  312. predicted = mean.argmax()
  313. correct = actual == predicted
  314. indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
  315. mean[1],
  316. entropy,
  317. confidence,
  318. correct,
  319. predicted,
  320. actual,
  321. ]
  322. return indv_statistics
  323. # Compute individual model thresholds
  324. def compute_individual_thresholds(input_stats: xr.DataArray):
  325. quantiles = np.linspace(0.05, 0.95, 19) * 100
  326. metrics = ['accuracy', 'f1']
  327. statistics = ['entropy', 'confidence']
  328. zeros = np.zeros(
  329. (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
  330. )
  331. indv_thresholds = xr.DataArray(
  332. zeros,
  333. dims=('model_id', 'quantile', 'statistic', 'metric'),
  334. coords={
  335. 'model_id': input_stats.model_id,
  336. 'quantile': quantiles,
  337. 'statistic': statistics,
  338. 'metric': metrics,
  339. },
  340. )
  341. for model_id in input_stats.model_id:
  342. for statistic in statistics:
  343. # First, we must compute the quantiles for the statistic
  344. quantile_values = np.percentile(
  345. input_stats.sel(model_id=model_id, statistic=statistic).values,
  346. quantiles,
  347. axis=0,
  348. )
  349. # Then, we must compute the metrics for each quantile
  350. for i, quantile in enumerate(quantiles):
  351. if low_to_high(statistic):
  352. mask = (
  353. input_stats.sel(model_id=model_id, statistic=statistic)
  354. >= quantile_values[i]
  355. ).values
  356. else:
  357. mask = (
  358. input_stats.sel(model_id=model_id, statistic=statistic)
  359. <= quantile_values[i]
  360. ).values
  361. # Filter the data based on the mask
  362. filtered_data = input_stats.where(
  363. input_stats.data_id.isin(np.where(mask)), drop=True
  364. )
  365. for metric in metrics:
  366. indv_thresholds.loc[
  367. {
  368. 'model_id': model_id,
  369. 'quantile': quantile,
  370. 'statistic': statistic,
  371. 'metric': metric,
  372. }
  373. ] = compute_metric(filtered_data, metric)
  374. return indv_thresholds
  375. # Graph individual model thresholded predictions
  376. def graph_individual_thresholded_predictions(
  377. indv_thresholds,
  378. ensemble_thresholds,
  379. statistic,
  380. metric,
  381. save_path,
  382. title,
  383. xlabel,
  384. ylabel,
  385. ):
  386. data = indv_thresholds.sel(statistic=statistic, metric=metric)
  387. e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
  388. x_data = data.coords['quantile'].values
  389. y_data = data.values
  390. e_x_data = e_data.coords['quantile'].values
  391. e_y_data = e_data.values
  392. fig, ax = plt.subplots()
  393. for model_id in data.coords['model_id'].values:
  394. model_data = data.sel(model_id=model_id)
  395. ax.plot(x_data, model_data)
  396. ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
  397. ax.set_title(title)
  398. ax.set_xlabel(xlabel)
  399. ax.set_ylabel(ylabel)
  400. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  401. if not low_to_high(statistic):
  402. ax.invert_xaxis()
  403. ax.legend()
  404. plt.savefig(save_path)
  405. # Graph all individual thresholded predictions
  406. def graph_all_individual_thresholded_predictions(
  407. indv_thresholds, ensemble_thresholds, save_path
  408. ):
  409. # Confidence Accuracy
  410. graph_individual_thresholded_predictions(
  411. indv_thresholds,
  412. ensemble_thresholds,
  413. 'confidence',
  414. 'accuracy',
  415. f'{save_path}/indv/confidence_accuracy.png',
  416. 'Coverage Analysis of Confidence vs. Accuracy for All Models',
  417. 'Minumum Confidence Percentile Threshold',
  418. 'Accuracy',
  419. )
  420. # Confidence F1
  421. graph_individual_thresholded_predictions(
  422. indv_thresholds,
  423. ensemble_thresholds,
  424. 'confidence',
  425. 'f1',
  426. f'{save_path}/indv/confidence_f1.png',
  427. 'Coverage Analysis of Confidence vs. F1 Score for All Models',
  428. 'Minimum Confidence Percentile Threshold',
  429. 'F1 Score',
  430. )
  431. # Entropy Accuracy
  432. graph_individual_thresholded_predictions(
  433. indv_thresholds,
  434. ensemble_thresholds,
  435. 'entropy',
  436. 'accuracy',
  437. f'{save_path}/indv/entropy_accuracy.png',
  438. 'Coverage Analysis of Entropy vs. Accuracy for All Models',
  439. 'Maximum Entropy Percentile Threshold',
  440. 'Accuracy',
  441. )
  442. # Entropy F1
  443. graph_individual_thresholded_predictions(
  444. indv_thresholds,
  445. ensemble_thresholds,
  446. 'entropy',
  447. 'f1',
  448. f'{save_path}/indv/entropy_f1.png',
  449. 'Coverage Analysis of Entropy vs. F1 Score for All Models',
  450. 'Maximum Entropy Percentile Threshold',
  451. 'F1 Score',
  452. )
  453. # Calculate statistics of subsets of models for sensitivity analysis
  454. def calculate_subset_statistics(predictions: xr.DataArray):
  455. # Calculate subsets for 1-50 models
  456. subsets = range(1, len(predictions.model_id) + 1)
  457. zeros = np.zeros(
  458. (len(predictions.data_id), len(subsets), 7)
  459. ) # Include stdev, but for 1 models set to NaN
  460. subset_stats = xr.DataArray(
  461. zeros,
  462. dims=('data_id', 'model_count', 'statistic'),
  463. coords={
  464. 'data_id': predictions.data_id,
  465. 'model_count': subsets,
  466. 'statistic': [
  467. 'mean',
  468. 'stdev',
  469. 'entropy',
  470. 'confidence',
  471. 'correct',
  472. 'predicted',
  473. 'actual',
  474. ],
  475. },
  476. )
  477. for data_id in predictions.data_id:
  478. for subset in subsets:
  479. data = predictions.sel(
  480. data_id=data_id, model_id=predictions.model_id[:subset]
  481. )
  482. mean = data.mean(dim='model_id')[0:2]
  483. stdev = data.std(dim='model_id')[1]
  484. entropy = (-mean * np.log(mean)).sum()
  485. confidence = mean.max()
  486. actual = data[0][3]
  487. predicted = mean.argmax()
  488. correct = actual == predicted
  489. subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
  490. mean[1],
  491. stdev,
  492. entropy,
  493. confidence,
  494. correct,
  495. predicted,
  496. actual,
  497. ]
  498. return subset_stats
  499. # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
  500. def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
  501. subsets = subset_stats.model_count
  502. stats = ['accuracy', 'f1']
  503. zeros = np.zeros((len(subsets), len(stats)))
  504. sens_analysis = xr.DataArray(
  505. zeros,
  506. dims=('model_count', 'statistic'),
  507. coords={'model_count': subsets, 'statistic': ['accuracy', 'f1']},
  508. )
  509. for subset in subsets:
  510. data = subset_stats.sel(model_count=subset)
  511. acc = compute_metric(data, 'accuracy')
  512. f1 = compute_metric(data, 'f1')
  513. sens_analysis.loc[{'model_count': subset}] = [acc, f1]
  514. return sens_analysis
  515. def graph_sensitivity_analysis(
  516. sens_analysis: xr.DataArray, statistic, save_path, title, xlabel, ylabel
  517. ):
  518. data = sens_analysis.sel(statistic=statistic)
  519. xdata = data.coords['model_count'].values
  520. ydata = data.values
  521. fig, ax = plt.subplots()
  522. ax.plot(xdata, ydata)
  523. ax.set_title(title)
  524. ax.set_xlabel(xlabel)
  525. ax.set_ylabel(ylabel)
  526. plt.savefig(save_path)
  527. def calculate_overall_stats(ensemble_statistics: xr.DataArray):
  528. accuracy = compute_metric(ensemble_statistics, 'accuracy')
  529. f1 = compute_metric(ensemble_statistics, 'f1')
  530. return {'accuracy': accuracy.item(), 'f1': f1.item()}
  531. # https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
  532. def calculate_ece_stats(statistics, bins=10):
  533. bin_boundaries = np.linspace(0, 1, bins + 1)
  534. bin_lowers = bin_boundaries[:-1]
  535. bin_uppers = bin_boundaries[1:]
  536. confidences = ((statistics.sel(statistic='mean').values) - 0.5) * 2
  537. accuracies = statistics.sel(statistic='correct').values
  538. ece = np.zeros(1)
  539. bin_accuracies = xr.DataArray(
  540. np.zeros(bins), dims=('lower_bound'), coords={'lower_bound': bin_lowers}
  541. )
  542. for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
  543. in_bin = np.logical_and(
  544. confidences > bin_lower.item(), confidences <= bin_upper.item()
  545. )
  546. prob_in_bin = in_bin.mean()
  547. if prob_in_bin.item() > 0:
  548. accuracy_in_bin = accuracies[in_bin].mean()
  549. bin_accuracies.loc[{'lower_bound': bin_lower}]
  550. avg_confidence_in_bin = confidences[in_bin].mean()
  551. ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
  552. bin_accuracies.attrs['ece'] = ece
  553. bin_accuracies.attrs['bin_number'] = bins
  554. return bin_accuracies
  555. def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
  556. fix, ax = plt.subplot()
  557. # Main Function
  558. def main():
  559. print('Loading Config...')
  560. config = load_config()
  561. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  562. V4_PATH = ENSEMBLE_PATH + '/v4'
  563. if not os.path.exists(V4_PATH):
  564. os.makedirs(V4_PATH)
  565. print('Config Loaded')
  566. # Load Datasets
  567. print('Loading Datasets...')
  568. (test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
  569. print('Datasets Loaded')
  570. # Get Predictions, either by running the models or loading them from a file
  571. if config['ensemble']['run_models']:
  572. # Load Models
  573. print('Loading Models...')
  574. device = torch.device(config['training']['device'])
  575. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  576. print('Models Loaded')
  577. # Get Predictions
  578. print('Getting Predictions...')
  579. test_predictions = get_ensemble_predictions(models, test_dataset, device)
  580. val_predictions = get_ensemble_predictions(
  581. models, val_dataset, device, len(test_dataset)
  582. )
  583. print('Predictions Loaded')
  584. # Save Prediction
  585. test_predictions.to_netcdf(f'{V4_PATH}/test_predictions.nc')
  586. val_predictions.to_netcdf(f'{V4_PATH}/val_predictions.nc')
  587. else:
  588. test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
  589. val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
  590. # Prune Data
  591. print('Pruning Data...')
  592. if config['operation']['exclude_blank_ids']:
  593. excluded_data_ids = config['ensemble']['excluded_ids']
  594. test_predictions = prune_data(test_predictions, excluded_data_ids)
  595. val_predictions = prune_data(val_predictions, excluded_data_ids)
  596. # Concatenate Predictions
  597. predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
  598. # Compute Ensemble Statistics
  599. print('Computing Ensemble Statistics...')
  600. ensemble_statistics = compute_ensemble_statistics(predictions)
  601. ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
  602. print('Ensemble Statistics Computed')
  603. # Compute Thresholded Predictions
  604. print('Computing Thresholded Predictions...')
  605. thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
  606. thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
  607. print('Thresholded Predictions Computed')
  608. # Graph Thresholded Predictions
  609. print('Graphing Thresholded Predictions...')
  610. graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
  611. print('Thresholded Predictions Graphed')
  612. # Additional Graphs
  613. print('Graphing Additional Graphs...')
  614. # Confidence vs stdev
  615. graph_statistics(
  616. ensemble_statistics,
  617. 'confidence',
  618. 'stdev',
  619. f'{V4_PATH}/confidence_stdev.png',
  620. 'Confidence and Standard Deviation for Predictions',
  621. 'Confidence',
  622. 'Standard Deviation',
  623. )
  624. print('Additional Graphs Graphed')
  625. # Compute Individual Statistics
  626. print('Computing Individual Statistics...')
  627. indv_statistics = compute_individual_statistics(predictions)
  628. indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
  629. print('Individual Statistics Computed')
  630. # Compute Individual Thresholds
  631. print('Computing Individual Thresholds...')
  632. indv_thresholds = compute_individual_thresholds(indv_statistics)
  633. indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
  634. print('Individual Thresholds Computed')
  635. # Graph Individual Thresholded Predictions
  636. print('Graphing Individual Thresholded Predictions...')
  637. if not os.path.exists(f'{V4_PATH}/indv'):
  638. os.makedirs(f'{V4_PATH}/indv')
  639. graph_all_individual_thresholded_predictions(
  640. indv_thresholds, thresholded_predictions, V4_PATH
  641. )
  642. print('Individual Thresholded Predictions Graphed')
  643. # Compute subset statistics and graph
  644. subset_stats = calculate_subset_statistics(predictions)
  645. sens_analysis = calculate_sensitivity_analysis(subset_stats)
  646. graph_sensitivity_analysis(
  647. sens_analysis,
  648. 'accuracy',
  649. f'{V4_PATH}/sens_analysis.png',
  650. 'Sensitivity Analsis of Accuracy vs. # of Models',
  651. '# of Models',
  652. 'Accuracy',
  653. )
  654. print(sens_analysis.sel(statistic='accuracy'))
  655. print(calculate_overall_stats(ensemble_statistics))
  656. if __name__ == '__main__':
  657. main()