Ver código fonte

continuing work on rewrite

Nicholas Schense 3 dias atrás
pai
commit
fb995072b3
8 arquivos alterados com 567 adições e 411 exclusões
  1. 18 17
      ensemble_predict.py
  2. 81 0
      model_evaluation.py
  3. 34 32
      sensitivity_analysis.py
  4. 147 113
      threshold_refac.py
  5. 235 230
      threshold_xarray.py
  6. 15 4
      utils/data/datasets.py
  7. 34 13
      utils/ensemble.py
  8. 3 2
      utils/models/cnn.py

+ 18 - 17
ensemble_predict.py

@@ -7,28 +7,29 @@ import math
 import torch
 
 # CONFIGURATION
-if os.getenv('ADL_CONFIG_PATH') is None:
-    with open('config.toml', 'rb') as f:
+if os.getenv("ADL_CONFIG_PATH") is None:
+    with open("config.toml", "rb") as f:
         config = toml.load(f)
 else:
-    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
         config = toml.load(f)
 
 # Force cuDNN initialization
-force_init_cudnn(config['training']['device'])
+force_init_cudnn(config["training"]["device"])
 
 
 ensemble_folder = (
-    config['paths']['model_output'] + config['ensemble']['name'] + '/models/'
+    config["paths"]["model_output"] + config["ensemble"]["name"] + "/models/"
 )
-models, model_descs = ens.load_models(ensemble_folder, config['training']['device'])
+models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
 models, model_descs = ens.prune_models(
-    models, model_descs, ensemble_folder, config['ensemble']['prune_threshold']
+    models, model_descs, ensemble_folder, config["ensemble"]["prune_threshold"]
 )
 
 # Load test data
 test_dataset = torch.load(
-    config['paths']['model_output'] + config['ensemble']['name'] + '/test_dataset.pt'
+    config["paths"]["model_output"] + config["ensemble"]["name"] + "/test_dataset.pt",
+    weights_only=False,
 )
 
 # Evaluate ensemble and uncertainty test set
@@ -67,22 +68,22 @@ accuracy = correct / total
 with open(
     ensemble_folder
     + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
-    'w',
+    "w",
 ) as f:
-    f.write('Accuracy: ' + str(accuracy) + '\n')
-    f.write('Correct: ' + str(correct) + '\n')
-    f.write('Total: ' + str(total) + '\n')
+    f.write("Accuracy: " + str(accuracy) + "\n")
+    f.write("Correct: " + str(correct) + "\n")
+    f.write("Total: " + str(total) + "\n")
 
     for exp, pred, stdev in zip(actual, predictions, stdevs):
         f.write(
             str(exp)
-            + ', '
+            + ", "
             + str(pred)
-            + ', '
+            + ", "
             + str(stdev)
-            + ', '
+            + ", "
             + str(yes_votes)
-            + ', '
+            + ", "
             + str(no_votes)
-            + '\n'
+            + "\n"
         )

+ 81 - 0
model_evaluation.py

@@ -0,0 +1,81 @@
+import utils.ensemble as ens
+import os
+import tomli as tml
+from utils.system import force_init_cudnn
+import torch
+import pathlib as pl
+from utils.data.datasets import ADNIDataset
+import xarray as xr
+
+
+# CONFIGURATION
+with open(os.getenv("ADL_CONFIG_PATH", "config.toml"), "rb") as f:
+    config = tml.load(f)
+
+force_init_cudnn(config["training"]["device"])
+
+# INIT DATA AND MODELS
+ensemble_folder: pl.Path = (
+    config["paths"]["model_output"] + config["ensemble"]["name"] + "/models/"
+)
+
+# Load test data
+test_dataset: ADNIDataset = torch.load(
+    config["paths"]["model_output"] + config["ensemble"]["name"] + "/test_dataset.pt",
+    weights_only=False,
+)
+
+
+models = ens.load_models(pl.Path(ensemble_folder), config["training"]["device"])
+
+# We are generating a large matrix, with the dimensions of the models, the test set, and the number of classes
+# Therefore we are capturing the output of every model for every item in the test set and storing it in a matrix
+
+type ResultsMatrix = xr.DataArray
+type ActualMatrix = xr.DataArray
+
+results: ResultsMatrix = xr.DataArray(
+    data=0,
+    dims=["model", "test_item", "class"],
+    coords={
+        "model": ens.get_model_names(models),
+        "test_item": range(len(test_dataset)),
+        "class": [0, 1],
+    },
+)
+
+actual: ActualMatrix = xr.DataArray(
+    data=0,
+    dims=["test_item", "class"],
+    coords={
+        "test_item": range(len(test_dataset)),
+        "class": [0, 1],
+    },
+)
+
+final: xr.Dataset = xr.Dataset(
+    data_vars={
+        "evaluated": results,
+        "actual": actual,
+    },
+)
+
+
+# Iterate over the test set and get the predictions for each model
+for i, (unp_data, target) in enumerate(test_dataset):
+    data = ens.prepare_datasets(unp_data)
+
+    for j, (model_obj, model_name) in enumerate(models):
+        model_obj.eval()
+        with torch.no_grad():
+            output: torch.Tensor = model_obj(data)
+            final.results.loc[dict(model=model_name, test_item=i)] = output.numpy()  # type: ignore
+            final.actual.loc[dict(test_item=i)] = target.numpy()  # type: ignore
+
+
+# Save the results to a file
+final.to_netcdf(  # type: ignore
+    config["paths"]["model_output"] + config["ensemble"]["name"] + "/test_results.nc",
+    mode="w",
+    format="NETCDF4",
+)

+ 34 - 32
sensitivity_analysis.py

@@ -3,83 +3,85 @@
 import numpy as np
 import pandas as pd
 import matplotlib.pyplot as plt
-import torch 
+import torch
 import os
 
 import threshold_refac as th
 import pickle as pk
+import utils.models.cnn
+
+torch.serialization.safe_globals([utils.models.cnn.CNN])
+
 
 def main():
     config = th.load_config()
 
-    
     ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
 
-    V3_PATH = ENSEMBLE_PATH + '/v3'
+    V3_PATH = ENSEMBLE_PATH + "/v3"
 
     # Create the directory if it does not exist
     if not os.path.exists(V3_PATH):
         os.makedirs(V3_PATH)
 
     # Load the models
-    device = torch.device(config['training']['device'])
-    models = th.load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
+    device = torch.device(config["training"]["device"])
+    models = th.load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
 
     # Load Dataset
-    dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
-        f'{ENSEMBLE_PATH}/val_dataset.pt'
-    )
+    dataset = torch.load(
+        f"{ENSEMBLE_PATH}/test_dataset.pt", weights_only=False
+    ) + torch.load(f"{ENSEMBLE_PATH}/val_dataset.pt", weights_only=False)
 
-    if config['ensemble']['run_models']:
+    if config["ensemble"]["run_models"]:
         # Get thre predicitons of the ensemble
         ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)
 
         # Save to file using pickle
-        with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
+        with open(f"{V3_PATH}/ensemble_predictions.pk", "wb") as f:
             pk.dump(ensemble_predictions, f)
     else:
         # Load the predictions from file
-        with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
+        with open(f"{V3_PATH}/ensemble_predictions.pk", "rb") as f:
             ensemble_predictions = pk.load(f)
 
     # Now that we have the predictions, we can run the sensitivity analysis
-    #We do this by getting the stats for each possible number of models in the ensemble
+    # We do this by getting the stats for each possible number of models in the ensemble
     # We will store the results in a dataframe with number of models and the stats
-    results = pd.DataFrame(columns=['num_models', 'ECE', 'accuracy']).set_index('num_models')
+    results = pd.DataFrame(columns=["num_models", "ECE", "accuracy"]).set_index(
+        "num_models"
+    )
     for i in range(2, len(models) + 1):
         sel_preds = th.select_subset_models(ensemble_predictions, range(i))
 
-
         sel_stats = th.calculate_statistics(sel_preds)
 
