threshold_xarray.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827
  1. # Rewritten Program to use xarray instead of pandas for thresholding
  2. import xarray as xr
  3. import torch
  4. import numpy as np
  5. import os
  6. import glob
  7. import tomli as toml
  8. from tqdm import tqdm
  9. import utils.metrics as met
  10. import matplotlib.pyplot as plt
  11. import matplotlib.ticker as mtick
  12. import utils.models.cnn
  13. torch.serialization.safe_globals([utils.models.cnn.CNN])
  14. # The datastructures for this file are as follows
  15. # models_dict: Dictionary - {model_id: model}
  16. # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
  17. # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
  18. # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
  19. # Additionally, we also have the thresholds and statistics for the individual models
  20. # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
  21. # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
  22. # Additionally, we have some for the sensitivity analysis for number of models
  23. # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
  24. # Loads configuration dictionary
  25. def load_config():
  26. if os.getenv("ADL_CONFIG_PATH") is None:
  27. with open("config.toml", "rb") as f:
  28. config = toml.load(f)
  29. else:
  30. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  31. config = toml.load(f)
  32. return config
  33. # Loads models into a dictionary
  34. def load_models_v2(folder, device):
  35. glob_path = os.path.join(folder, "*.pt")
  36. model_files = glob.glob(glob_path)
  37. model_dict = {}
  38. for model_file in model_files:
  39. with open(model_file, "r") as f:
  40. print(torch.serialization.get_unsafe_globals_in_checkpoint(f))
  41. model = torch.load(model_file, map_location=device)
  42. model_id = os.path.basename(model_file).split("_")[0]
  43. model_dict[model_id] = model
  44. if len(model_dict) == 0:
  45. raise FileNotFoundError("No models found in the specified directory: " + folder)
  46. return model_dict
  47. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  48. def preprocess_data(data, device):
  49. mri, xls = data
  50. mri = mri.unsqueeze(0).to(device)
  51. xls = xls.unsqueeze(0).to(device)
  52. return (mri, xls)
  53. # Loads datasets and returns concatenated test and validation datasets
  54. def load_datasets(ensemble_path):
  55. return (
  56. torch.load(f"{ensemble_path}/test_dataset.pt"),
  57. torch.load(f"{ensemble_path}/val_dataset.pt"),
  58. )
  59. # Gets the predictions for a set of models on a dataset
  60. def get_ensemble_predictions(models, dataset, device, id_offset=0):
  61. zeros = np.zeros((len(dataset), len(models), 4))
  62. predictions = xr.DataArray(
  63. zeros,
  64. dims=("data_id", "model_id", "prediction_value"),
  65. coords={
  66. "data_id": range(id_offset, len(dataset) + id_offset),
  67. "model_id": list(models.keys()),
  68. "prediction_value": [
  69. "negative_prediction",
  70. "positive_prediction",
  71. "negative_actual",
  72. "positive_actual",
  73. ],
  74. },
  75. )
  76. for data_id, (data, target) in tqdm(
  77. enumerate(dataset), total=len(dataset), unit="images"
  78. ):
  79. dat = preprocess_data(data, device)
  80. actual = list(target.cpu().numpy())
  81. for model_id, model in models.items():
  82. with torch.no_grad():
  83. output = model(dat)
  84. prediction = output.cpu().numpy().tolist()[0]
  85. predictions.loc[
  86. {"data_id": data_id + id_offset, "model_id": model_id}
  87. ] = (prediction + actual)
  88. return predictions
  89. # Compute the ensemble statistics given an array of predictions
  90. def compute_ensemble_statistics(predictions: xr.DataArray):
  91. zeros = np.zeros((len(predictions.data_id), 7))
  92. ensemble_statistics = xr.DataArray(
  93. zeros,
  94. dims=("data_id", "statistic"),
  95. coords={
  96. "data_id": predictions.data_id,
  97. "statistic": [
  98. "mean",
  99. "stdev",
  100. "entropy",
  101. "confidence",
  102. "correct",
  103. "predicted",
  104. "actual",
  105. ],
  106. },
  107. )
  108. for data_id in predictions.data_id:
  109. data = predictions.loc[{"data_id": data_id}]
  110. mean = data.mean(dim="model_id")[
  111. 0:2
  112. ] # Only take the predictions, not the actual
  113. stdev = data.std(dim="model_id")[
  114. 1
  115. ] # Only need the standard deviation of the postive prediction
  116. entropy = (-mean * np.log(mean)).sum()
  117. # Compute confidence
  118. confidence = mean.max()
  119. # only need one of the actual values, since they are all the same, just get the first actual_positive
  120. actual = data.loc[{"prediction_value": "positive_actual"}][0]
  121. predicted = mean.argmax()
  122. correct = actual == predicted
  123. ensemble_statistics.loc[{"data_id": data_id}] = [
  124. mean[1],
  125. stdev,
  126. entropy,
  127. confidence,
  128. correct,
  129. predicted,
  130. actual,
  131. ]
  132. return ensemble_statistics
  133. # Compute the thresholded predictions given an array of predictions
  134. def compute_thresholded_predictions(input_stats: xr.DataArray):
  135. quantiles = np.linspace(0.00, 1.00, 21) * 100
  136. metrics = ["accuracy", "f1"]
  137. statistics = ["stdev", "entropy", "confidence"]
  138. zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
  139. thresholded_predictions = xr.DataArray(
  140. zeros,
  141. dims=("quantile", "statistic", "metric"),
  142. coords={"quantile": quantiles, "statistic": statistics, "metric": metrics},
  143. )
  144. for statistic in statistics:
  145. # First, we must compute the quantiles for the statistic
  146. quantile_values = np.percentile(
  147. input_stats.sel(statistic=statistic).values, quantiles, axis=0
  148. )
  149. # Then, we must compute the metrics for each quantile
  150. for i, quantile in enumerate(quantiles):
  151. if low_to_high(statistic):
  152. mask = (
  153. input_stats.sel(statistic=statistic) >= quantile_values[i]
  154. ).values
  155. else:
  156. mask = (
  157. input_stats.sel(statistic=statistic) <= quantile_values[i]
  158. ).values
  159. # Filter the data based on the mask
  160. filtered_data = input_stats.where(
  161. input_stats.data_id.isin(np.where(mask)), drop=True
  162. )
  163. for metric in metrics:
  164. thresholded_predictions.loc[
  165. {"quantile": quantile, "statistic": statistic, "metric": metric}
  166. ] = compute_metric(filtered_data, metric)
  167. return thresholded_predictions
  168. # Truth function to determine if metric should be thresholded low to high or high to low
  169. # Low confidence is bad, high entropy is bad, high stdev is bad
  170. # So we threshold confidence low to high, entropy and stdev high to low
  171. # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
  172. def low_to_high(stat):
  173. return stat in ["confidence"]
  174. # Compute a given metric on a DataArray of statstics
  175. def compute_metric(arr, metric):
  176. if metric == "accuracy":
  177. return np.mean(arr.loc[{"statistic": "correct"}])
  178. elif metric == "f1":
  179. return met.F1(
  180. arr.loc[{"statistic": "predicted"}], arr.loc[{"statistic": "actual"}]
  181. )
  182. elif metric == "ece":
  183. true_labels = arr.loc[{"statistic": "actual"}].values
  184. predicted_labels = arr.loc[{"statistic": "predicted"}].values
  185. confidences = arr.loc[{"statistic": "confidence"}].values
  186. return calculate_ece_stats(confidences, predicted_labels, true_labels)
  187. else:
  188. raise ValueError("Invalid metric: " + metric)
  189. # Graph a thresholded prediction for a given statistic and metric
  190. def graph_thresholded_prediction(
  191. thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel
  192. ):
  193. data = thresholded_predictions.sel(statistic=statistic, metric=metric)
  194. x_data = data.coords["quantile"].values
  195. y_data = data.values
  196. fig, ax = plt.subplots()
  197. ax.plot(x_data, y_data, "bx-", label="Ensemble")
  198. ax.set_title(title)
  199. ax.set_xlabel(xlabel)
  200. ax.set_ylabel(ylabel)
  201. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  202. if not low_to_high(statistic):
  203. ax.invert_xaxis()
  204. plt.savefig(save_path)
  205. # Graph all thresholded predictions
  206. def graph_all_thresholded_predictions(thresholded_predictions, save_path):
  207. # Confidence Accuracy
  208. graph_thresholded_prediction(
  209. thresholded_predictions,
  210. "confidence",
  211. "accuracy",
  212. f"{save_path}/confidence_accuracy.png",
  213. "Coverage Analysis of Confidence vs. Accuracy",
  214. "Minimum Confidence Percentile Threshold",
  215. "Accuracy",
  216. )
  217. # Confidence F1
  218. graph_thresholded_prediction(
  219. thresholded_predictions,
  220. "confidence",
  221. "f1",
  222. f"{save_path}/confidence_f1.png",
  223. "Coverage Analysis of Confidence vs. F1 Score",
  224. "Minimum Confidence Percentile Threshold",
  225. "F1 Score",
  226. )
  227. # Entropy Accuracy
  228. graph_thresholded_prediction(
  229. thresholded_predictions,
  230. "entropy",
  231. "accuracy",
  232. f"{save_path}/entropy_accuracy.png",
  233. "Coverage Analysis of Entropy vs. Accuracy",
  234. "Maximum Entropy Percentile Threshold",
  235. "Accuracy",
  236. )
  237. # Entropy F1
  238. graph_thresholded_prediction(
  239. thresholded_predictions,
  240. "entropy",
  241. "f1",
  242. f"{save_path}/entropy_f1.png",
  243. "Coverage Analysis of Entropy vs. F1 Score",
  244. "Maximum Entropy Percentile Threshold",
  245. "F1 Score",
  246. )
  247. # Stdev Accuracy
  248. graph_thresholded_prediction(
  249. thresholded_predictions,
  250. "stdev",
  251. "accuracy",
  252. f"{save_path}/stdev_accuracy.png",
  253. "Coverage Analysis of Standard Deviation vs. Accuracy",
  254. "Maximum Standard Deviation Percentile Threshold",
  255. "Accuracy",
  256. )
  257. # Stdev F1
  258. graph_thresholded_prediction(
  259. thresholded_predictions,
  260. "stdev",
  261. "f1",
  262. f"{save_path}/stdev_f1.png",
  263. "Coverage Analysis of Standard Deviation vs. F1 Score",
  264. "Maximum Standard Deviation Percentile Threshold",
  265. "F1",
  266. )
  267. # Graph two statistics against each other
  268. def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
  269. # Filter for correct predictions
  270. c_stats = stats.where(
  271. stats.data_id.isin(np.where((stats.sel(statistic="correct") == 1).values)),
  272. drop=True,
  273. )
  274. # Filter for incorrect predictions
  275. i_stats = stats.where(
  276. stats.data_id.isin(np.where((stats.sel(statistic="correct") == 0).values)),
  277. drop=True,
  278. )
  279. # x and y data for correct and incorrect predictions
  280. x_data_c = c_stats.sel(statistic=x_stat).values
  281. y_data_c = c_stats.sel(statistic=y_stat).values
  282. x_data_i = i_stats.sel(statistic=x_stat).values
  283. y_data_i = i_stats.sel(statistic=y_stat).values
  284. fig, ax = plt.subplots()
  285. ax.plot(x_data_c, y_data_c, "go", label="Correct")
  286. ax.plot(x_data_i, y_data_i, "ro", label="Incorrect")
  287. ax.set_title(title)
  288. ax.set_xlabel(xlabel)
  289. ax.set_ylabel(ylabel)
  290. ax.legend()
  291. plt.savefig(save_path)
  292. # Prune the data based on excluded data_ids
  293. def prune_data(data, excluded_data_ids):
  294. return data.where(~data.data_id.isin(excluded_data_ids), drop=True)
  295. # Calculate individual model statistics
  296. def compute_individual_statistics(predictions: xr.DataArray):
  297. zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6))
  298. indv_statistics = xr.DataArray(
  299. zeros,
  300. dims=("data_id", "model_id", "statistic"),
  301. coords={
  302. "data_id": predictions.data_id,
  303. "model_id": predictions.model_id,
  304. "statistic": [
  305. "mean",
  306. "entropy",
  307. "confidence",
  308. "correct",
  309. "predicted",
  310. "actual",
  311. ],
  312. },
  313. )
  314. for data_id in tqdm(
  315. predictions.data_id, total=len(predictions.data_id), unit="images"
  316. ):
  317. for model_id in predictions.model_id:
  318. data = predictions.loc[{"data_id": data_id, "model_id": model_id}]
  319. mean = data[0:2]
  320. entropy = (-mean * np.log(mean)).sum()
  321. confidence = mean.max()
  322. actual = data[3]
  323. predicted = mean.argmax()
  324. correct = actual == predicted
  325. indv_statistics.loc[{"data_id": data_id, "model_id": model_id}] = [
  326. mean[1],
  327. entropy,
  328. confidence,
  329. correct,
  330. predicted,
  331. actual,
  332. ]
  333. return indv_statistics
  334. # Compute individual model thresholds
  335. def compute_individual_thresholds(input_stats: xr.DataArray):
  336. quantiles = np.linspace(0.05, 0.95, 19) * 100
  337. metrics = ["accuracy", "f1"]
  338. statistics = ["entropy", "confidence"]
  339. zeros = np.zeros(
  340. (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
  341. )
  342. indv_thresholds = xr.DataArray(
  343. zeros,
  344. dims=("model_id", "quantile", "statistic", "metric"),
  345. coords={
  346. "model_id": input_stats.model_id,
  347. "quantile": quantiles,
  348. "statistic": statistics,
  349. "metric": metrics,
  350. },
  351. )
  352. for model_id in tqdm(
  353. input_stats.model_id, total=len(input_stats.model_id), unit="models"
  354. ):
  355. for statistic in statistics:
  356. # First, we must compute the quantiles for the statistic
  357. quantile_values = np.percentile(
  358. input_stats.sel(model_id=model_id, statistic=statistic).values,
  359. quantiles,
  360. axis=0,
  361. )
  362. # Then, we must compute the metrics for each quantile
  363. for i, quantile in enumerate(quantiles):
  364. if low_to_high(statistic):
  365. mask = (
  366. input_stats.sel(model_id=model_id, statistic=statistic)
  367. >= quantile_values[i]
  368. ).values
  369. else:
  370. mask = (
  371. input_stats.sel(model_id=model_id, statistic=statistic)
  372. <= quantile_values[i]
  373. ).values
  374. # Filter the data based on the mask
  375. filtered_data = input_stats.where(
  376. input_stats.data_id.isin(np.where(mask)), drop=True
  377. )
  378. for metric in metrics:
  379. indv_thresholds.loc[
  380. {
  381. "model_id": model_id,
  382. "quantile": quantile,
  383. "statistic": statistic,
  384. "metric": metric,
  385. }
  386. ] = compute_metric(filtered_data, metric)
  387. return indv_thresholds
  388. # Graph individual model thresholded predictions
  389. def graph_individual_thresholded_predictions(
  390. indv_thresholds,
  391. ensemble_thresholds,
  392. statistic,
  393. metric,
  394. save_path,
  395. title,
  396. xlabel,
  397. ylabel,
  398. ):
  399. data = indv_thresholds.sel(statistic=statistic, metric=metric)
  400. e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
  401. x_data = data.coords["quantile"].values
  402. y_data = data.values
  403. e_x_data = e_data.coords["quantile"].values
  404. e_y_data = e_data.values
  405. fig, ax = plt.subplots()
  406. for model_id in data.coords["model_id"].values:
  407. model_data = data.sel(model_id=model_id)
  408. ax.plot(x_data, model_data)
  409. ax.plot(e_x_data, e_y_data, "kx-", label="Ensemble")
  410. ax.set_title(title)
  411. ax.set_xlabel(xlabel)
  412. ax.set_ylabel(ylabel)
  413. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  414. if not low_to_high(statistic):
  415. ax.invert_xaxis()
  416. ax.legend()
  417. plt.savefig(save_path)
  418. # Graph all individual thresholded predictions
  419. def graph_all_individual_thresholded_predictions(
  420. indv_thresholds, ensemble_thresholds, save_path
  421. ):
  422. # Confidence Accuracy
  423. graph_individual_thresholded_predictions(
  424. indv_thresholds,
  425. ensemble_thresholds,
  426. "confidence",
  427. "accuracy",
  428. f"{save_path}/indv/confidence_accuracy.png",
  429. "Coverage Analysis of Confidence vs. Accuracy for All Models",
  430. "Minumum Confidence Percentile Threshold",
  431. "Accuracy",
  432. )
  433. # Confidence F1
  434. graph_individual_thresholded_predictions(
  435. indv_thresholds,
  436. ensemble_thresholds,
  437. "confidence",
  438. "f1",
  439. f"{save_path}/indv/confidence_f1.png",
  440. "Coverage Analysis of Confidence vs. F1 Score for All Models",
  441. "Minimum Confidence Percentile Threshold",
  442. "F1 Score",
  443. )
  444. # Entropy Accuracy
  445. graph_individual_thresholded_predictions(
  446. indv_thresholds,
  447. ensemble_thresholds,
  448. "entropy",
  449. "accuracy",
  450. f"{save_path}/indv/entropy_accuracy.png",
  451. "Coverage Analysis of Entropy vs. Accuracy for All Models",
  452. "Maximum Entropy Percentile Threshold",
  453. "Accuracy",
  454. )
  455. # Entropy F1
  456. graph_individual_thresholded_predictions(
  457. indv_thresholds,
  458. ensemble_thresholds,
  459. "entropy",
  460. "f1",
  461. f"{save_path}/indv/entropy_f1.png",
  462. "Coverage Analysis of Entropy vs. F1 Score for All Models",
  463. "Maximum Entropy Percentile Threshold",
  464. "F1 Score",
  465. )
  466. # Calculate statistics of subsets of models for sensitivity analysis
  467. def calculate_subset_statistics(predictions: xr.DataArray):
  468. # Calculate subsets for 1-49 models
  469. subsets = range(1, len(predictions.model_id))
  470. zeros = np.zeros(
  471. (len(predictions.data_id), len(subsets), 7)
  472. ) # Include stdev, but for 1 models set to NaN
  473. subset_stats = xr.DataArray(
  474. zeros,
  475. dims=("data_id", "model_count", "statistic"),
  476. coords={
  477. "data_id": predictions.data_id,
  478. "model_count": subsets,
  479. "statistic": [
  480. "mean",
  481. "stdev",
  482. "entropy",
  483. "confidence",
  484. "correct",
  485. "predicted",
  486. "actual",
  487. ],
  488. },
  489. )
  490. for data_id in tqdm(
  491. predictions.data_id, total=len(predictions.data_id), unit="images"
  492. ):
  493. for subset in subsets:
  494. data = predictions.sel(
  495. data_id=data_id, model_id=predictions.model_id[:subset]
  496. )
  497. mean = data.mean(dim="model_id")[0:2]
  498. stdev = data.std(dim="model_id")[1]
  499. entropy = (-mean * np.log(mean)).sum()
  500. confidence = mean.max()
  501. actual = data[0][3]
  502. predicted = mean.argmax()
  503. correct = actual == predicted
  504. subset_stats.loc[{"data_id": data_id, "model_count": subset}] = [
  505. mean[1],
  506. stdev,
  507. entropy,
  508. confidence,
  509. correct,
  510. predicted,
  511. actual,
  512. ]
  513. return subset_stats
  514. # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
  515. def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
  516. subsets = subset_stats.model_count
  517. stats = ["accuracy", "f1", "ece"]
  518. zeros = np.zeros((len(subsets), len(stats)))
  519. sens_analysis = xr.DataArray(
  520. zeros,
  521. dims=("model_count", "statistic"),
  522. coords={"model_count": subsets, "statistic": stats},
  523. )
  524. for subset in tqdm(subsets, total=len(subsets), unit="model subsets"):
  525. data = subset_stats.sel(model_count=subset)
  526. acc = compute_metric(data, "accuracy").item()
  527. f1 = compute_metric(data, "f1").item()
  528. ece = compute_metric(data, "ece").item()
  529. sens_analysis.loc[{"model_count": subset.item()}] = [acc, f1, ece]
  530. return sens_analysis
  531. def graph_sensitivity_analysis(
  532. sens_analysis: xr.DataArray, statistic, save_path, title, xlabel, ylabel
  533. ):
  534. data = sens_analysis.sel(statistic=statistic)
  535. xdata = data.coords["model_count"].values
  536. ydata = data.values
  537. fig, ax = plt.subplots()
  538. ax.plot(xdata, ydata)
  539. ax.set_title(title)
  540. ax.set_xlabel(xlabel)
  541. ax.set_ylabel(ylabel)
  542. plt.savefig(save_path)
  543. def calculate_overall_stats(ensemble_statistics: xr.DataArray):
  544. accuracy = compute_metric(ensemble_statistics, "accuracy")
  545. f1 = compute_metric(ensemble_statistics, "f1")
  546. return {"accuracy": accuracy.item(), "f1": f1.item()}
  547. # https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
  548. def calculate_ece_stats(confidences, predicted_labels, true_labels, bins=10):
  549. bin_boundaries = np.linspace(0, 1, bins + 1)
  550. bin_lowers = bin_boundaries[:-1]
  551. bin_uppers = bin_boundaries[1:]
  552. ece = np.zeros(1)
  553. for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
  554. in_bin = np.logical_and(
  555. confidences > bin_lower.item(), confidences <= bin_upper.item()
  556. )
  557. prob_in_bin = in_bin.mean()
  558. if prob_in_bin.item() > 0:
  559. accuracy_in_bin = true_labels[in_bin].mean()
  560. avg_confidence_in_bin = confidences[in_bin].mean()
  561. ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
  562. return ece
  563. def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
  564. fix, ax = plt.subplot()
  565. # Main Function
  566. def main():
  567. print("Loading Config...")
  568. config = load_config()
  569. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  570. V4_PATH = ENSEMBLE_PATH + "/v4"
  571. if not os.path.exists(V4_PATH):
  572. os.makedirs(V4_PATH)
  573. print("Config Loaded")
  574. # Load Datasets
  575. print("Loading Datasets...")
  576. (test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
  577. print("Datasets Loaded")
  578. # Get Predictions, either by running the models or loading them from a file
  579. if config["ensemble"]["run_models"]:
  580. # Load Models
  581. print("Loading Models...")
  582. device = torch.device(config["training"]["device"])
  583. models = load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
  584. print("Models Loaded")
  585. # Get Predictions
  586. print("Getting Predictions...")
  587. test_predictions = get_ensemble_predictions(models, test_dataset, device)
  588. val_predictions = get_ensemble_predictions(
  589. models, val_dataset, device, len(test_dataset)
  590. )
  591. print("Predictions Loaded")
  592. # Save Prediction
  593. test_predictions.to_netcdf(f"{V4_PATH}/test_predictions.nc")
  594. val_predictions.to_netcdf(f"{V4_PATH}/val_predictions.nc")
  595. else:
  596. test_predictions = xr.open_dataarray(f"{V4_PATH}/test_predictions.nc")
  597. val_predictions = xr.open_dataarray(f"{V4_PATH}/val_predictions.nc")
  598. # Prune Data
  599. print("Pruning Data...")
  600. if config["operation"]["exclude_blank_ids"]:
  601. excluded_data_ids = config["ensemble"]["excluded_ids"]
  602. test_predictions = prune_data(test_predictions, excluded_data_ids)
  603. val_predictions = prune_data(val_predictions, excluded_data_ids)
  604. # Concatenate Predictions
  605. predictions = xr.concat([test_predictions, val_predictions], dim="data_id")
  606. # Compute Ensemble Statistics
  607. print("Computing Ensemble Statistics...")
  608. ensemble_statistics = compute_ensemble_statistics(predictions)
  609. ensemble_statistics.to_netcdf(f"{V4_PATH}/ensemble_statistics.nc")
  610. print("Ensemble Statistics Computed")
  611. # Compute Thresholded Predictions
  612. print("Computing Thresholded Predictions...")
  613. thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
  614. thresholded_predictions.to_netcdf(f"{V4_PATH}/thresholded_predictions.nc")
  615. print("Thresholded Predictions Computed")
  616. # Graph Thresholded Predictions
  617. print("Graphing Thresholded Predictions...")
  618. graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
  619. print("Thresholded Predictions Graphed")
  620. # Additional Graphs
  621. print("Graphing Additional Graphs...")
  622. # Confidence vs stdev
  623. graph_statistics(
  624. ensemble_statistics,
  625. "confidence",
  626. "stdev",
  627. f"{V4_PATH}/confidence_stdev.png",
  628. "Confidence and Standard Deviation for Predictions",
  629. "Confidence",
  630. "Standard Deviation",
  631. )
  632. print("Additional Graphs Graphed")
  633. # Compute Individual Statistics
  634. print("Computing Individual Statistics...")
  635. indv_statistics = compute_individual_statistics(predictions)
  636. indv_statistics.to_netcdf(f"{V4_PATH}/indv_statistics.nc")
  637. print("Individual Statistics Computed")
  638. # Compute Individual Thresholds
  639. print("Computing Individual Thresholds...")
  640. indv_thresholds = compute_individual_thresholds(indv_statistics)
  641. indv_thresholds.to_netcdf(f"{V4_PATH}/indv_thresholds.nc")
  642. print("Individual Thresholds Computed")
  643. # Graph Individual Thresholded Predictions
  644. print("Graphing Individual Thresholded Predictions...")
  645. if not os.path.exists(f"{V4_PATH}/indv"):
  646. os.makedirs(f"{V4_PATH}/indv")
  647. graph_all_individual_thresholded_predictions(
  648. indv_thresholds, thresholded_predictions, V4_PATH
  649. )
  650. print("Individual Thresholded Predictions Graphed")
  651. # Compute subset statistics and graph
  652. print("Computing Sensitivity Analysis...")
  653. subset_stats = calculate_subset_statistics(predictions)
  654. sens_analysis = calculate_sensitivity_analysis(subset_stats)
  655. graph_sensitivity_analysis(
  656. sens_analysis,
  657. "accuracy",
  658. f"{V4_PATH}/sens_analysis.png",
  659. "Sensitivity Analsis of Accuracy vs. # of Models",
  660. "# of Models",
  661. "Accuracy",
  662. )
  663. graph_sensitivity_analysis(
  664. sens_analysis,
  665. "ece",
  666. f"{V4_PATH}/sens_analysis_ece.png",
  667. "Sensitivity Analysis of ECE vs. # of Models",
  668. "# of Models",
  669. "ECE",
  670. )
  671. print(sens_analysis.sel(statistic="accuracy"))
  672. print(calculate_overall_stats(ensemble_statistics))
  673. if __name__ == "__main__":
  674. main()