threshold_xarray.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. # Rewritten Program to use xarray instead of pandas for thresholding
  2. import xarray as xr
  3. import torch
  4. import numpy as np
  5. import os
  6. import glob
  7. import tomli as toml
  8. from tqdm import tqdm
  9. import utils.metrics as met
  10. import matplotlib.pyplot as plt
  11. import matplotlib.ticker as mtick
  12. # The datastructures for this file are as follows
  13. # models_dict: Dictionary - {model_id: model}
  14. # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
  15. # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
  16. # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
  17. # Additionally, we also have the thresholds and statistics for the individual models
  18. # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
  19. # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
  20. # Additionally, we have some for the sensitivity analysis for number of models
  21. # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
  22. # Loads configuration dictionary
  23. def load_config():
  24. if os.getenv('ADL_CONFIG_PATH') is None:
  25. with open('config.toml', 'rb') as f:
  26. config = toml.load(f)
  27. else:
  28. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  29. config = toml.load(f)
  30. return config
  31. # Loads models into a dictionary
  32. def load_models_v2(folder, device):
  33. glob_path = os.path.join(folder, '*.pt')
  34. model_files = glob.glob(glob_path)
  35. model_dict = {}
  36. for model_file in model_files:
  37. model = torch.load(model_file, map_location=device)
  38. model_id = os.path.basename(model_file).split('_')[0]
  39. model_dict[model_id] = model
  40. if len(model_dict) == 0:
  41. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  42. return model_dict
  43. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  44. def preprocess_data(data, device):
  45. mri, xls = data
  46. mri = mri.unsqueeze(0).to(device)
  47. xls = xls.unsqueeze(0).to(device)
  48. return (mri, xls)
  49. # Loads datasets and returns concatenated test and validation datasets
  50. def load_datasets(ensemble_path):
  51. return (
  52. torch.load(f'{ensemble_path}/test_dataset.pt'),
  53. torch.load(f'{ensemble_path}/val_dataset.pt'),
  54. )
  55. # Gets the predictions for a set of models on a dataset
  56. def get_ensemble_predictions(models, dataset, device, id_offset=0):
  57. zeros = np.zeros((len(dataset), len(models), 4))
  58. predictions = xr.DataArray(
  59. zeros,
  60. dims=('data_id', 'model_id', 'prediction_value'),
  61. coords={
  62. 'data_id': range(id_offset, len(dataset) + id_offset),
  63. 'model_id': list(models.keys()),
  64. 'prediction_value': [
  65. 'negative_prediction',
  66. 'positive_prediction',
  67. 'negative_actual',
  68. 'positive_actual',
  69. ],
  70. },
  71. )
  72. for data_id, (data, target) in tqdm(
  73. enumerate(dataset), total=len(dataset), unit='images'
  74. ):
  75. dat = preprocess_data(data, device)
  76. actual = list(target.cpu().numpy())
  77. for model_id, model in models.items():
  78. with torch.no_grad():
  79. output = model(dat)
  80. prediction = output.cpu().numpy().tolist()[0]
  81. predictions.loc[
  82. {'data_id': data_id + id_offset, 'model_id': model_id}
  83. ] = prediction + actual
  84. return predictions
  85. # Compute the ensemble statistics given an array of predictions
  86. def compute_ensemble_statistics(predictions: xr.DataArray):
  87. zeros = np.zeros((len(predictions.data_id), 7))
  88. ensemble_statistics = xr.DataArray(
  89. zeros,
  90. dims=('data_id', 'statistic'),
  91. coords={
  92. 'data_id': predictions.data_id,
  93. 'statistic': [
  94. 'mean',
  95. 'stdev',
  96. 'entropy',
  97. 'confidence',
  98. 'correct',
  99. 'predicted',
  100. 'actual',
  101. ],
  102. },
  103. )
  104. for data_id in predictions.data_id:
  105. data = predictions.loc[{'data_id': data_id}]
  106. mean = data.mean(dim='model_id')[
  107. 0:2
  108. ] # Only take the predictions, not the actual
  109. stdev = data.std(dim='model_id')[
  110. 1
  111. ] # Only need the standard deviation of the postive prediction
  112. entropy = (-mean * np.log(mean)).sum()
  113. # Compute confidence
  114. confidence = mean.max()
  115. # only need one of the actual values, since they are all the same, just get the first actual_positive
  116. actual = data.loc[{'prediction_value': 'positive_actual'}][0]
  117. predicted = mean.argmax()
  118. correct = actual == predicted
  119. ensemble_statistics.loc[{'data_id': data_id}] = [
  120. mean[1],
  121. stdev,
  122. entropy,
  123. confidence,
  124. correct,
  125. predicted,
  126. actual,
  127. ]
  128. return ensemble_statistics
  129. # Compute the thresholded predictions given an array of predictions
  130. def compute_thresholded_predictions(input_stats: xr.DataArray):
  131. quantiles = np.linspace(0.00, 1.00, 21) * 100
  132. metrics = ['accuracy', 'f1']
  133. statistics = ['stdev', 'entropy', 'confidence']
  134. zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
  135. thresholded_predictions = xr.DataArray(
  136. zeros,
  137. dims=('quantile', 'statistic', 'metric'),
  138. coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
  139. )
  140. for statistic in statistics:
  141. # First, we must compute the quantiles for the statistic
  142. quantile_values = np.percentile(
  143. input_stats.sel(statistic=statistic).values, quantiles, axis=0
  144. )
  145. # Then, we must compute the metrics for each quantile
  146. for i, quantile in enumerate(quantiles):
  147. if low_to_high(statistic):
  148. mask = (
  149. input_stats.sel(statistic=statistic) >= quantile_values[i]
  150. ).values
  151. else:
  152. mask = (
  153. input_stats.sel(statistic=statistic) <= quantile_values[i]
  154. ).values
  155. # Filter the data based on the mask
  156. filtered_data = input_stats.where(
  157. input_stats.data_id.isin(np.where(mask)), drop=True
  158. )
  159. for metric in metrics:
  160. thresholded_predictions.loc[
  161. {'quantile': quantile, 'statistic': statistic, 'metric': metric}
  162. ] = compute_metric(filtered_data, metric)
  163. return thresholded_predictions
  164. # Truth function to determine if metric should be thresholded low to high or high to low
  165. # Low confidence is bad, high entropy is bad, high stdev is bad
  166. # So we threshold confidence low to high, entropy and stdev high to low
  167. # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
  168. def low_to_high(stat):
  169. return stat in ['confidence']
  170. # Compute a given metric on a DataArray of statstics
  171. def compute_metric(arr, metric):
  172. if metric == 'accuracy':
  173. return np.mean(arr.loc[{'statistic': 'correct'}])
  174. elif metric == 'f1':
  175. return met.F1(
  176. arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
  177. )
  178. elif metric == 'ece':
  179. true_labels = arr.loc[{'statistic': 'actual'}].values
  180. predicted_labels = arr.loc[{'statistic': 'predicted'}].values
  181. confidences = arr.loc[{'statistic': 'confidence'}].values
  182. return calculate_ece_stats(confidences, predicted_labels, true_labels)
  183. else:
  184. raise ValueError('Invalid metric: ' + metric)
  185. # Graph a thresholded prediction for a given statistic and metric
  186. def graph_thresholded_prediction(
  187. thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel
  188. ):
  189. data = thresholded_predictions.sel(statistic=statistic, metric=metric)
  190. x_data = data.coords['quantile'].values
  191. y_data = data.values
  192. fig, ax = plt.subplots()
  193. ax.plot(x_data, y_data, 'bx-', label='Ensemble')
  194. ax.set_title(title)
  195. ax.set_xlabel(xlabel)
  196. ax.set_ylabel(ylabel)
  197. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  198. if not low_to_high(statistic):
  199. ax.invert_xaxis()
  200. plt.savefig(save_path)
  201. # Graph all thresholded predictions
  202. def graph_all_thresholded_predictions(thresholded_predictions, save_path):
  203. # Confidence Accuracy
  204. graph_thresholded_prediction(
  205. thresholded_predictions,
  206. 'confidence',
  207. 'accuracy',
  208. f'{save_path}/confidence_accuracy.png',
  209. 'Coverage Analysis of Confidence vs. Accuracy',
  210. 'Minimum Confidence Percentile Threshold',
  211. 'Accuracy',
  212. )
  213. # Confidence F1
  214. graph_thresholded_prediction(
  215. thresholded_predictions,
  216. 'confidence',
  217. 'f1',
  218. f'{save_path}/confidence_f1.png',
  219. 'Coverage Analysis of Confidence vs. F1 Score',
  220. 'Minimum Confidence Percentile Threshold',
  221. 'F1 Score',
  222. )
  223. # Entropy Accuracy
  224. graph_thresholded_prediction(
  225. thresholded_predictions,
  226. 'entropy',
  227. 'accuracy',
  228. f'{save_path}/entropy_accuracy.png',
  229. 'Coverage Analysis of Entropy vs. Accuracy',
  230. 'Maximum Entropy Percentile Threshold',
  231. 'Accuracy',
  232. )
  233. # Entropy F1
  234. graph_thresholded_prediction(
  235. thresholded_predictions,
  236. 'entropy',
  237. 'f1',
  238. f'{save_path}/entropy_f1.png',
  239. 'Coverage Analysis of Entropy vs. F1 Score',
  240. 'Maximum Entropy Percentile Threshold',
  241. 'F1 Score',
  242. )
  243. # Stdev Accuracy
  244. graph_thresholded_prediction(
  245. thresholded_predictions,
  246. 'stdev',
  247. 'accuracy',
  248. f'{save_path}/stdev_accuracy.png',
  249. 'Coverage Analysis of Standard Deviation vs. Accuracy',
  250. 'Maximum Standard Deviation Percentile Threshold',
  251. 'Accuracy',
  252. )
  253. # Stdev F1
  254. graph_thresholded_prediction(
  255. thresholded_predictions,
  256. 'stdev',
  257. 'f1',
  258. f'{save_path}/stdev_f1.png',
  259. 'Coverage Analysis of Standard Deviation vs. F1 Score',
  260. 'Maximum Standard Deviation Percentile Threshold',
  261. 'F1',
  262. )
  263. # Graph two statistics against each other
  264. def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
  265. # Filter for correct predictions
  266. c_stats = stats.where(
  267. stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
  268. drop=True,
  269. )
  270. # Filter for incorrect predictions
  271. i_stats = stats.where(
  272. stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
  273. drop=True,
  274. )
  275. # x and y data for correct and incorrect predictions
  276. x_data_c = c_stats.sel(statistic=x_stat).values
  277. y_data_c = c_stats.sel(statistic=y_stat).values
  278. x_data_i = i_stats.sel(statistic=x_stat).values
  279. y_data_i = i_stats.sel(statistic=y_stat).values
  280. fig, ax = plt.subplots()
  281. ax.plot(x_data_c, y_data_c, 'go', label='Correct')
  282. ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
  283. ax.set_title(title)
  284. ax.set_xlabel(xlabel)
  285. ax.set_ylabel(ylabel)
  286. ax.legend()
  287. plt.savefig(save_path)
  288. # Prune the data based on excluded data_ids
  289. def prune_data(data, excluded_data_ids):
  290. return data.where(~data.data_id.isin(excluded_data_ids), drop=True)
  291. # Calculate individual model statistics
  292. def compute_individual_statistics(predictions: xr.DataArray):
  293. zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6))
  294. indv_statistics = xr.DataArray(
  295. zeros,
  296. dims=('data_id', 'model_id', 'statistic'),
  297. coords={
  298. 'data_id': predictions.data_id,
  299. 'model_id': predictions.model_id,
  300. 'statistic': [
  301. 'mean',
  302. 'entropy',
  303. 'confidence',
  304. 'correct',
  305. 'predicted',
  306. 'actual',
  307. ],
  308. },
  309. )
  310. for data_id in tqdm(
  311. predictions.data_id, total=len(predictions.data_id), unit='images'
  312. ):
  313. for model_id in predictions.model_id:
  314. data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
  315. mean = data[0:2]
  316. entropy = (-mean * np.log(mean)).sum()
  317. confidence = mean.max()
  318. actual = data[3]
  319. predicted = mean.argmax()
  320. correct = actual == predicted
  321. indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
  322. mean[1],
  323. entropy,
  324. confidence,
  325. correct,
  326. predicted,
  327. actual,
  328. ]
  329. return indv_statistics
  330. # Compute individual model thresholds
  331. def compute_individual_thresholds(input_stats: xr.DataArray):
  332. quantiles = np.linspace(0.05, 0.95, 19) * 100
  333. metrics = ['accuracy', 'f1']
  334. statistics = ['entropy', 'confidence']
  335. zeros = np.zeros(
  336. (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
  337. )
  338. indv_thresholds = xr.DataArray(
  339. zeros,
  340. dims=('model_id', 'quantile', 'statistic', 'metric'),
  341. coords={
  342. 'model_id': input_stats.model_id,
  343. 'quantile': quantiles,
  344. 'statistic': statistics,
  345. 'metric': metrics,
  346. },
  347. )
  348. for model_id in tqdm(
  349. input_stats.model_id, total=len(input_stats.model_id), unit='models'
  350. ):
  351. for statistic in statistics:
  352. # First, we must compute the quantiles for the statistic
  353. quantile_values = np.percentile(
  354. input_stats.sel(model_id=model_id, statistic=statistic).values,
  355. quantiles,
  356. axis=0,
  357. )
  358. # Then, we must compute the metrics for each quantile
  359. for i, quantile in enumerate(quantiles):
  360. if low_to_high(statistic):
  361. mask = (
  362. input_stats.sel(model_id=model_id, statistic=statistic)
  363. >= quantile_values[i]
  364. ).values
  365. else:
  366. mask = (
  367. input_stats.sel(model_id=model_id, statistic=statistic)
  368. <= quantile_values[i]
  369. ).values
  370. # Filter the data based on the mask
  371. filtered_data = input_stats.where(
  372. input_stats.data_id.isin(np.where(mask)), drop=True
  373. )
  374. for metric in metrics:
  375. indv_thresholds.loc[
  376. {
  377. 'model_id': model_id,
  378. 'quantile': quantile,
  379. 'statistic': statistic,
  380. 'metric': metric,
  381. }
  382. ] = compute_metric(filtered_data, metric)
  383. return indv_thresholds
  384. # Graph individual model thresholded predictions
  385. def graph_individual_thresholded_predictions(
  386. indv_thresholds,
  387. ensemble_thresholds,
  388. statistic,
  389. metric,
  390. save_path,
  391. title,
  392. xlabel,
  393. ylabel,
  394. ):
  395. data = indv_thresholds.sel(statistic=statistic, metric=metric)
  396. e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
  397. x_data = data.coords['quantile'].values
  398. y_data = data.values
  399. e_x_data = e_data.coords['quantile'].values
  400. e_y_data = e_data.values
  401. fig, ax = plt.subplots()
  402. for model_id in data.coords['model_id'].values:
  403. model_data = data.sel(model_id=model_id)
  404. ax.plot(x_data, model_data)
  405. ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
  406. ax.set_title(title)
  407. ax.set_xlabel(xlabel)
  408. ax.set_ylabel(ylabel)
  409. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  410. if not low_to_high(statistic):
  411. ax.invert_xaxis()
  412. ax.legend()
  413. plt.savefig(save_path)
  414. # Graph all individual thresholded predictions
  415. def graph_all_individual_thresholded_predictions(
  416. indv_thresholds, ensemble_thresholds, save_path
  417. ):
  418. # Confidence Accuracy
  419. graph_individual_thresholded_predictions(
  420. indv_thresholds,
  421. ensemble_thresholds,
  422. 'confidence',
  423. 'accuracy',
  424. f'{save_path}/indv/confidence_accuracy.png',
  425. 'Coverage Analysis of Confidence vs. Accuracy for All Models',
  426. 'Minumum Confidence Percentile Threshold',
  427. 'Accuracy',
  428. )
  429. # Confidence F1
  430. graph_individual_thresholded_predictions(
  431. indv_thresholds,
  432. ensemble_thresholds,
  433. 'confidence',
  434. 'f1',
  435. f'{save_path}/indv/confidence_f1.png',
  436. 'Coverage Analysis of Confidence vs. F1 Score for All Models',
  437. 'Minimum Confidence Percentile Threshold',
  438. 'F1 Score',
  439. )
  440. # Entropy Accuracy
  441. graph_individual_thresholded_predictions(
  442. indv_thresholds,
  443. ensemble_thresholds,
  444. 'entropy',
  445. 'accuracy',
  446. f'{save_path}/indv/entropy_accuracy.png',
  447. 'Coverage Analysis of Entropy vs. Accuracy for All Models',
  448. 'Maximum Entropy Percentile Threshold',
  449. 'Accuracy',
  450. )
  451. # Entropy F1
  452. graph_individual_thresholded_predictions(
  453. indv_thresholds,
  454. ensemble_thresholds,
  455. 'entropy',
  456. 'f1',
  457. f'{save_path}/indv/entropy_f1.png',
  458. 'Coverage Analysis of Entropy vs. F1 Score for All Models',
  459. 'Maximum Entropy Percentile Threshold',
  460. 'F1 Score',
  461. )
  462. # Calculate statistics of subsets of models for sensitivity analysis
  463. def calculate_subset_statistics(predictions: xr.DataArray):
  464. # Calculate subsets for 1-49 models
  465. subsets = range(1, len(predictions.model_id))
  466. zeros = np.zeros(
  467. (len(predictions.data_id), len(subsets), 7)
  468. ) # Include stdev, but for 1 models set to NaN
  469. subset_stats = xr.DataArray(
  470. zeros,
  471. dims=('data_id', 'model_count', 'statistic'),
  472. coords={
  473. 'data_id': predictions.data_id,
  474. 'model_count': subsets,
  475. 'statistic': [
  476. 'mean',
  477. 'stdev',
  478. 'entropy',
  479. 'confidence',
  480. 'correct',
  481. 'predicted',
  482. 'actual',
  483. ],
  484. },
  485. )
  486. for data_id in tqdm(
  487. predictions.data_id, total=len(predictions.data_id), unit='images'
  488. ):
  489. for subset in subsets:
  490. data = predictions.sel(
  491. data_id=data_id, model_id=predictions.model_id[:subset]
  492. )
  493. mean = data.mean(dim='model_id')[0:2]
  494. stdev = data.std(dim='model_id')[1]
  495. entropy = (-mean * np.log(mean)).sum()
  496. confidence = mean.max()
  497. actual = data[0][3]
  498. predicted = mean.argmax()
  499. correct = actual == predicted
  500. subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
  501. mean[1],
  502. stdev,
  503. entropy,
  504. confidence,
  505. correct,
  506. predicted,
  507. actual,
  508. ]
  509. return subset_stats
  510. # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
  511. def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
  512. subsets = subset_stats.model_count
  513. stats = ['accuracy', 'f1', 'ece']
  514. zeros = np.zeros((len(subsets), len(stats)))
  515. sens_analysis = xr.DataArray(
  516. zeros,
  517. dims=('model_count', 'statistic'),
  518. coords={'model_count': subsets, 'statistic': stats},
  519. )
  520. for subset in tqdm(subsets, total=len(subsets), unit='model subsets'):
  521. data = subset_stats.sel(model_count=subset)
  522. acc = compute_metric(data, 'accuracy').item()
  523. f1 = compute_metric(data, 'f1').item()
  524. ece = compute_metric(data, 'ece').item()
  525. sens_analysis.loc[{'model_count': subset.item()}] = [acc, f1, ece]
  526. return sens_analysis
  527. def graph_sensitivity_analysis(
  528. sens_analysis: xr.DataArray, statistic, save_path, title, xlabel, ylabel
  529. ):
  530. data = sens_analysis.sel(statistic=statistic)
  531. xdata = data.coords['model_count'].values
  532. ydata = data.values
  533. fig, ax = plt.subplots()
  534. ax.plot(xdata, ydata)
  535. ax.set_title(title)
  536. ax.set_xlabel(xlabel)
  537. ax.set_ylabel(ylabel)
  538. plt.savefig(save_path)
  539. def calculate_overall_stats(ensemble_statistics: xr.DataArray):
  540. accuracy = compute_metric(ensemble_statistics, 'accuracy')
  541. f1 = compute_metric(ensemble_statistics, 'f1')
  542. return {'accuracy': accuracy.item(), 'f1': f1.item()}
  543. # https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
  544. def calculate_ece_stats(confidences, predicted_labels, true_labels, bins=10):
  545. bin_boundaries = np.linspace(0, 1, bins + 1)
  546. bin_lowers = bin_boundaries[:-1]
  547. bin_uppers = bin_boundaries[1:]
  548. ece = np.zeros(1)
  549. for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
  550. in_bin = np.logical_and(
  551. confidences > bin_lower.item(), confidences <= bin_upper.item()
  552. )
  553. prob_in_bin = in_bin.mean()
  554. if prob_in_bin.item() > 0:
  555. accuracy_in_bin = true_labels[in_bin].mean()
  556. avg_confidence_in_bin = confidences[in_bin].mean()
  557. ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
  558. return ece
  559. def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
  560. fix, ax = plt.subplot()
  561. # Main Function
  562. def main():
  563. print('Loading Config...')
  564. config = load_config()
  565. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  566. V4_PATH = ENSEMBLE_PATH + '/v4'
  567. if not os.path.exists(V4_PATH):
  568. os.makedirs(V4_PATH)
  569. print('Config Loaded')
  570. # Load Datasets
  571. print('Loading Datasets...')
  572. (test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
  573. print('Datasets Loaded')
  574. # Get Predictions, either by running the models or loading them from a file
  575. if config['ensemble']['run_models']:
  576. # Load Models
  577. print('Loading Models...')
  578. device = torch.device(config['training']['device'])
  579. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  580. print('Models Loaded')
  581. # Get Predictions
  582. print('Getting Predictions...')
  583. test_predictions = get_ensemble_predictions(models, test_dataset, device)
  584. val_predictions = get_ensemble_predictions(
  585. models, val_dataset, device, len(test_dataset)
  586. )
  587. print('Predictions Loaded')
  588. # Save Prediction
  589. test_predictions.to_netcdf(f'{V4_PATH}/test_predictions.nc')
  590. val_predictions.to_netcdf(f'{V4_PATH}/val_predictions.nc')
  591. else:
  592. test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
  593. val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
  594. # Prune Data
  595. print('Pruning Data...')
  596. if config['operation']['exclude_blank_ids']:
  597. excluded_data_ids = config['ensemble']['excluded_ids']
  598. test_predictions = prune_data(test_predictions, excluded_data_ids)
  599. val_predictions = prune_data(val_predictions, excluded_data_ids)
  600. # Concatenate Predictions
  601. predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
  602. # Compute Ensemble Statistics
  603. print('Computing Ensemble Statistics...')
  604. ensemble_statistics = compute_ensemble_statistics(predictions)
  605. ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
  606. print('Ensemble Statistics Computed')
  607. # Compute Thresholded Predictions
  608. print('Computing Thresholded Predictions...')
  609. thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
  610. thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
  611. print('Thresholded Predictions Computed')
  612. # Graph Thresholded Predictions
  613. print('Graphing Thresholded Predictions...')
  614. graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
  615. print('Thresholded Predictions Graphed')
  616. # Additional Graphs
  617. print('Graphing Additional Graphs...')
  618. # Confidence vs stdev
  619. graph_statistics(
  620. ensemble_statistics,
  621. 'confidence',
  622. 'stdev',
  623. f'{V4_PATH}/confidence_stdev.png',
  624. 'Confidence and Standard Deviation for Predictions',
  625. 'Confidence',
  626. 'Standard Deviation',
  627. )
  628. print('Additional Graphs Graphed')
  629. # Compute Individual Statistics
  630. print('Computing Individual Statistics...')
  631. indv_statistics = compute_individual_statistics(predictions)
  632. indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
  633. print('Individual Statistics Computed')
  634. # Compute Individual Thresholds
  635. print('Computing Individual Thresholds...')
  636. indv_thresholds = compute_individual_thresholds(indv_statistics)
  637. indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
  638. print('Individual Thresholds Computed')
  639. # Graph Individual Thresholded Predictions
  640. print('Graphing Individual Thresholded Predictions...')
  641. if not os.path.exists(f'{V4_PATH}/indv'):
  642. os.makedirs(f'{V4_PATH}/indv')
  643. graph_all_individual_thresholded_predictions(
  644. indv_thresholds, thresholded_predictions, V4_PATH
  645. )
  646. print('Individual Thresholded Predictions Graphed')
  647. # Compute subset statistics and graph
  648. print('Computing Sensitivity Analysis...')
  649. subset_stats = calculate_subset_statistics(predictions)
  650. sens_analysis = calculate_sensitivity_analysis(subset_stats)
  651. graph_sensitivity_analysis(
  652. sens_analysis,
  653. 'accuracy',
  654. f'{V4_PATH}/sens_analysis.png',
  655. 'Sensitivity Analsis of Accuracy vs. # of Models',
  656. '# of Models',
  657. 'Accuracy',
  658. )
  659. graph_sensitivity_analysis(
  660. sens_analysis,
  661. 'ece',
  662. f'{V4_PATH}/sens_analysis_ece.png',
  663. 'Sensitivity Analysis of ECE vs. # of Models',
  664. '# of Models',
  665. 'ECE',
  666. )
  667. print(sens_analysis.sel(statistic='accuracy'))
  668. print(calculate_overall_stats(ensemble_statistics))
  669. if __name__ == '__main__':
  670. main()