-        raw_confidence = sel_stats['confidence'].apply(lambda x: (x / 2) + 0.5)
-        sel_stats.insert(4, 'raw_confidence', raw_confidence)
+        raw_confidence = sel_stats["confidence"].apply(lambda x: (x / 2) + 0.5)
+        sel_stats.insert(4, "raw_confidence", raw_confidence)
 
         stats = th.calculate_overall_statistics(sel_stats)
-        ece = stats.at['raw_confidence', 'ECE']
-        accuracy = sel_stats['correct'].mean()
+        ece = stats.at["raw_confidence", "ECE"]
+        accuracy = sel_stats["correct"].mean()
         results.loc[i] = (ece, accuracy)
 
     # Save the results to a file
-    results.to_csv(f'{V3_PATH}/sensitivity_analysis.csv')
+    results.to_csv(f"{V3_PATH}/sensitivity_analysis.csv")
 
     # Plot the results
-    plt.plot(results.index, results['ECE'])
-    plt.xlabel('Number of Models')
-    plt.ylabel('ECE')
-    plt.title('Sensitivity Analysis')
-    plt.savefig(f'{V3_PATH}/sensitivity_analysis.png')
+    plt.plot(results.index, results["ECE"])
+    plt.xlabel("Number of Models")
+    plt.ylabel("ECE")
+    plt.title("Sensitivity Analysis")
+    plt.savefig(f"{V3_PATH}/sensitivity_analysis.png")
     plt.close()
 
-    plt.plot(results.index, results['accuracy'])
-    plt.xlabel('Number of Models')
-    plt.ylabel('Accuracy')
-    plt.title('Sensitivity Analysis')
-    plt.savefig(f'{V3_PATH}/sensitivity_analysis_accuracy.png')
+    plt.plot(results.index, results["accuracy"])
+    plt.xlabel("Number of Models")
+    plt.ylabel("Accuracy")
+    plt.title("Sensitivity Analysis")
+    plt.savefig(f"{V3_PATH}/sensitivity_analysis_accuracy.png")
     plt.close()
 
 
-
-
 if __name__ == "__main__":
-    main()
+    main()

+ 147 - 113
threshold_refac.py

@@ -11,8 +11,40 @@ import matplotlib.ticker as ticker
 import glob
 import pickle as pk
 import warnings
+import utils.models.cnn
+import torch.serialization
+import utils.models.layers
+import utils.data.datasets
+import pandas.core.frame
+import pandas.core.internals.managers
+
+torch.serialization.add_safe_globals(
+    [
+        torch.nn.modules.linear.Linear,
+        torch.nn.modules.batchnorm.BatchNorm3d,
+        torch.nn.modules.container.Sequential,
+        torch.nn.modules.activation.ELU,
+        utils.models.layers.ConvBlock,
+        utils.models.layers.FullConnBlock,
+        utils.models.layers.SplitConvBlock,
+        utils.models.layers.SepConv3d,
+        utils.models.cnn.CNN_Image_Section,
+        torch.nn.modules.dropout.Dropout,
+        torch.nn.modules.conv.Conv3d,
+        torch.nn.modules.batchnorm.BatchNorm1d,
+        utils.models.layers.MidFlowBlock,
+        torch.nn.modules.activation.Softmax,
+        utils.models.layers.SepConvBlock,
+        utils.models.cnn.CNN,
+        utils.data.datasets.ADNIDataset,
+        pandas.core.frame.DataFrame,
+        pandas.core.internals.managers.BlockManager,
+        pandas._libs.internals._unpickle_block,
+    ]
+)
+
 
-warnings.filterwarnings('error')
+warnings.filterwarnings("error")
 
 
 def plot_image_grid(image_ids, dataset, rows, path, titles=None):
@@ -23,8 +55,8 @@ def plot_image_grid(image_ids, dataset, rows, path, titles=None):
         # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
         image = image[:, :, 45]
 
-        ax.imshow(image, cmap='gray')
-        ax.axis('off')
+        ax.imshow(image, cmap="gray")
+        ax.axis("off")
         if titles is not None:
             ax.set_title(titles[i])
 
@@ -38,8 +70,8 @@ def plot_single_image(image_id, dataset, path, title=None):
     # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
     image = image[:, :, 45]
 
-    ax.imshow(image, cmap='gray')
-    ax.axis('off')
+    ax.imshow(image, cmap="gray")
+    ax.axis("off")
     if title is not None:
         ax.set_title(title)
 
@@ -52,23 +84,23 @@ def plot_statistics_versus(
     stat_1, stat_2, xaxis_name, yaxis_name, title, dataframe, path, annotate=False
 ):
     # Get correct predictions and incorrect predictions dataframes
-    corr_df = dataframe[dataframe['correct']]
-    incorr_df = dataframe[~dataframe['correct']]
+    corr_df = dataframe[dataframe["correct"]]
+    incorr_df = dataframe[~dataframe["correct"]]
 
     # Plot the correct and incorrect predictions
     fig, ax = plt.subplots()
-    ax.scatter(corr_df[stat_1], corr_df[stat_2], c='green', label='Correct')
-    ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c='red', label='Incorrect')
+    ax.scatter(corr_df[stat_1], corr_df[stat_2], c="green", label="Correct")
+    ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c="red", label="Incorrect")
     ax.legend()
     ax.set_xlabel(xaxis_name)
     ax.set_ylabel(yaxis_name)
     ax.set_title(title)
 
     if annotate:
-        print('DEBUG -- REMOVE: Annotating')
+        print("DEBUG -- REMOVE: Annotating")
         # label correct points green
         for row in dataframe[[stat_1, stat_2]].itertuples():
-            plt.text(row[1], row[2], row[0], fontsize=6, color='black')
+            plt.text(row[1], row[2], row[0], fontsize=6, color="black")
 
     plt.savefig(path)
 
@@ -87,23 +119,25 @@ def get_model_predictions(models, data):
 
 
 def load_models_v2(folder, device):
-    glob_path = os.path.join(folder, '*.pt')
+    glob_path = os.path.join(folder, "*.pt")
     model_files = glob.glob(glob_path)
     model_dict = {}
 
     for model_file in model_files:
+        print(model_file)
         model = torch.load(model_file, map_location=device)
-        model_id = os.path.basename(model_file).split('_')[0]
+        model_id = os.path.basename(model_file).split("_")[0]
         model_dict[model_id] = model
 
     if len(model_dict) == 0:
-        raise FileNotFoundError('No models found in the specified directory: ' + folder)
+        raise FileNotFoundError("No models found in the specified directory: " + folder)
 
     return model_dict
 
 
 # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
 def preprocess_data(data, device):
+
     mri, xls = data
     mri = mri.unsqueeze(0).to(device)
     xls = xls.unsqueeze(0).to(device)
@@ -160,13 +194,13 @@ def calculate_statistics(predictions):
     # Create DataFrame with columns for each statistic
     stats_df = pd.DataFrame(
         columns=[
-            'mean',
-            'stdev',
-            'entropy',
-            'confidence',
-            'correct',
-            'predicted',
-            'actual',
+            "mean",
+            "stdev",
+            "entropy",
+            "confidence",
+            "correct",
+            "predicted",
+            "actual",
         ]
     )
 
@@ -215,7 +249,7 @@ def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
     )
 
     # Dictionary of form {threshold: {metric: value}}
-    thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
+    thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=["accuracy", "f1"])
     for percentile, value in percentile_df.items():
         # Filter the statistics
         if low_to_high:
@@ -224,11 +258,11 @@ def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
             filtered_statistics = statistics[statistics[statistic_name] >= value]
 
         # Calculate accuracy and f1 score
-        accuracy = filtered_statistics['correct'].mean()
+        accuracy = filtered_statistics["correct"].mean()
 
         # Calculate F1 score
-        predicted = filtered_statistics['predicted'].values
-        actual = filtered_statistics['actual'].values
+        predicted = filtered_statistics["predicted"].values
+        actual = filtered_statistics["actual"].values
 
         f1 = metrics.f1_score(actual, predicted)
 
