2 İşlemeler 06a736e4f7 ... a02334abbf

Yazar SHA1 Mesaj Tarih
  Nicholas Schense a02334abbf More work done on overall stats - not sure why so many files changed 4 ay önce
  Nicholas Schense f2e7f78a40 More work done on overall stats 4 ay önce

+ 0 - 0
.gitignore


+ 0 - 0
.vscode/launch.json


+ 0 - 0
.vscode/settings.json


+ 0 - 0
LP_ADNIMERGE.csv


+ 0 - 0
README.md


+ 0 - 0
bayesian.py


+ 1 - 0
config.toml

@@ -11,6 +11,7 @@ runs = 50
 max_epochs = 30
 
 [dataset]
+excluded_ids = [91, 108, 268, 269, 272, 279, 293, 296, 307]
 validation_split = 0.4 #Splits the dataset into the train and validation/test set, 50% each for validation and test
 #|---TEST---|---VALIDATION---|---TRAIN---|
 #|splt*0.5  | split*0.5      | 1-split   |

+ 0 - 0
daily_log.md


+ 0 - 0
ensemble_predict.py


+ 0 - 0
planning.md


+ 0 - 0
ruff.toml


+ 85 - 0
sensitivity_analysis.py

@@ -0,0 +1,85 @@
+### This file is a program to run a sentivity analysis to determine what the best number of models to use in the ensemble is.
+
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import torch 
+import os
+
+import threshold_refac as th
+import pickle as pk
+
+def main():
+    config = th.load_config()
+
+    
+    ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+
+    V3_PATH = ENSEMBLE_PATH + '/v3'
+
+    # Create the directory if it does not exist
+    if not os.path.exists(V3_PATH):
+        os.makedirs(V3_PATH)
+
+    # Load the models
+    device = torch.device(config['training']['device'])
+    models = th.load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
+
+    # Load Dataset
+    dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
+        f'{ENSEMBLE_PATH}/val_dataset.pt'
+    )
+
+    if config['ensemble']['run_models']:
+        # Get thre predicitons of the ensemble
+        ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)
+
+        # Save to file using pickle
+        with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
+            pk.dump(ensemble_predictions, f)
+    else:
+        # Load the predictions from file
+        with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
+            ensemble_predictions = pk.load(f)
+
+    # Now that we have the predictions, we can run the sensitivity analysis
+    #We do this by getting the stats for each possible number of models in the ensemble
+    # We will store the results in a dataframe with number of models and the stats
+    results = pd.DataFrame(columns=['num_models', 'ECE', 'accuracy']).set_index('num_models')
+    for i in range(2, len(models) + 1):
+        sel_preds = th.select_subset_models(ensemble_predictions, range(i))
+
+
+        sel_stats = th.calculate_statistics(sel_preds)
+
+        raw_confidence = sel_stats['confidence'].apply(lambda x: (x / 2) + 0.5)
+        sel_stats.insert(4, 'raw_confidence', raw_confidence)
+
+        stats = th.calculate_overall_statistics(sel_stats)
+        ece = stats.at['raw_confidence', 'ECE']
+        accuracy = sel_stats['correct'].mean()
+        results.loc[i] = (ece, accuracy)
+
+    # Save the results to a file
+    results.to_csv(f'{V3_PATH}/sensitivity_analysis.csv')
+
+    # Plot the results
+    plt.plot(results.index, results['ECE'])
+    plt.xlabel('Number of Models')
+    plt.ylabel('ECE')
+    plt.title('Sensitivity Analysis')
+    plt.savefig(f'{V3_PATH}/sensitivity_analysis.png')
+    plt.close()
+
+    plt.plot(results.index, results['accuracy'])
+    plt.xlabel('Number of Models')
+    plt.ylabel('Accuracy')
+    plt.title('Sensitivity Analysis')
+    plt.savefig(f'{V3_PATH}/sensitivity_analysis_accuracy.png')
+    plt.close()
+
+
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 0
threshold.py


+ 193 - 46
threshold_refac.py

@@ -14,24 +14,67 @@ import matplotlib.ticker as ticker
 import glob
 import pickle as pk
 import warnings
+import random as rand
 
 warnings.filterwarnings('error')
 
-# CONFIGURATION
-if os.getenv('ADL_CONFIG_PATH') is None:
-    with open('config.toml', 'rb') as f:
-        config = toml.load(f)
-else:
-    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
-        config = toml.load(f)
 
-ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+def plot_image_grid(image_ids, dataset, rows, path, titles=None):
+    fig, axs = plt.subplots(rows, len(image_ids) // rows)
+    for i, ax in enumerate(axs.flat):
+        image_id = image_ids[i]
+        image = dataset[image_id][0][0].squeeze().cpu().numpy()
+        # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
+        image = image[:, :, 45]
+
+        ax.imshow(image, cmap='gray')
+        ax.axis('off')
+        if titles is not None:
+            ax.set_title(titles[i])
+
+    plt.savefig(path)
+    plt.close()
+
+
+def plot_single_image(image_id, dataset, path, title=None):
+    fig, ax = plt.subplots()
+    image = dataset[image_id][0][0].squeeze().cpu().numpy()
+    # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
+    image = image[:, :, 45]
+
+    ax.imshow(image, cmap='gray')
+    ax.axis('off')
+    if title is not None:
+        ax.set_title(title)
 
-V3_PATH = ENSEMBLE_PATH + '/v3'
+    plt.savefig(path)
+    plt.close()
 
-# Create the directory if it does not exist
-if not os.path.exists(V3_PATH):
-    os.makedirs(V3_PATH)
+
+# Given a dataframe of the form {data_id: (stat_1, stat_2, ..., correct)}, plot the two statistics against each other and color by correctness
+def plot_statistics_versus(
+    stat_1, stat_2, xaxis_name, yaxis_name, title, dataframe, path, annotate=False
+):
+    # Get correct predictions and incorrect predictions dataframes
+    corr_df = dataframe[dataframe['correct']]
+    incorr_df = dataframe[~dataframe['correct']]
+
+    # Plot the correct and incorrect predictions
+    fig, ax = plt.subplots()
+    ax.scatter(corr_df[stat_1], corr_df[stat_2], c='green', label='Correct')
+    ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c='red', label='Incorrect')
+    ax.legend()
+    ax.set_xlabel(xaxis_name)
+    ax.set_ylabel(yaxis_name)
+    ax.set_title(title)
+
+    if annotate:
+        print('DEBUG -- REMOVE: Annotating')
+        # label correct points green
+        for row in dataframe[[stat_1, stat_2]].itertuples():
+            plt.text(row[1], row[2], row[0], fontsize=6, color='black')
+
+    plt.savefig(path)
 
 
 # Models is a dictionary with the model ids as keys and the model data as values
@@ -99,13 +142,19 @@ def select_individual_model(predictions, model_id):
 
 
 # Given a dictionary of predictions, select a subset of models and eliminate the rest
+# predictions dictory of the form {data_id: (target, {model_id: prediction})}
 def select_subset_models(predictions, model_ids):
     selected_model_predictions = {}
     for key, value in predictions.items():
+        target = value[0]
+        model_predictions = value[1]
+
+        # Filter the model predictions, only keeping selected models
         selected_model_predictions[key] = (
-            value[0],
-            {model_id: value[1][model_id] for model_id in model_ids},
+            target,
+            {model_id: model_predictions[str(model_id + 1)] for model_id in model_ids},
         )
+
     return selected_model_predictions
 
 
@@ -239,7 +288,59 @@ def common_entries(*dcts):
         yield (i,) + tuple(d[i] for d in dcts)
 
 
+# Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
+def calculate_overall_statistics(ensemble_statistics):
+    predicted = ensemble_statistics['predicted']
+    actual = ensemble_statistics['actual']
+
+    # New dataframe to store the statistics
+    stats_df = pd.DataFrame(
+        columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']
+    ).set_index('stat')
+
+    # Loop through and calculate the ECE, MCE, Brier Score, and NLL
+    for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
+        ece = met.ECE(predicted, ensemble_statistics[stat], actual)
+        mce = met.MCE(predicted, ensemble_statistics[stat], actual)
+        brier = met.brier_binary(ensemble_statistics[stat], actual)
+        nll = met.nll_binary(ensemble_statistics[stat], actual)
+
+        stats_df.loc[stat] = [ece, mce, brier, nll]
+
+    return stats_df
+
+
+# CONFIGURATION
+def load_config():
+    if os.getenv('ADL_CONFIG_PATH') is None:
+        with open('config.toml', 'rb') as f:
+            config = toml.load(f)
+    else:
+        with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+            config = toml.load(f)
+
+    return config
+
+def prune_dataset(dataset, pruned_ids):
+    pruned_dataset = []
+    for i, (data, target) in enumerate(dataset):
+        if i not in pruned_ids:
+            pruned_dataset.append((data, target))
+
+    return pruned_dataset
+
+
 def main():
+    config = load_config()
+
+    ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+
+    V3_PATH = ENSEMBLE_PATH + '/v3'
+
+    # Create the directory if it does not exist
+    if not os.path.exists(V3_PATH):
+        os.makedirs(V3_PATH)
+
     # Load the models
     device = torch.device(config['training']['device'])
     models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
@@ -249,6 +350,8 @@ def main():
         f'{ENSEMBLE_PATH}/val_dataset.pt'
     )
 
+    dataset = 
+
     if config['ensemble']['run_models']:
         # Get thre predicitons of the ensemble
         ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
