threshold_xarray.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711
  1. # Rewritten Program to use xarray instead of pandas for thresholding
  2. import xarray as xr
  3. import torch
  4. import numpy as np
  5. import os
  6. import glob
  7. import tomli as toml
  8. from tqdm import tqdm
  9. import utils.metrics as met
  10. import matplotlib.pyplot as plt
  11. import matplotlib.ticker as mtick
  12. # The datastructures for this file are as follows
  13. # models_dict: Dictionary - {model_id: model}
  14. # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
  15. # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
  16. # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
  17. # Additionally, we also have the thresholds and statistics for the individual models
  18. # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
  19. # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
  20. # Additionally, we have some for the sensitivity analysis for number of models
  21. # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
  22. # Loads configuration dictionary
  23. def load_config():
  24. if os.getenv('ADL_CONFIG_PATH') is None:
  25. with open('config.toml', 'rb') as f:
  26. config = toml.load(f)
  27. else:
  28. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  29. config = toml.load(f)
  30. return config
  31. # Loads models into a dictionary
  32. def load_models_v2(folder, device):
  33. glob_path = os.path.join(folder, '*.pt')
  34. model_files = glob.glob(glob_path)
  35. model_dict = {}
  36. for model_file in model_files:
  37. model = torch.load(model_file, map_location=device)
  38. model_id = os.path.basename(model_file).split('_')[0]
  39. model_dict[model_id] = model
  40. if len(model_dict) == 0:
  41. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  42. return model_dict
  43. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  44. def preprocess_data(data, device):
  45. mri, xls = data
  46. mri = mri.unsqueeze(0).to(device)
  47. xls = xls.unsqueeze(0).to(device)
  48. return (mri, xls)
  49. # Loads datasets and returns concatenated test and validation datasets
  50. def load_datasets(ensemble_path):
  51. return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
  52. f'{ensemble_path}/val_dataset.pt'
  53. )
  54. # Gets the predictions for a set of models on a dataset
  55. def get_ensemble_predictions(models, dataset, device):
  56. zeros = np.zeros((len(dataset), len(models), 4))
  57. predictions = xr.DataArray(
  58. zeros,
  59. dims=('data_id', 'model_id', 'prediction_value'),
  60. coords={
  61. 'data_id': range(len(dataset)),
  62. 'model_id': list(models.keys()),
  63. 'prediction_value': [
  64. 'negative_prediction',
  65. 'positive_prediction',
  66. 'negative_actual',
  67. 'positive_actual',
  68. ],
  69. },
  70. )
  71. for data_id, (data, target) in tqdm(
  72. enumerate(dataset), total=len(dataset), unit='images'
  73. ):
  74. dat = preprocess_data(data, device)
  75. actual = list(target.cpu().numpy())
  76. for model_id, model in models.items():
  77. with torch.no_grad():
  78. output = model(dat)
  79. prediction = output.cpu().numpy().tolist()[0]
  80. predictions.loc[{'data_id': data_id, 'model_id': model_id}] = (
  81. prediction + actual
  82. )
  83. return predictions
  84. # Compute the ensemble statistics given an array of predictions
  85. def compute_ensemble_statistics(predictions: xr.DataArray):
  86. zeros = np.zeros((len(predictions.data_id), 7))
  87. ensemble_statistics = xr.DataArray(
  88. zeros,
  89. dims=('data_id', 'statistic'),
  90. coords={
  91. 'data_id': predictions.data_id,
  92. 'statistic': [
  93. 'mean',
  94. 'stdev',
  95. 'entropy',
  96. 'confidence',
  97. 'correct',
  98. 'predicted',
  99. 'actual',
  100. ],
  101. },
  102. )
  103. for data_id in predictions.data_id:
  104. data = predictions.loc[{'data_id': data_id}]
  105. mean = data.mean(dim='model_id')[
  106. 0:2
  107. ] # Only take the predictions, not the actual
  108. stdev = data.std(dim='model_id')[
  109. 1
  110. ] # Only need the standard deviation of the postive prediction
  111. entropy = (-mean * np.log(mean)).sum()
  112. # Compute confidence
  113. confidence = mean.max()
  114. # only need one of the actual values, since they are all the same, just get the first actual_positive
  115. actual = data.loc[{'prediction_value': 'positive_actual'}][0]
  116. predicted = mean.argmax()
  117. correct = actual == predicted
  118. ensemble_statistics.loc[{'data_id': data_id}] = [
  119. mean[1],
  120. stdev,
  121. entropy,
  122. confidence,
  123. correct,
  124. predicted,
  125. actual,
  126. ]
  127. return ensemble_statistics
  128. # Compute the thresholded predictions given an array of predictions
  129. def compute_thresholded_predictions(input_stats: xr.DataArray):
  130. quantiles = np.linspace(0.05, 0.95, 19) * 100
  131. metrics = ['accuracy', 'f1']
  132. statistics = ['stdev', 'entropy', 'confidence']
  133. zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
  134. thresholded_predictions = xr.DataArray(
  135. zeros,
  136. dims=('quantile', 'statistic', 'metric'),
  137. coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
  138. )
  139. for statistic in statistics:
  140. # First, we must compute the quantiles for the statistic
  141. quantile_values = np.percentile(
  142. input_stats.sel(statistic=statistic).values, quantiles, axis=0
  143. )
  144. # Then, we must compute the metrics for each quantile
  145. for i, quantile in enumerate(quantiles):
  146. if low_to_high(statistic):
  147. mask = (
  148. input_stats.sel(statistic=statistic) >= quantile_values[i]
  149. ).values
  150. else:
  151. mask = (
  152. input_stats.sel(statistic=statistic) <= quantile_values[i]
  153. ).values
  154. # Filter the data based on the mask
  155. filtered_data = input_stats.where(
  156. input_stats.data_id.isin(np.where(mask)), drop=True
  157. )
  158. for metric in metrics:
  159. thresholded_predictions.loc[
  160. {'quantile': quantile, 'statistic': statistic, 'metric': metric}
  161. ] = compute_metric(filtered_data, metric)
  162. return thresholded_predictions
  163. # Truth function to determine if metric should be thresholded low to high or high to low
  164. # Low confidence is bad, high entropy is bad, high stdev is bad
  165. # So we threshold confidence low to high, entropy and stdev high to low
  166. # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
  167. def low_to_high(stat):
  168. return stat in ['confidence']
  169. # Compute a given metric on a DataArray of statstics
  170. def compute_metric(arr, metric):
  171. if metric == 'accuracy':
  172. return np.mean(arr.loc[{'statistic': 'correct'}])
  173. elif metric == 'f1':
  174. return met.F1(
  175. arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
  176. )
  177. else:
  178. raise ValueError('Invalid metric: ' + metric)
  179. # Graph a thresholded prediction for a given statistic and metric
  180. def graph_thresholded_prediction(
  181. thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel
  182. ):
  183. data = thresholded_predictions.sel(statistic=statistic, metric=metric)
  184. x_data = data.coords['quantile'].values
  185. y_data = data.values
  186. fig, ax = plt.subplots()
  187. ax.plot(x_data, y_data, 'bx-', label='Ensemble')
  188. ax.set_title(title)
  189. ax.set_xlabel(xlabel)
  190. ax.set_ylabel(ylabel)
  191. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  192. if not low_to_high(statistic):
  193. ax.invert_xaxis()
  194. plt.savefig(save_path)
  195. # Graph all thresholded predictions
  196. def graph_all_thresholded_predictions(thresholded_predictions, save_path):
  197. # Confidence Accuracy
  198. graph_thresholded_prediction(
  199. thresholded_predictions,
  200. 'confidence',
  201. 'accuracy',
  202. f'{save_path}/confidence_accuracy.png',
  203. 'Confidence vs. Accuracy',
  204. 'Confidence',
  205. 'Accuracy',
  206. )
  207. # Confidence F1
  208. graph_thresholded_prediction(
  209. thresholded_predictions,
  210. 'confidence',
  211. 'f1',
  212. f'{save_path}/confidence_f1.png',
  213. 'Confidence vs. F1',
  214. 'Confidence',
  215. 'F1',
  216. )
  217. # Entropy Accuracy
  218. graph_thresholded_prediction(
  219. thresholded_predictions,
  220. 'entropy',
  221. 'accuracy',
  222. f'{save_path}/entropy_accuracy.png',
  223. 'Entropy vs. Accuracy',
  224. 'Entropy',
  225. 'Accuracy',
  226. )
  227. # Entropy F1
  228. graph_thresholded_prediction(
  229. thresholded_predictions,
  230. 'entropy',
  231. 'f1',
  232. f'{save_path}/entropy_f1.png',
  233. 'Entropy vs. F1',
  234. 'Entropy',
  235. 'F1',
  236. )
  237. # Stdev Accuracy
  238. graph_thresholded_prediction(
  239. thresholded_predictions,
  240. 'stdev',
  241. 'accuracy',
  242. f'{save_path}/stdev_accuracy.png',
  243. 'Standard Deviation vs. Accuracy',
  244. 'Standard Deviation',
  245. 'Accuracy',
  246. )
  247. # Stdev F1
  248. graph_thresholded_prediction(
  249. thresholded_predictions,
  250. 'stdev',
  251. 'f1',
  252. f'{save_path}/stdev_f1.png',
  253. 'Standard Deviation vs. F1',
  254. 'Standard Deviation',
  255. 'F1',
  256. )
  257. # Graph two statistics against each other
  258. def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
  259. # Filter for correct predictions
  260. c_stats = stats.where(
  261. stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
  262. drop=True,
  263. )
  264. # Filter for incorrect predictions
  265. i_stats = stats.where(
  266. stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
  267. drop=True,
  268. )
  269. # x and y data for correct and incorrect predictions
  270. x_data_c = c_stats.sel(statistic=x_stat).values
  271. y_data_c = c_stats.sel(statistic=y_stat).values
  272. x_data_i = i_stats.sel(statistic=x_stat).values
  273. y_data_i = i_stats.sel(statistic=y_stat).values
  274. fig, ax = plt.subplots()
  275. ax.plot(x_data_c, y_data_c, 'go', label='Correct')
  276. ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
  277. ax.set_title(title)
  278. ax.set_xlabel(xlabel)
  279. ax.set_ylabel(ylabel)
  280. ax.legend()
  281. plt.savefig(save_path)
  282. # Prune the data based on excluded data_ids
  283. def prune_data(data, excluded_data_ids):
  284. return data.where(~data.data_id.isin(excluded_data_ids), drop=True)
  285. # Calculate individual model statistics
  286. def compute_individual_statistics(predictions: xr.DataArray):
  287. zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6))
  288. indv_statistics = xr.DataArray(
  289. zeros,
  290. dims=('data_id', 'model_id', 'statistic'),
  291. coords={
  292. 'data_id': predictions.data_id,
  293. 'model_id': predictions.model_id,
  294. 'statistic': [
  295. 'mean',
  296. 'entropy',
  297. 'confidence',
  298. 'correct',
  299. 'predicted',
  300. 'actual',
  301. ],
  302. },
  303. )
  304. for data_id in predictions.data_id:
  305. for model_id in predictions.model_id:
  306. data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
  307. mean = data[0:2]
  308. entropy = (-mean * np.log(mean)).sum()
  309. confidence = mean.max()
  310. actual = data[3]
  311. predicted = mean.argmax()
  312. correct = actual == predicted
  313. indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
  314. mean[1],
  315. entropy,
  316. confidence,
  317. correct,
  318. predicted,
  319. actual,
  320. ]
  321. return indv_statistics
  322. # Compute individual model thresholds
  323. def compute_individual_thresholds(input_stats: xr.DataArray):
  324. quantiles = np.linspace(0.05, 0.95, 19) * 100
  325. metrics = ['accuracy', 'f1']
  326. statistics = ['entropy', 'confidence']
  327. zeros = np.zeros(
  328. (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
  329. )
  330. indv_thresholds = xr.DataArray(
  331. zeros,
  332. dims=('model_id', 'quantile', 'statistic', 'metric'),
  333. coords={
  334. 'model_id': input_stats.model_id,
  335. 'quantile': quantiles,
  336. 'statistic': statistics,
  337. 'metric': metrics,
  338. },
  339. )
  340. for model_id in input_stats.model_id:
  341. for statistic in statistics:
  342. # First, we must compute the quantiles for the statistic
  343. quantile_values = np.percentile(
  344. input_stats.sel(model_id=model_id, statistic=statistic).values,
  345. quantiles,
  346. axis=0,
  347. )
  348. # Then, we must compute the metrics for each quantile
  349. for i, quantile in enumerate(quantiles):
  350. if low_to_high(statistic):
  351. mask = (
  352. input_stats.sel(model_id=model_id, statistic=statistic)
  353. >= quantile_values[i]
  354. ).values
  355. else:
  356. mask = (
  357. input_stats.sel(model_id=model_id, statistic=statistic)
  358. <= quantile_values[i]
  359. ).values
  360. # Filter the data based on the mask
  361. filtered_data = input_stats.where(
  362. input_stats.data_id.isin(np.where(mask)), drop=True
  363. )
  364. for metric in metrics:
  365. indv_thresholds.loc[
  366. {
  367. 'model_id': model_id,
  368. 'quantile': quantile,
  369. 'statistic': statistic,
  370. 'metric': metric,
  371. }
  372. ] = compute_metric(filtered_data, metric)
  373. return indv_thresholds
  374. # Graph individual model thresholded predictions
  375. def graph_individual_thresholded_predictions(
  376. indv_thresholds,
  377. ensemble_thresholds,
  378. statistic,
  379. metric,
  380. save_path,
  381. title,
  382. xlabel,
  383. ylabel,
  384. ):
  385. data = indv_thresholds.sel(statistic=statistic, metric=metric)
  386. e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
  387. x_data = data.coords['quantile'].values
  388. y_data = data.values
  389. e_x_data = e_data.coords['quantile'].values
  390. e_y_data = e_data.values
  391. fig, ax = plt.subplots()
  392. for model_id in data.coords['model_id'].values:
  393. model_data = data.sel(model_id=model_id)
  394. ax.plot(x_data, model_data)
  395. ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
  396. ax.set_title(title)
  397. ax.set_xlabel(xlabel)
  398. ax.set_ylabel(ylabel)
  399. ax.xaxis.set_major_formatter(mtick.PercentFormatter())
  400. if not low_to_high(statistic):
  401. ax.invert_xaxis()
  402. ax.legend()
  403. plt.savefig(save_path)
  404. # Graph all individual thresholded predictions
  405. def graph_all_individual_thresholded_predictions(
  406. indv_thresholds, ensemble_thresholds, save_path
  407. ):
  408. # Confidence Accuracy
  409. graph_individual_thresholded_predictions(
  410. indv_thresholds,
  411. ensemble_thresholds,
  412. 'confidence',
  413. 'accuracy',
  414. f'{save_path}/indv/confidence_accuracy.png',
  415. 'Confidence vs. Accuracy',
  416. 'Confidence Percentile Threshold',
  417. 'Accuracy',
  418. )
  419. # Confidence F1
  420. graph_individual_thresholded_predictions(
  421. indv_thresholds,
  422. ensemble_thresholds,
  423. 'confidence',
  424. 'f1',
  425. f'{save_path}/indv/confidence_f1.png',
  426. 'Confidence vs. F1',
  427. 'Confidence Percentile Threshold',
  428. 'F1',
  429. )
  430. # Entropy Accuracy
  431. graph_individual_thresholded_predictions(
  432. indv_thresholds,
  433. ensemble_thresholds,
  434. 'entropy',
  435. 'accuracy',
  436. f'{save_path}/indv/entropy_accuracy.png',
  437. 'Entropy vs. Accuracy',
  438. 'Entropy Percentile Threshold',
  439. 'Accuracy',
  440. )
  441. # Entropy F1
  442. graph_individual_thresholded_predictions(
  443. indv_thresholds,
  444. ensemble_thresholds,
  445. 'entropy',
  446. 'f1',
  447. f'{save_path}/indv/entropy_f1.png',
  448. 'Entropy vs. F1',
  449. 'Entropy Percentile Threshold',
  450. 'F1',
  451. )
  452. # Calculate statistics of subsets of models for sensitivity analysis
  453. def calculate_subset_statistics(predictions: xr.DataArray):
  454. # Calculate subsets for 1-50 models
  455. subsets = range(1, len(predictions.model_id) + 1)
  456. zeros = np.zeros(
  457. (len(predictions.data_id), len(subsets), 7)
  458. ) # Include stdev, but for 1 models set to NaN
  459. subset_stats = xr.DataArray(
  460. zeros,
  461. dims=('data_id', 'model_count', 'statistic'),
  462. coords={
  463. 'data_id': predictions.data_id,
  464. 'model_count': subsets,
  465. 'statistic': [
  466. 'mean',
  467. 'stdev',
  468. 'entropy',
  469. 'confidence',
  470. 'correct',
  471. 'predicted',
  472. 'actual',
  473. ],
  474. },
  475. )
  476. for data_id in predictions.data_id:
  477. for subset in subsets:
  478. data = predictions.sel(
  479. data_id=data_id, model_id=predictions.model_id[:subset]
  480. )
  481. mean = data.mean(dim='model_id')[0:2]
  482. stdev = data.std(dim='model_id')[1]
  483. entropy = (-mean * np.log(mean)).sum()
  484. confidence = mean.max()
  485. actual = data[3]
  486. predicted = mean.argmax()
  487. correct = actual == predicted
  488. subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
  489. mean[1],
  490. stdev,
  491. entropy,
  492. confidence,
  493. correct,
  494. predicted,
  495. actual,
  496. ]
  497. return subset_stats
  498. # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
  499. def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
  500. subsets = subset_stats.subsets
  501. stats = ['accuracy', 'f1', 'ECE', 'MCE']
  502. zeros = np.zeros((len(subsets), len(stats)))
  503. sens_analysis = xr.DataArray(
  504. zeros,
  505. dims=('model_count', 'statistic'),
  506. coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
  507. )
  508. # Main Function
  509. def main():
  510. print('Loading Config...')
  511. config = load_config()
  512. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  513. V4_PATH = ENSEMBLE_PATH + '/v4'
  514. if not os.path.exists(V4_PATH):
  515. os.makedirs(V4_PATH)
  516. print('Config Loaded')
  517. # Load Datasets
  518. print('Loading Datasets...')
  519. dataset = load_datasets(ENSEMBLE_PATH)
  520. print('Datasets Loaded')
  521. # Get Predictions, either by running the models or loading them from a file
  522. if config['ensemble']['run_models']:
  523. # Load Models
  524. print('Loading Models...')
  525. device = torch.device(config['training']['device'])
  526. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  527. print('Models Loaded')
  528. # Get Predictions
  529. print('Getting Predictions...')
  530. predictions = get_ensemble_predictions(models, dataset, device)
  531. print('Predictions Loaded')
  532. # Save Prediction
  533. predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
  534. else:
  535. predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
  536. # Prune Data
  537. print('Pruning Data...')
  538. if config['operation']['exclude_blank_ids']:
  539. excluded_data_ids = config['ensemble']['excluded_ids']
  540. predictions = prune_data(predictions, excluded_data_ids)
  541. # Compute Ensemble Statistics
  542. print('Computing Ensemble Statistics...')
  543. ensemble_statistics = compute_ensemble_statistics(predictions)
  544. ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
  545. print('Ensemble Statistics Computed')
  546. # Compute Thresholded Predictions
  547. print('Computing Thresholded Predictions...')
  548. thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
  549. thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
  550. print('Thresholded Predictions Computed')
  551. # Graph Thresholded Predictions
  552. print('Graphing Thresholded Predictions...')
  553. graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
  554. print('Thresholded Predictions Graphed')
  555. # Additional Graphs
  556. print('Graphing Additional Graphs...')
  557. # Confidence vs stdev
  558. graph_statistics(
  559. ensemble_statistics,
  560. 'confidence',
  561. 'stdev',
  562. f'{V4_PATH}/confidence_stdev.png',
  563. 'Confidence vs. Standard Deviation',
  564. 'Confidence',
  565. 'Standard Deviation',
  566. )
  567. print('Additional Graphs Graphed')
  568. # Compute Individual Statistics
  569. print('Computing Individual Statistics...')
  570. indv_statistics = compute_individual_statistics(predictions)
  571. indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
  572. print('Individual Statistics Computed')
  573. # Compute Individual Thresholds
  574. print('Computing Individual Thresholds...')
  575. indv_thresholds = compute_individual_thresholds(indv_statistics)
  576. indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
  577. print('Individual Thresholds Computed')
  578. # Graph Individual Thresholded Predictions
  579. print('Graphing Individual Thresholded Predictions...')
  580. if not os.path.exists(f'{V4_PATH}/indv'):
  581. os.makedirs(f'{V4_PATH}/indv')
  582. graph_all_individual_thresholded_predictions(
  583. indv_thresholds, thresholded_predictions, V4_PATH
  584. )
  585. print('Individual Thresholded Predictions Graphed')
  586. if __name__ == '__main__':
  587. main()