@@ -251,7 +285,7 @@ def plot_threshold_analysis(
     metric = list(thresholds_metric.values)
 
     # Plot the metric against the threshold
-    plt.plot(thresholds, metric, 'bo-', label='Ensemble')
+    plt.plot(thresholds, metric, "bo-", label="Ensemble")
 
     if additional_set is not None:
         # Get the thresholds and metrics
@@ -259,7 +293,7 @@ def plot_threshold_analysis(
         metric = list(additional_set.values)
 
         # Plot the metric against the threshold
-        plt.plot(thresholds, metric, 'rx-', label='Individual')
+        plt.plot(thresholds, metric, "rx-", label="Individual")
 
     if flip:
         ax.invert_xaxis()
@@ -286,16 +320,16 @@ def common_entries(*dcts):
 
 # Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
 def calculate_overall_statistics(ensemble_statistics):
-    predicted = ensemble_statistics['predicted']
-    actual = ensemble_statistics['actual']
+    predicted = ensemble_statistics["predicted"]
+    actual = ensemble_statistics["actual"]
 
     # New dataframe to store the statistics
     stats_df = pd.DataFrame(
-        columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']
-    ).set_index('stat')
+        columns=["stat", "ECE", "MCE", "Brier Score", "NLL"]
+    ).set_index("stat")
 
     # Loop through and calculate the ECE, MCE, Brier Score, and NLL
-    for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
+    for stat in ["confidence", "entropy", "stdev", "raw_confidence"]:
         ece = met.ECE(predicted, ensemble_statistics[stat], actual)
         mce = met.MCE(predicted, ensemble_statistics[stat], actual)
         brier = met.brier_binary(ensemble_statistics[stat], actual)
@@ -308,15 +342,16 @@ def calculate_overall_statistics(ensemble_statistics):
 
 # CONFIGURATION
 def load_config():
-    if os.getenv('ADL_CONFIG_PATH') is None:
-        with open('config.toml', 'rb') as f:
+    if os.getenv("ADL_CONFIG_PATH") is None:
+        with open("config.toml", "rb") as f:
             config = toml.load(f)
     else:
-        with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+        with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
             config = toml.load(f)
 
     return config
 
+
 def prune_dataset(dataset, pruned_ids):
     pruned_dataset = []
     for i, (data, target) in enumerate(dataset):
@@ -331,59 +366,58 @@ def main():
 
     ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
 
-    V3_PATH = ENSEMBLE_PATH + '/v3'
+    V3_PATH = ENSEMBLE_PATH + "/v3"
 
     # Create the directory if it does not exist
     if not os.path.exists(V3_PATH):
         os.makedirs(V3_PATH)
 
     # Load the models
-    device = torch.device(config['training']['device'])
-    models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
+    device = torch.device(config["training"]["device"])
+    models = load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
 
     # Load Dataset
-    dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
-        f'{ENSEMBLE_PATH}/val_dataset.pt'
-    )
-
+    dataset = torch.load(
+        f"{ENSEMBLE_PATH}/test_dataset.pt", weights_only=False
+    ) + torch.load(f"{ENSEMBLE_PATH}/val_dataset.pt", weights_only=False)
 
-    if config['ensemble']['run_models']:
+    if config["ensemble"]["run_models"]:
         # Get thre predicitons of the ensemble
         ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
 
         # Save to file using pickle
-        with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
+        with open(f"{V3_PATH}/ensemble_predictions.pk", "wb") as f:
             pk.dump(ensemble_predictions, f)
     else:
         # Load the predictions from file
-        with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
+        with open(f"{V3_PATH}/ensemble_predictions.pk", "rb") as f:
             ensemble_predictions = pk.load(f)
 
     # Get the statistics and thresholds of the ensemble
     ensemble_statistics = calculate_statistics(ensemble_predictions)
 
     stdev_thresholds = conduct_threshold_analysis(
-        ensemble_statistics, 'stdev', low_to_high=True
+        ensemble_statistics, "stdev", low_to_high=True
     )
     entropy_thresholds = conduct_threshold_analysis(
-        ensemble_statistics, 'entropy', low_to_high=True
+        ensemble_statistics, "entropy", low_to_high=True
     )
     confidence_thresholds = conduct_threshold_analysis(
-        ensemble_statistics, 'confidence', low_to_high=False
+        ensemble_statistics, "confidence", low_to_high=False
     )
 
-    raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
-    ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
+    raw_confidence = ensemble_statistics["confidence"].apply(lambda x: (x / 2) + 0.5)
+    ensemble_statistics.insert(4, "raw_confidence", raw_confidence)
 
     # Plot confidence vs standard deviation
     plot_statistics_versus(
-        'raw_confidence',
-        'stdev',
-        'Confidence',
-        'Standard Deviation',
-        'Confidence vs Standard Deviation',
+        "raw_confidence",
+        "stdev",
+        "Confidence",
+        "Standard Deviation",
+        "Confidence vs Standard Deviation",
         ensemble_statistics,
-        f'{V3_PATH}/confidence_vs_stdev.png',
+        f"{V3_PATH}/confidence_vs_stdev.png",
         annotate=True,
     )
 
@@ -393,35 +427,35 @@ def main():
         [279, 202, 28, 107, 27, 121],
         dataset,
         2,
-        f'{V3_PATH}/image_grid.png',
+        f"{V3_PATH}/image_grid.png",
         titles=[
-            'Weird: 279',
-            'Weird: 202',
-            'Weird: 28',
-            'Normal: 107',
-            'Normal: 27',
-            'Normal: 121',
+            "Weird: 279",
+            "Weird: 202",
+            "Weird: 28",
+            "Normal: 107",
+            "Normal: 27",
+            "Normal: 121",
         ],
     )
 
     # Filter dataset for where confidence < .7 and stdev < .1
     weird_results = ensemble_statistics.loc[
         (
-            (ensemble_statistics['raw_confidence'] < 0.7)
-            & (ensemble_statistics['stdev'] < 0.1)
+            (ensemble_statistics["raw_confidence"] < 0.7)
+            & (ensemble_statistics["stdev"] < 0.1)
         )
     ]
     normal_results = ensemble_statistics.loc[
         ~(
-            (ensemble_statistics['raw_confidence'] < 0.7)
-            & (ensemble_statistics['stdev'] < 0.1)
+            (ensemble_statistics["raw_confidence"] < 0.7)
+            & (ensemble_statistics["stdev"] < 0.1)
         )
     ]
     # Get the data ids in a list
     # Plot the images
-    if not os.path.exists(f'{V3_PATH}/images'):
-        os.makedirs(f'{V3_PATH}/images/weird')
-        os.makedirs(f'{V3_PATH}/images/normal')
+    if not os.path.exists(f"{V3_PATH}/images"):
+        os.makedirs(f"{V3_PATH}/images/weird")
+        os.makedirs(f"{V3_PATH}/images/normal")
 
     for i in weird_results.itertuples():
         id = i.Index
@@ -431,8 +465,8 @@ def main():
         plot_single_image(
             id,
             dataset,
-            f'{V3_PATH}/images/weird/{id}.png',
-            title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
+            f"{V3_PATH}/images/weird/{id}.png",
+            title=f"ID: {id}, Confidence: {conf}, Stdev: {stdev}",
         )
 
     for i in normal_results.itertuples():
@@ -443,8 +477,8 @@ def main():
         plot_single_image(
             id,
             dataset,
-            f'{V3_PATH}/images/normal/{id}.png',
-            title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
+            f"{V3_PATH}/images/normal/{id}.png",
+            title=f"ID: {id}, Confidence: {conf}, Stdev: {stdev}",
         )
 
     # Calculate overall statistics
@@ -454,81 +488,81 @@ def main():
     print(overall_statistics)
 
     # Print overall ensemble statistics