@@ -263,6 +366,7 @@ def main():
 
     # Get the statistics and thresholds of the ensemble
     ensemble_statistics = calculate_statistics(ensemble_predictions)
+
     stdev_thresholds = conduct_threshold_analysis(
         ensemble_statistics, 'stdev', low_to_high=True
     )
@@ -273,43 +377,86 @@ def main():
         ensemble_statistics, 'confidence', low_to_high=False
     )
 
-    # Print ECE and MCE Values
-    conf_ece = met.ECE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['confidence'],
-        ensemble_statistics['actual'],
-    )
-    conf_mce = met.MCE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['confidence'],
-        ensemble_statistics['actual'],
+    raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
+    ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
+
+    # Plot confidence vs standard deviation
+    plot_statistics_versus(
+        'raw_confidence',
+        'stdev',
+        'Confidence',
+        'Standard Deviation',
+        'Confidence vs Standard Deviation',
+        ensemble_statistics,
+        f'{V3_PATH}/confidence_vs_stdev.png',
+        annotate=True,
     )
 
-    ent_ece = met.ECE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['entropy'],
-        ensemble_statistics['actual'],
-    )
-    ent_mce = met.MCE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['entropy'],
-        ensemble_statistics['actual'],
+    # Plot images - 3 weird and 3 normal
+    # Selected from confidence vs stdev plot
+    plot_image_grid(
+        [279, 202, 28, 107, 27, 121],
+        dataset,
+        2,
+        f'{V3_PATH}/image_grid.png',
+        titles=[
+            'Weird: 279',
+            'Weird: 202',
+            'Weird: 28',
+            'Normal: 107',
+            'Normal: 27',
+            'Normal: 121',
+        ],
     )
 
-    stdev_ece = met.ECE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['stdev'],
-        ensemble_statistics['actual'],
-    )
-    stdev_mce = met.MCE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['stdev'],
-        ensemble_statistics['actual'],
-    )
+    # Filter dataset for where confidence < .7 and stdev < .1
+    weird_results = ensemble_statistics.loc[
+        (
+            (ensemble_statistics['raw_confidence'] < 0.7)
+            & (ensemble_statistics['stdev'] < 0.1)
+        )
+    ]
+    normal_results = ensemble_statistics.loc[
+        ~(
+            (ensemble_statistics['raw_confidence'] < 0.7)
+            & (ensemble_statistics['stdev'] < 0.1)
+        )
+    ]
+    # Get the data ids in a list
+    # Plot the images
+    if not os.path.exists(f'{V3_PATH}/images'):
+        os.makedirs(f'{V3_PATH}/images/weird')
+        os.makedirs(f'{V3_PATH}/images/normal')
+
+    for i in weird_results.itertuples():
+        id = i.Index
+        conf = i.raw_confidence
+        stdev = i.stdev
+
+        plot_single_image(
+            id,
+            dataset,
+            f'{V3_PATH}/images/weird/{id}.png',
+            title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
+        )
+
+    for i in normal_results.itertuples():
+        id = i.Index
+        conf = i.raw_confidence
+        stdev = i.stdev
+
+        plot_single_image(
+            id,
+            dataset,
+            f'{V3_PATH}/images/normal/{id}.png',
+            title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
+        )
+
+    # Calculate overall statistics
+    overall_statistics = calculate_overall_statistics(ensemble_statistics)
 
-    print(f'Confidence ECE: {conf_ece}, Confidence MCE: {conf_mce}')
-    print(f'Entropy ECE: {ent_ece}, Entropy MCE: {ent_mce}')
-    print(f'Stdev ECE: {stdev_ece}, Stdev MCE: {stdev_mce}')
+    # Print overall statistics
+    print(overall_statistics)
 
     # Print overall ensemble statistics
     print('Ensemble Statistics')

+ 0 - 0
train_cnn.py


+ 0 - 0
utils/data/datasets.py


+ 0 - 0
utils/data/preprocessing.py


+ 0 - 0
utils/ensemble.py


+ 11 - 0
utils/metrics.py

@@ -75,3 +75,14 @@ def AUC(confidences, true_labels):
 
 def entropy(confidences):
     return -np.sum(confidences * np.log(confidences))
+
+### Negative Log Likelyhood for binary classification
+def nll_binary(confidences, true_labels):
+    return -np.sum(np.log(confidences[true_labels == 1])) - np.sum(np.log(1 - confidences[true_labels == 0]))
+
+### Breier score for binary classification
+def brier_binary(confidences, true_labels):
+    return np.mean((confidences - true_labels) ** 2)
+
+
+

+ 0 - 0
utils/models/cnn.py


+ 0 - 0
utils/models/layers.py


+ 0 - 0
utils/system.py


+ 0 - 0
utils/testing.py


+ 0 - 0
utils/training.py