-    print('Ensemble Statistics')
+    print("Ensemble Statistics")
     print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
     print(
         f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
     )
 
     # Get the predictions, statistics and thresholds an individual model
-    indv_id = config['ensemble']['individual_id']
+    indv_id = config["ensemble"]["individual_id"]
     indv_predictions = select_individual_model(ensemble_predictions, indv_id)
     indv_statistics = calculate_statistics(indv_predictions)
 
     # Calculate entropy and confidence thresholds for individual model
     indv_entropy_thresholds = conduct_threshold_analysis(
-        indv_statistics, 'entropy', low_to_high=True
+        indv_statistics, "entropy", low_to_high=True
     )
     indv_confidence_thresholds = conduct_threshold_analysis(
-        indv_statistics, 'confidence', low_to_high=False
+        indv_statistics, "confidence", low_to_high=False
     )
 
     # Plot the threshold analysis for standard deviation
     plot_threshold_analysis(
-        stdev_thresholds['accuracy'],
-        'Stdev Threshold Analysis for Accuracy',
-        'Stdev Threshold',
-        'Accuracy',
-        f'{V3_PATH}/stdev_threshold_analysis.png',
+        stdev_thresholds["accuracy"],
+        "Stdev Threshold Analysis for Accuracy",
+        "Stdev Threshold",
+        "Accuracy",
+        f"{V3_PATH}/stdev_threshold_analysis.png",
         flip=True,
     )
     plot_threshold_analysis(
-        stdev_thresholds['f1'],
-        'Stdev Threshold Analysis for F1 Score',
-        'Stdev Threshold',
-        'F1 Score',
-        f'{V3_PATH}/stdev_threshold_analysis_f1.png',
+        stdev_thresholds["f1"],
+        "Stdev Threshold Analysis for F1 Score",
+        "Stdev Threshold",
+        "F1 Score",
+        f"{V3_PATH}/stdev_threshold_analysis_f1.png",
         flip=True,
     )
 
     # Plot the threshold analysis for entropy
     plot_threshold_analysis(
-        entropy_thresholds['accuracy'],
-        'Entropy Threshold Analysis for Accuracy',
-        'Entropy Threshold',
-        'Accuracy',
-        f'{V3_PATH}/entropy_threshold_analysis.png',
-        indv_entropy_thresholds['accuracy'],
+        entropy_thresholds["accuracy"],
+        "Entropy Threshold Analysis for Accuracy",
+        "Entropy Threshold",
+        "Accuracy",
+        f"{V3_PATH}/entropy_threshold_analysis.png",
+        indv_entropy_thresholds["accuracy"],
         flip=True,
     )
     plot_threshold_analysis(
-        entropy_thresholds['f1'],
-        'Entropy Threshold Analysis for F1 Score',
-        'Entropy Threshold',
-        'F1 Score',
-        f'{V3_PATH}/entropy_threshold_analysis_f1.png',
-        indv_entropy_thresholds['f1'],
+        entropy_thresholds["f1"],
+        "Entropy Threshold Analysis for F1 Score",
+        "Entropy Threshold",
+        "F1 Score",
+        f"{V3_PATH}/entropy_threshold_analysis_f1.png",
+        indv_entropy_thresholds["f1"],
         flip=True,
     )
 
     # Plot the threshold analysis for confidence
     plot_threshold_analysis(
-        confidence_thresholds['accuracy'],
-        'Confidence Threshold Analysis for Accuracy',
-        'Confidence Threshold',
-        'Accuracy',
-        f'{V3_PATH}/confidence_threshold_analysis.png',
-        indv_confidence_thresholds['accuracy'],
+        confidence_thresholds["accuracy"],
+        "Confidence Threshold Analysis for Accuracy",
+        "Confidence Threshold",
+        "Accuracy",
+        f"{V3_PATH}/confidence_threshold_analysis.png",
+        indv_confidence_thresholds["accuracy"],
     )
     plot_threshold_analysis(
-        confidence_thresholds['f1'],
-        'Confidence Threshold Analysis for F1 Score',
-        'Confidence Threshold',
-        'F1 Score',
-        f'{V3_PATH}/confidence_threshold_analysis_f1.png',
-        indv_confidence_thresholds['f1'],
+        confidence_thresholds["f1"],
+        "Confidence Threshold Analysis for F1 Score",
+        "Confidence Threshold",
+        "F1 Score",
+        f"{V3_PATH}/confidence_threshold_analysis_f1.png",
+        indv_confidence_thresholds["f1"],
     )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

+ 235 - 230
threshold_xarray.py

@@ -10,7 +10,9 @@ from tqdm import tqdm
 import utils.metrics as met
 import matplotlib.pyplot as plt
 import matplotlib.ticker as mtick
+import utils.models.cnn
 
+torch.serialization.safe_globals([utils.models.cnn.CNN])
 
 
 # The datastructures for this file are as follows
@@ -29,11 +31,11 @@ import matplotlib.ticker as mtick
 
 # Loads configuration dictionary
 def load_config():
-    if os.getenv('ADL_CONFIG_PATH') is None:
-        with open('config.toml', 'rb') as f:
+    if os.getenv("ADL_CONFIG_PATH") is None:
+        with open("config.toml", "rb") as f:
             config = toml.load(f)
     else:
-        with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+        with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
             config = toml.load(f)
 
     return config
@@ -41,17 +43,20 @@ def load_config():
 
 # Loads models into a dictionary
 def load_models_v2(folder, device):
-    glob_path = os.path.join(folder, '*.pt')
+    glob_path = os.path.join(folder, "*.pt")
     model_files = glob.glob(glob_path)
     model_dict = {}
 
     for model_file in model_files:
+        with open(model_file, "r") as f:
+            print(torch.serialization.get_unsafe_globals_in_checkpoint(f))
+
         model = torch.load(model_file, map_location=device)
-        model_id = os.path.basename(model_file).split('_')[0]
+        model_id = os.path.basename(model_file).split("_")[0]
         model_dict[model_id] = model
 
     if len(model_dict) == 0:
-        raise FileNotFoundError('No models found in the specified directory: ' + folder)
+        raise FileNotFoundError("No models found in the specified directory: " + folder)
 
     return model_dict
 
@@ -67,8 +72,8 @@ def preprocess_data(data, device):
 # Loads datasets and returns concatenated test and validation datasets
 def load_datasets(ensemble_path):
     return (
-        torch.load(f'{ensemble_path}/test_dataset.pt'),
-        torch.load(f'{ensemble_path}/val_dataset.pt'),
+        torch.load(f"{ensemble_path}/test_dataset.pt"),
+        torch.load(f"{ensemble_path}/val_dataset.pt"),
     )
 
 
@@ -77,21 +82,21 @@ def get_ensemble_predictions(models, dataset, device, id_offset=0):
     zeros = np.zeros((len(dataset), len(models), 4))
     predictions = xr.DataArray(
         zeros,
-        dims=('data_id', 'model_id', 'prediction_value'),
+        dims=("data_id", "model_id", "prediction_value"),
         coords={
-            'data_id': range(id_offset, len(dataset) + id_offset),
-            'model_id': list(models.keys()),
-            'prediction_value': [
-                'negative_prediction',
-                'positive_prediction',
-                'negative_actual',
-                'positive_actual',
+            "data_id": range(id_offset, len(dataset) + id_offset),
+            "model_id": list(models.keys()),
+            "prediction_value": [
+                "negative_prediction",
+                "positive_prediction",
+                "negative_actual",
+                "positive_actual",
             ],
         },
     )
 
     for data_id, (data, target) in tqdm(
-        enumerate(dataset), total=len(dataset), unit='images'
+        enumerate(dataset), total=len(dataset), unit="images"
     ):
         dat = preprocess_data(data, device)
         actual = list(target.cpu().numpy())
@@ -101,8 +106,8 @@ def get_ensemble_predictions(models, dataset, device, id_offset=0):
                 prediction = output.cpu().numpy().tolist()[0]
 
                 predictions.loc[
-                    {'data_id': data_id + id_offset, 'model_id': model_id}
-                ] = prediction + actual
+                    {"data_id": data_id + id_offset, "model_id": model_id}
+                ] = (prediction + actual)
 
     return predictions
 
@@ -113,27 +118,27 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
 
     ensemble_statistics = xr.DataArray(
         zeros,
-        dims=('data_id', 'statistic'),
+        dims=("data_id", "statistic"),
         coords={
-            'data_id': predictions.data_id,
-            'statistic': [
-                'mean',
-                'stdev',
-                'entropy',
-                'confidence',
-                'correct',
-                'predicted',
-                'actual',
+            "data_id": predictions.data_id,
+            "statistic": [
+                "mean",
+                "stdev",
+                "entropy",
+                "confidence",
+                "correct",
+                "predicted",
+                "actual",
             ],
         },
     )
 
     for data_id in predictions.data_id:
-        data = predictions.loc[{'data_id': data_id}]
-        mean = data.mean(dim='model_id')[
+        data = predictions.loc[{"data_id": data_id}]
+        mean = data.mean(dim="model_id")[
             0:2
         ]  # Only take the predictions, not the actual
-        stdev = data.std(dim='model_id')[
+        stdev = data.std(dim="model_id")[
             1
         ]  # Only need the standard deviation of the postive prediction
         entropy = (-mean * np.log(mean)).sum()
@@ -142,11 +147,11 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
         confidence = mean.max()
 
         # only need one of the actual values, since they are all the same, just get the first actual_positive
-        actual = data.loc[{'prediction_value': 'positive_actual'}][0]
+        actual = data.loc[{"prediction_value": "positive_actual"}][0]
         predicted = mean.argmax()
         correct = actual == predicted
 
-        ensemble_statistics.loc[{'data_id': data_id}] = [
+        ensemble_statistics.loc[{"data_id": data_id}] = [
             mean[1],
             stdev,
             entropy,
@@ -162,15 +167,15 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
 # Compute the thresholded predictions given an array of predictions
 def compute_thresholded_predictions(input_stats: xr.DataArray):
     quantiles = np.linspace(0.00, 1.00, 21) * 100
-    metrics = ['accuracy', 'f1']
-    statistics = ['stdev', 'entropy', 'confidence']
+    metrics = ["accuracy", "f1"]
+    statistics = ["stdev", "entropy", "confidence"]
 
     zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
 
     thresholded_predictions = xr.DataArray(
         zeros,
-        dims=('quantile', 'statistic', 'metric'),
-        coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
+        dims=("quantile", "statistic", "metric"),
+        coords={"quantile": quantiles, "statistic": statistics, "metric": metrics},
     )
 
     for statistic in statistics:
@@ -197,7 +202,7 @@ def compute_thresholded_predictions(input_stats: xr.DataArray):
 
             for metric in metrics:
                 thresholded_predictions.loc[
-                    {'quantile': quantile, 'statistic': statistic, 'metric': metric}
+                    {"quantile": quantile, "statistic": statistic, "metric": metric}
                 ] = compute_metric(filtered_data, metric)
 
     return thresholded_predictions
@@ -208,26 +213,26 @@ def compute_thresholded_predictions(input_stats: xr.DataArray):
 # So we threshold confidence low to high, entropy and stdev high to low
 # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
 def low_to_high(stat):
-    return stat in ['confidence']
+    return stat in ["confidence"]
 
 
 # Compute a given metric on a DataArray of statstics
 def compute_metric(arr, metric):
-    if metric == 'accuracy':
-        return np.mean(arr.loc[{'statistic': 'correct'}])
-    elif metric == 'f1':
+    if metric == "accuracy":
+        return np.mean(arr.loc[{"statistic": "correct"}])
+    elif metric == "f1":
         return met.F1(
-            arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
+            arr.loc[{"statistic": "predicted"}], arr.loc[{"statistic": "actual"}]
         )
-    elif metric == 'ece':
-        true_labels = arr.loc[{'statistic': 'actual'}].values
-        predicted_labels = arr.loc[{'statistic': 'predicted'}].values
-        confidences = arr.loc[{'statistic': 'confidence'}].values
+    elif metric == "ece":
+        true_labels = arr.loc[{"statistic": "actual"}].values
+        predicted_labels = arr.loc[{"statistic": "predicted"}].values
+        confidences = arr.loc[{"statistic": "confidence"}].values
 
         return calculate_ece_stats(confidences, predicted_labels, true_labels)
 
     else:
-        raise ValueError('Invalid metric: ' + metric)
+        raise ValueError("Invalid metric: " + metric)
 
 
 # Graph a thresholded prediction for a given statistic and metric
@@ -236,11 +241,11 @@ def graph_thresholded_prediction(
 ):
     data = thresholded_predictions.sel(statistic=statistic, metric=metric)
 
-    x_data = data.coords['quantile'].values
+    x_data = data.coords["quantile"].values
     y_data = data.values
 
     fig, ax = plt.subplots()
-    ax.plot(x_data, y_data, 'bx-', label='Ensemble')
+    ax.plot(x_data, y_data, "bx-", label="Ensemble")
     ax.set_title(title)
     ax.set_xlabel(xlabel)
     ax.set_ylabel(ylabel)
@@ -257,68 +262,68 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
     # Confidence Accuracy
     graph_thresholded_prediction(
         thresholded_predictions,
-        'confidence',
-        'accuracy',
-        f'{save_path}/confidence_accuracy.png',
-        'Coverage Analysis of Confidence vs. Accuracy',
-        'Minimum Confidence Percentile Threshold',
-        'Accuracy',
+        "confidence",
+        "accuracy",
+        f"{save_path}/confidence_accuracy.png",
+        "Coverage Analysis of Confidence vs. Accuracy",
+        "Minimum Confidence Percentile Threshold",
+        "Accuracy",
     )
 
     # Confidence F1
     graph_thresholded_prediction(
         thresholded_predictions,
-        'confidence',
-        'f1',
-        f'{save_path}/confidence_f1.png',
-        'Coverage Analysis of Confidence vs. F1 Score',
-        'Minimum Confidence Percentile Threshold',
-        'F1 Score',
+        "confidence",
+        "f1",
+        f"{save_path}/confidence_f1.png",
+        "Coverage Analysis of Confidence vs. F1 Score",
+        "Minimum Confidence Percentile Threshold",
+        "F1 Score",
     )
 
     # Entropy Accuracy
     graph_thresholded_prediction(
         thresholded_predictions,
-        'entropy',
-        'accuracy',
-        f'{save_path}/entropy_accuracy.png',
-        'Coverage Analysis of Entropy vs. Accuracy',
-        'Maximum Entropy Percentile Threshold',
-        'Accuracy',
+        "entropy",
+        "accuracy",
+        f"{save_path}/entropy_accuracy.png",
+        "Coverage Analysis of Entropy vs. Accuracy",
+        "Maximum Entropy Percentile Threshold",
+        "Accuracy",
     )
 
     # Entropy F1
 
     graph_thresholded_prediction(
         thresholded_predictions,
-        'entropy',
-        'f1',
-        f'{save_path}/entropy_f1.png',
-        'Coverage Analysis of Entropy vs. F1 Score',
-        'Maximum Entropy Percentile Threshold',
-        'F1 Score',
+        "entropy",
+        "f1",
+        f"{save_path}/entropy_f1.png",
+        "Coverage Analysis of Entropy vs. F1 Score",
+        "Maximum Entropy Percentile Threshold",
+        "F1 Score",
     )
 
     # Stdev Accuracy
     graph_thresholded_prediction(
         thresholded_predictions,
-        'stdev',
-        'accuracy',
-        f'{save_path}/stdev_accuracy.png',
-        'Coverage Analysis of Standard Deviation vs. Accuracy',
-        'Maximum Standard Deviation Percentile Threshold',
-        'Accuracy',
+        "stdev",
+        "accuracy",
+        f"{save_path}/stdev_accuracy.png",
+        "Coverage Analysis of Standard Deviation vs. Accuracy",
+        "Maximum Standard Deviation Percentile Threshold",
+        "Accuracy",
     )
 
     # Stdev F1
     graph_thresholded_prediction(
         thresholded_predictions,
-        'stdev',
-        'f1',
-        f'{save_path}/stdev_f1.png',
-        'Coverage Analysis of Standard Deviation vs. F1 Score',
-        'Maximum Standard Deviation Percentile Threshold',
-        'F1',
+        "stdev",
+        "f1",
+        f"{save_path}/stdev_f1.png",
+        "Coverage Analysis of Standard Deviation vs. F1 Score",
+        "Maximum Standard Deviation Percentile Threshold",
+        "F1",
     )
 
 
@@ -326,13 +331,13 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
 def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
     # Filter for correct predictions
     c_stats = stats.where(
-        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
+        stats.data_id.isin(np.where((stats.sel(statistic="correct") == 1).values)),
         drop=True,
     )
 
     # Filter for incorrect predictions
     i_stats = stats.where(
-        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
+        stats.data_id.isin(np.where((stats.sel(statistic="correct") == 0).values)),
         drop=True,
     )
 
@@ -344,8 +349,8 @@ def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
     y_data_i = i_stats.sel(statistic=y_stat).values
 
     fig, ax = plt.subplots()
-    ax.plot(x_data_c, y_data_c, 'go', label='Correct')
-    ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
+    ax.plot(x_data_c, y_data_c, "go", label="Correct")
+    ax.plot(x_data_i, y_data_i, "ro", label="Incorrect")
     ax.set_title(title)
     ax.set_xlabel(xlabel)
     ax.set_ylabel(ylabel)
@@ -365,26 +370,26 @@ def compute_individual_statistics(predictions: xr.DataArray):
 
     indv_statistics = xr.DataArray(
         zeros,
-        dims=('data_id', 'model_id', 'statistic'),
+        dims=("data_id", "model_id", "statistic"),
         coords={
-            'data_id': predictions.data_id,
-            'model_id': predictions.model_id,
-            'statistic': [
-                'mean',
-                'entropy',
-                'confidence',
-                'correct',
-                'predicted',
-                'actual',
+            "data_id": predictions.data_id,
+            "model_id": predictions.model_id,
+            "statistic": [
+                "mean",
+                "entropy",
+                "confidence",
+                "correct",
+                "predicted",
+                "actual",
             ],
         },
     )
 
     for data_id in tqdm(
-        predictions.data_id, total=len(predictions.data_id), unit='images'
+        predictions.data_id, total=len(predictions.data_id), unit="images"
     ):
         for model_id in predictions.model_id:
-            data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
+            data = predictions.loc[{"data_id": data_id, "model_id": model_id}]
             mean = data[0:2]
             entropy = (-mean * np.log(mean)).sum()
             confidence = mean.max()
@@ -392,7 +397,7 @@ def compute_individual_statistics(predictions: xr.DataArray):
             predicted = mean.argmax()
             correct = actual == predicted
 
-            indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
+            indv_statistics.loc[{"data_id": data_id, "model_id": model_id}] = [
                 mean[1],
                 entropy,
                 confidence,
@@ -407,8 +412,8 @@ def compute_individual_statistics(predictions: xr.DataArray):
 # Compute individual model thresholds
 def compute_individual_thresholds(input_stats: xr.DataArray):
     quantiles = np.linspace(0.05, 0.95, 19) * 100
-    metrics = ['accuracy', 'f1']
-    statistics = ['entropy', 'confidence']
+    metrics = ["accuracy", "f1"]
+    statistics = ["entropy", "confidence"]
 
     zeros = np.zeros(
         (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
@@ -416,17 +421,17 @@ def compute_individual_thresholds(input_stats: xr.DataArray):
 
     indv_thresholds = xr.DataArray(
         zeros,
-        dims=('model_id', 'quantile', 'statistic', 'metric'),
+        dims=("model_id", "quantile", "statistic", "metric"),
         coords={
-            'model_id': input_stats.model_id,
-            'quantile': quantiles,
-            'statistic': statistics,
-            'metric': metrics,
+            "model_id": input_stats.model_id,
+            "quantile": quantiles,
+            "statistic": statistics,
+            "metric": metrics,
         },
     )
 
     for model_id in tqdm(
-        input_stats.model_id, total=len(input_stats.model_id), unit='models'
+        input_stats.model_id, total=len(input_stats.model_id), unit="models"
     ):
         for statistic in statistics:
             # First, we must compute the quantiles for the statistic
@@ -457,10 +462,10 @@ def compute_individual_thresholds(input_stats: xr.DataArray):
                 for metric in metrics:
                     indv_thresholds.loc[
                         {
-                            'model_id': model_id,
-                            'quantile': quantile,
-                            'statistic': statistic,
-                            'metric': metric,
+                            "model_id": model_id,
+                            "quantile": quantile,
+                            "statistic": statistic,
+                            "metric": metric,
                         }
                     ] = compute_metric(filtered_data, metric)
 
@@ -481,18 +486,18 @@ def graph_individual_thresholded_predictions(
     data = indv_thresholds.sel(statistic=statistic, metric=metric)
     e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
 
-    x_data = data.coords['quantile'].values
+    x_data = data.coords["quantile"].values
     y_data = data.values
 
-    e_x_data = e_data.coords['quantile'].values
+    e_x_data = e_data.coords["quantile"].values
     e_y_data = e_data.values
 
     fig, ax = plt.subplots()
-    for model_id in data.coords['model_id'].values:
+    for model_id in data.coords["model_id"].values:
         model_data = data.sel(model_id=model_id)
         ax.plot(x_data, model_data)
 
-    ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
+    ax.plot(e_x_data, e_y_data, "kx-", label="Ensemble")
 
     ax.set_title(title)
     ax.set_xlabel(xlabel)
@@ -514,48 +519,48 @@ def graph_all_individual_thresholded_predictions(
     graph_individual_thresholded_predictions(
         indv_thresholds,
         ensemble_thresholds,
-        'confidence',
-        'accuracy',
-        f'{save_path}/indv/confidence_accuracy.png',
-        'Coverage Analysis of Confidence vs. Accuracy for All Models',
-        'Minumum Confidence Percentile Threshold',
-        'Accuracy',
+        "confidence",
+        "accuracy",
+        f"{save_path}/indv/confidence_accuracy.png",
+        "Coverage Analysis of Confidence vs. Accuracy for All Models",
+        "Minumum Confidence Percentile Threshold",
+        "Accuracy",
     )
 
     # Confidence F1
     graph_individual_thresholded_predictions(
         indv_thresholds,
         ensemble_thresholds,
-        'confidence',
-        'f1',
-        f'{save_path}/indv/confidence_f1.png',
-        'Coverage Analysis of Confidence vs. F1 Score for All Models',
-        'Minimum Confidence Percentile Threshold',
-        'F1 Score',
+        "confidence",
+        "f1",
+        f"{save_path}/indv/confidence_f1.png",
+        "Coverage Analysis of Confidence vs. F1 Score for All Models",
+        "Minimum Confidence Percentile Threshold",
+        "F1 Score",
     )
 
     # Entropy Accuracy
     graph_individual_thresholded_predictions(
         indv_thresholds,
         ensemble_thresholds,
-        'entropy',
-        'accuracy',
-        f'{save_path}/indv/entropy_accuracy.png',
-        'Coverage Analysis of Entropy vs. Accuracy for All Models',
-        'Maximum Entropy Percentile Threshold',
-        'Accuracy',
+        "entropy",
+        "accuracy",
+        f"{save_path}/indv/entropy_accuracy.png",
+        "Coverage Analysis of Entropy vs. Accuracy for All Models",
+        "Maximum Entropy Percentile Threshold",
+        "Accuracy",
     )
 
     # Entropy F1
     graph_individual_thresholded_predictions(
         indv_thresholds,
         ensemble_thresholds,
-        'entropy',
-        'f1',
-        f'{save_path}/indv/entropy_f1.png',
-        'Coverage Analysis of Entropy vs. F1 Score for All Models',
-        'Maximum Entropy Percentile Threshold',
-        'F1 Score',
+        "entropy",
+        "f1",
+        f"{save_path}/indv/entropy_f1.png",
+        "Coverage Analysis of Entropy vs. F1 Score for All Models",
+        "Maximum Entropy Percentile Threshold",
+        "F1 Score",
     )
 
 
@@ -570,38 +575,38 @@ def calculate_subset_statistics(predictions: xr.DataArray):
 
     subset_stats = xr.DataArray(
         zeros,
-        dims=('data_id', 'model_count', 'statistic'),
+        dims=("data_id", "model_count", "statistic"),
         coords={
-            'data_id': predictions.data_id,
-            'model_count': subsets,
-            'statistic': [
-                'mean',
-                'stdev',
-                'entropy',
-                'confidence',
-                'correct',
-                'predicted',
-                'actual',
+            "data_id": predictions.data_id,
+            "model_count": subsets,
+            "statistic": [
+                "mean",
+                "stdev",
+                "entropy",
+                "confidence",
+                "correct",
+                "predicted",
+                "actual",
             ],
         },
     )
 
     for data_id in tqdm(
-        predictions.data_id, total=len(predictions.data_id), unit='images'
+        predictions.data_id, total=len(predictions.data_id), unit="images"
     ):
         for subset in subsets:
             data = predictions.sel(
                 data_id=data_id, model_id=predictions.model_id[:subset]
             )
-            mean = data.mean(dim='model_id')[0:2]
-            stdev = data.std(dim='model_id')[1]
+            mean = data.mean(dim="model_id")[0:2]
+            stdev = data.std(dim="model_id")[1]
             entropy = (-mean * np.log(mean)).sum()
             confidence = mean.max()
             actual = data[0][3]
             predicted = mean.argmax()
             correct = actual == predicted
 
-            subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
+            subset_stats.loc[{"data_id": data_id, "model_count": subset}] = [
                 mean[1],
                 stdev,
                 entropy,
@@ -617,24 +622,24 @@ def calculate_subset_statistics(predictions: xr.DataArray):
 # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
 def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
     subsets = subset_stats.model_count
-    stats = ['accuracy', 'f1', 'ece']
+    stats = ["accuracy", "f1", "ece"]
 
     zeros = np.zeros((len(subsets), len(stats)))
 
     sens_analysis = xr.DataArray(
         zeros,
-        dims=('model_count', 'statistic'),
-        coords={'model_count': subsets, 'statistic': stats},
+        dims=("model_count", "statistic"),
+        coords={"model_count": subsets, "statistic": stats},
     )
 
-    for subset in tqdm(subsets, total=len(subsets), unit='model subsets'):
+    for subset in tqdm(subsets, total=len(subsets), unit="model subsets"):
 
         data = subset_stats.sel(model_count=subset)
-        acc = compute_metric(data, 'accuracy').item()
-        f1 = compute_metric(data, 'f1').item()
-        ece = compute_metric(data, 'ece').item()
+        acc = compute_metric(data, "accuracy").item()
+        f1 = compute_metric(data, "f1").item()
+        ece = compute_metric(data, "ece").item()
 
-        sens_analysis.loc[{'model_count': subset.item()}] = [acc, f1, ece]
+        sens_analysis.loc[{"model_count": subset.item()}] = [acc, f1, ece]
 
     return sens_analysis
 
@@ -644,7 +649,7 @@ def graph_sensitivity_analysis(
 ):
     data = sens_analysis.sel(statistic=statistic)
 
-    xdata = data.coords['model_count'].values
+    xdata = data.coords["model_count"].values
     ydata = data.values
 
     fig, ax = plt.subplots()
@@ -657,10 +662,10 @@ def graph_sensitivity_analysis(
 
 
 def calculate_overall_stats(ensemble_statistics: xr.DataArray):
-    accuracy = compute_metric(ensemble_statistics, 'accuracy')
-    f1 = compute_metric(ensemble_statistics, 'f1')
+    accuracy = compute_metric(ensemble_statistics, "accuracy")
+    f1 = compute_metric(ensemble_statistics, "f1")
 
-    return {'accuracy': accuracy.item(), 'f1': f1.item()}
+    return {"accuracy": accuracy.item(), "f1": f1.item()}
 
 
 # https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
@@ -693,130 +698,130 @@ def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
 
 # Main Function
 def main():
-    print('Loading Config...')
+    print("Loading Config...")
     config = load_config()
     ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
-    V4_PATH = ENSEMBLE_PATH + '/v4'
+    V4_PATH = ENSEMBLE_PATH + "/v4"
 
     if not os.path.exists(V4_PATH):
         os.makedirs(V4_PATH)
-    print('Config Loaded')
+    print("Config Loaded")
 
     # Load Datasets
-    print('Loading Datasets...')
+    print("Loading Datasets...")
     (test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
-    print('Datasets Loaded')
+    print("Datasets Loaded")
 
     # Get Predictions, either by running the models or loading them from a file
-    if config['ensemble']['run_models']:
+    if config["ensemble"]["run_models"]:
         # Load Models
-        print('Loading Models...')
-        device = torch.device(config['training']['device'])
-        models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
-        print('Models Loaded')
+        print("Loading Models...")
+        device = torch.device(config["training"]["device"])
+        models = load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
+        print("Models Loaded")
 
         # Get Predictions
-        print('Getting Predictions...')
+        print("Getting Predictions...")
         test_predictions = get_ensemble_predictions(models, test_dataset, device)
         val_predictions = get_ensemble_predictions(
             models, val_dataset, device, len(test_dataset)
         )
-        print('Predictions Loaded')
+        print("Predictions Loaded")
 
         # Save Prediction
-        test_predictions.to_netcdf(f'{V4_PATH}/test_predictions.nc')
-        val_predictions.to_netcdf(f'{V4_PATH}/val_predictions.nc')
+        test_predictions.to_netcdf(f"{V4_PATH}/test_predictions.nc")
+        val_predictions.to_netcdf(f"{V4_PATH}/val_predictions.nc")
 
     else:
-        test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
-        val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
+        test_predictions = xr.open_dataarray(f"{V4_PATH}/test_predictions.nc")
+        val_predictions = xr.open_dataarray(f"{V4_PATH}/val_predictions.nc")
 
     # Prune Data
-    print('Pruning Data...')
-    if config['operation']['exclude_blank_ids']:
-        excluded_data_ids = config['ensemble']['excluded_ids']
+    print("Pruning Data...")
+    if config["operation"]["exclude_blank_ids"]:
+        excluded_data_ids = config["ensemble"]["excluded_ids"]
         test_predictions = prune_data(test_predictions, excluded_data_ids)
         val_predictions = prune_data(val_predictions, excluded_data_ids)
 
     # Concatenate Predictions
-    predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
+    predictions = xr.concat([test_predictions, val_predictions], dim="data_id")
 
     # Compute Ensemble Statistics
-    print('Computing Ensemble Statistics...')
+    print("Computing Ensemble Statistics...")
     ensemble_statistics = compute_ensemble_statistics(predictions)
-    ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
-    print('Ensemble Statistics Computed')
+    ensemble_statistics.to_netcdf(f"{V4_PATH}/ensemble_statistics.nc")
+    print("Ensemble Statistics Computed")
 
     # Compute Thresholded Predictions
-    print('Computing Thresholded Predictions...')
+    print("Computing Thresholded Predictions...")
     thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
-    thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
-    print('Thresholded Predictions Computed')
+    thresholded_predictions.to_netcdf(f"{V4_PATH}/thresholded_predictions.nc")
+    print("Thresholded Predictions Computed")
 
     # Graph Thresholded Predictions
-    print('Graphing Thresholded Predictions...')
+    print("Graphing Thresholded Predictions...")
     graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
-    print('Thresholded Predictions Graphed')
+    print("Thresholded Predictions Graphed")
 
     # Additional Graphs
-    print('Graphing Additional Graphs...')
+    print("Graphing Additional Graphs...")
     # Confidence vs stdev
     graph_statistics(
         ensemble_statistics,
-        'confidence',
-        'stdev',
-        f'{V4_PATH}/confidence_stdev.png',
-        'Confidence and Standard Deviation for Predictions',
-        'Confidence',
-        'Standard Deviation',
+        "confidence",
+        "stdev",
+        f"{V4_PATH}/confidence_stdev.png",
+        "Confidence and Standard Deviation for Predictions",
+        "Confidence",
+        "Standard Deviation",
     )
-    print('Additional Graphs Graphed')
+    print("Additional Graphs Graphed")
 
     # Compute Individual Statistics
-    print('Computing Individual Statistics...')
+    print("Computing Individual Statistics...")
     indv_statistics = compute_individual_statistics(predictions)
-    indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
-    print('Individual Statistics Computed')
+    indv_statistics.to_netcdf(f"{V4_PATH}/indv_statistics.nc")
+    print("Individual Statistics Computed")
 
     # Compute Individual Thresholds
-    print('Computing Individual Thresholds...')
+    print("Computing Individual Thresholds...")
     indv_thresholds = compute_individual_thresholds(indv_statistics)
-    indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
-    print('Individual Thresholds Computed')
+    indv_thresholds.to_netcdf(f"{V4_PATH}/indv_thresholds.nc")
+    print("Individual Thresholds Computed")
 
     # Graph Individual Thresholded Predictions
-    print('Graphing Individual Thresholded Predictions...')
-    if not os.path.exists(f'{V4_PATH}/indv'):
-        os.makedirs(f'{V4_PATH}/indv')
+    print("Graphing Individual Thresholded Predictions...")
+    if not os.path.exists(f"{V4_PATH}/indv"):
+        os.makedirs(f"{V4_PATH}/indv")
 
     graph_all_individual_thresholded_predictions(
         indv_thresholds, thresholded_predictions, V4_PATH
     )
-    print('Individual Thresholded Predictions Graphed')
+    print("Individual Thresholded Predictions Graphed")
 
     # Compute subset statistics and graph
-    print('Computing Sensitivity Analysis...')
+    print("Computing Sensitivity Analysis...")
     subset_stats = calculate_subset_statistics(predictions)
     sens_analysis = calculate_sensitivity_analysis(subset_stats)
     graph_sensitivity_analysis(
         sens_analysis,
-        'accuracy',
-        f'{V4_PATH}/sens_analysis.png',
-        'Sensitivity Analsis of Accuracy vs. # of Models',
-        '# of Models',
-        'Accuracy',
+        "accuracy",
+        f"{V4_PATH}/sens_analysis.png",
+        "Sensitivity Analsis of Accuracy vs. # of Models",
+        "# of Models",
+        "Accuracy",
     )
     graph_sensitivity_analysis(
         sens_analysis,
-        'ece',
-        f'{V4_PATH}/sens_analysis_ece.png',
-        'Sensitivity Analysis of ECE vs. # of Models',
-        '# of Models',
-        'ECE',
+        "ece",
+        f"{V4_PATH}/sens_analysis_ece.png",
+        "Sensitivity Analysis of ECE vs. # of Models",
+        "# of Models",
+        "ECE",
     )
-    print(sens_analysis.sel(statistic='accuracy'))
+    print(sens_analysis.sel(statistic="accuracy"))
     print(calculate_overall_stats(ensemble_statistics))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

+ 15 - 4
utils/data/datasets.py

@@ -9,7 +9,8 @@ from torch.utils.data import Dataset
 import pandas as pd
 from torch.utils.data import DataLoader
 import math
-
+from typing import Tuple
+import pathlib as pl
 
 """
 Prepares CustomDatasets for training, validating, and testing CNN
@@ -102,7 +103,7 @@ def get_train_val_test(AD_list, NL_list, val_split):
 
 
 class ADNIDataset(Dataset):
-    def __init__(self, mri, xls: pd.DataFrame, device=torch.device("cpu")):
+    def __init__(self, mri, xls: pd.DataFrame, data_dir: pl.Path, device:torch.device =torch.device("cpu"), ):
         self.mri_data = mri  # DATA IS A LIST WITH TUPLES (image_dir, class_id)
         self.xls_data = xls
         self.device = device
@@ -124,9 +125,15 @@ class ADNIDataset(Dataset):
         return xls_tensor
 
     def __getitem__(
-        self, idx
-    ):  # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
+        self, idx: int
+    ) -> Tuple[
+        Tuple[torch.Tensor, torch.Tensor], torch.Tensor
+    ]:  # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
         mri_path, class_id = self.mri_data[idx]
+
+        mri_path = pl.Path(mri_path).name
+        adj_path = self.
+
         mri = nib.load(mri_path)
         mri_data = mri.get_fdata()
 
@@ -147,6 +154,10 @@ class ADNIDataset(Dataset):
 
         return (mri_tensor, xls_tensor), class_id
 
+    def __iter__(self):
+        for i in range(len(self)):
+            yield self.__getitem__(i)
+
 
 def initalize_dataloaders(
     training_data,

+ 34 - 13
utils/ensemble.py

@@ -1,30 +1,51 @@
 import torch
-import os
-from glob import glob
+import pathlib
+import utils.models.cnn as c
+from typing import Tuple, List
+import xarray as xr
 
 
+type ModelPair = Tuple[c.CNN, str]
+type ModelPredictionData = xr.DataArray
+type InputData = Tuple[torch.Tensor, torch.Tensor]
+
 # This file contains functions to ensemble a folder of models and evaluate them on a test set, with included uncertainty estimation.
 
 
-def load_models(folder, device):
-    glob_path = os.path.join(folder, "*.pt")
-    model_files = glob(glob_path)
+def load_models(folder: pathlib.Path, device: str) -> List[ModelPair]:
+    model_files = folder.glob("*.pt")
 
-    models = []
-    model_descs = []
+    model_pairs: List[ModelPair] = []
 
     for model_file in model_files:
-        model = torch.load(model_file, map_location=device)
-        models.append(model)
+        model: c.CNN = torch.load(model_file, map_location=device, weights_only=False)
 
         # Extract model description from filename
-        desc = os.path.basename(model_file)
-        model_descs.append(os.path.splitext(desc)[0])
+        model_pairs.append((model, model_file.stem))
+
+    return model_pairs
+
+
+def prepare_datasets(data: Tuple[torch.Tensor, torch.Tensor]) -> InputData:
+    # Ensure the data is in the correct format
+    mri_data.unsqueeze(0)
+    xls_data.unsqueeze(0)
+
+    # Combine MRI and XLS data into a tuple
+    return (mri_data, xls_data)
+
+
+def get_model_names(models: List[ModelPair]) -> List[str]:
+    # Extract model names from the model pairs
+    return [model_pair[1] for model_pair in models]
+
 
-    return models, model_descs
+def get_model_objects(models: List[ModelPair]) -> List[c.CNN]:
+    # Extract model objects from the model pairs
+    return [model_pair[0] for model_pair in models]
 
 
-def ensemble_predict(models, input):
+def ensemble_predict(models: List[c.CNN], input: InputData):
     predictions = []
     for model in models:
         model.eval()

+ 3 - 2
utils/models/cnn.py

@@ -1,3 +1,4 @@
+from typing import Tuple
 from torch import nn
 import utils.models.layers as ly
 import torch
@@ -36,8 +37,8 @@ class CNN(nn.Module):
         self.dense2 = nn.Linear(5, 2)
         self.softmax = nn.Softmax(dim=1)
 
-    def forward(self, x):
-        image, clin_data = x
+    def forward(self, x_in: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
+        image, clin_data = x_in
 
         image = self.image_section(image)