4 コミット 72c64b23d9 ... 11ea1d0a02

作者 SHA1 メッセージ 日付
  Nicholas Schense 11ea1d0a02 Merge branch 'system-rewrite' of https://git0.fmf.uni-lj.si/nschense/alzheimers_nn into system-rewrite 1 週間 前
  Nicholas Schense 26e4e9c3f3 Commit of work from summer 1 週間 前
  Nicholas Schense 4549b2b349 Worked on calibration (little success) and began tweaking graphs for presentation and poster 8 ヶ月 前
  Nicholas Schense e9b380d66a Still working on sensitivity analysis work - interesting results for accuracy 8 ヶ月 前
5 ファイル変更354 行追加47 行削除
  1. 40 0
      calibration_xarray.py
  2. 37 0
      dataset_size.py
  3. 158 47
      threshold_xarray.py
  4. 21 0
      xarray_images.py
  5. 98 0
      xarray_sensitivity.py

+ 40 - 0
calibration_xarray.py

@@ -0,0 +1,40 @@
+import threshold_xarray as th
+import xarray as xr
+import numpy as np
+import torch
+import os
+
+import sklearn.calibration as cal
+
+
+# The purpose of this file is to calibrate the data on the test set, and evaluate the calibration on the validation set.
+# We are using scikits calibration library to do this.
+
+if __name__ == '__main__':
+    print('Loading Config..B')
+    config = th.load_config()
+    ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+    V4_PATH = ENSEMBLE_PATH + '/v4'
+
+    if not os.path.exists(V4_PATH):
+        os.makedirs(V4_PATH)
+    print('Config Loaded')
+
+    # Load the predictions
+    print('Loading Predictions...')
+    val_preds = xr.open_dataset(f'{ENSEMBLE_PATH}/val_predictions.nc')
+    test_preds = xr.open_dataset(f'{ENSEMBLE_PATH}/test_predictions.nc')
+    print('Predictions Loaded')
+
+    # Now the goal is to calibrate the test set, and evaluate the calibration on the validation set.
+    # We do this by binning the data into 15 bins, and then calculating the mean of the predictions in each bin.
+    # We then use this to calibrate the data.
+
+    # First, get the statistics of both sets
+    print('Calculating Statistics...')
+    val_stats = th.compute_ensemble_statistics(val_preds)
+    test_stats = th.compute_ensemble_statistics(test_preds)
+
+    # Calibrate the test set
+    print('Calibrating Test Set...')
+ 

+ 37 - 0
dataset_size.py

@@ -0,0 +1,37 @@
+# This file is just to tell how many total images there are in the dataset
+
+import threshold_refac as tr
+import torch
+
+config = tr.load_config()
+ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+
+test_dset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt')
+val_dset = torch.load(f'{ENSEMBLE_PATH}/val_dataset.pt')
+train_dset = torch.load(f'{ENSEMBLE_PATH}/train_dataset.pt')
+
+
+print(
+    f'Total number of images in dataset: {len(test_dset) + len(val_dset) + len(train_dset)}'
+)
+print(f'Test: {len(test_dset)}, Val: {len(val_dset)}, Train: {len(train_dset)}')
+
+
+def preprocess_data(data, device):
+    mri, xls = data
+    mri = mri.unsqueeze(0).to(device)
+    xls = xls.unsqueeze(0).to(device)
+    return (mri, xls)
+
+
+# Loop through images and determine how many are positive and negative
+positive = 0
+negative = 0
+for _, (_, target) in enumerate(test_dset + train_dset + val_dset):
+    actual = list(target.cpu().numpy())[1].item()
+    if actual == 1:
+        positive += 1
+    else:
+        negative += 1
+
+print(f'Positive: {positive}, Negative: {negative}')

+ 158 - 47
threshold_xarray.py

@@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
 import matplotlib.ticker as mtick
 
 
+
 # The datastructures for this file are as follows
 # models_dict: Dictionary - {model_id: model}
 # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
@@ -65,19 +66,20 @@ def preprocess_data(data, device):
 
 # Loads datasets and returns concatenated test and validation datasets
 def load_datasets(ensemble_path):
-    return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
-        f'{ensemble_path}/val_dataset.pt'
+    return (
+        torch.load(f'{ensemble_path}/test_dataset.pt'),
+        torch.load(f'{ensemble_path}/val_dataset.pt'),
     )
 
 
 # Gets the predictions for a set of models on a dataset
-def get_ensemble_predictions(models, dataset, device):
+def get_ensemble_predictions(models, dataset, device, id_offset=0):
     zeros = np.zeros((len(dataset), len(models), 4))
     predictions = xr.DataArray(
         zeros,
         dims=('data_id', 'model_id', 'prediction_value'),
         coords={
-            'data_id': range(len(dataset)),
+            'data_id': range(id_offset, len(dataset) + id_offset),
             'model_id': list(models.keys()),
             'prediction_value': [
                 'negative_prediction',
@@ -98,9 +100,9 @@ def get_ensemble_predictions(models, dataset, device):
                 output = model(dat)
                 prediction = output.cpu().numpy().tolist()[0]
 
-                predictions.loc[{'data_id': data_id, 'model_id': model_id}] = (
-                    prediction + actual
-                )
+                predictions.loc[
+                    {'data_id': data_id + id_offset, 'model_id': model_id}
+                ] = prediction + actual
 
     return predictions
 
@@ -159,7 +161,7 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
 
 # Compute the thresholded predictions given an array of predictions
 def compute_thresholded_predictions(input_stats: xr.DataArray):
-    quantiles = np.linspace(0.05, 0.95, 19) * 100
+    quantiles = np.linspace(0.00, 1.00, 21) * 100
     metrics = ['accuracy', 'f1']
     statistics = ['stdev', 'entropy', 'confidence']
 
@@ -217,6 +219,13 @@ def compute_metric(arr, metric):
         return met.F1(
             arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
         )
+    elif metric == 'ece':
+        true_labels = arr.loc[{'statistic': 'actual'}].values
+        predicted_labels = arr.loc[{'statistic': 'predicted'}].values
+        confidences = arr.loc[{'statistic': 'confidence'}].values
+
+        return calculate_ece_stats(confidences, predicted_labels, true_labels)
+
     else:
         raise ValueError('Invalid metric: ' + metric)
 
@@ -251,8 +260,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'confidence',
         'accuracy',
         f'{save_path}/confidence_accuracy.png',
-        'Confidence vs. Accuracy',
-        'Confidence',
+        'Coverage Analysis of Confidence vs. Accuracy',
+        'Minimum Confidence Percentile Threshold',
         'Accuracy',
     )
 
@@ -262,9 +271,9 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'confidence',
         'f1',
         f'{save_path}/confidence_f1.png',
-        'Confidence vs. F1',
-        'Confidence',
-        'F1',
+        'Coverage Analysis of Confidence vs. F1 Score',
+        'Minimum Confidence Percentile Threshold',
+        'F1 Score',
     )
 
     # Entropy Accuracy
@@ -273,8 +282,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'entropy',
         'accuracy',
         f'{save_path}/entropy_accuracy.png',
-        'Entropy vs. Accuracy',
-        'Entropy',
+        'Coverage Analysis of Entropy vs. Accuracy',
+        'Maximum Entropy Percentile Threshold',
         'Accuracy',
     )
 
@@ -285,9 +294,9 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'entropy',
         'f1',
         f'{save_path}/entropy_f1.png',
-        'Entropy vs. F1',
-        'Entropy',
-        'F1',
+        'Coverage Analysis of Entropy vs. F1 Score',
+        'Maximum Entropy Percentile Threshold',
+        'F1 Score',
     )
 
     # Stdev Accuracy
@@ -296,8 +305,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'stdev',
         'accuracy',
         f'{save_path}/stdev_accuracy.png',
-        'Standard Deviation vs. Accuracy',
-        'Standard Deviation',
+        'Coverage Analysis of Standard Deviation vs. Accuracy',
+        'Maximum Standard Deviation Percentile Threshold',
         'Accuracy',
     )
 
@@ -307,8 +316,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'stdev',
         'f1',
         f'{save_path}/stdev_f1.png',
-        'Standard Deviation vs. F1',
-        'Standard Deviation',
+        'Coverage Analysis of Standard Deviation vs. F1 Score',
+        'Maximum Standard Deviation Percentile Threshold',
         'F1',
     )
 
@@ -371,7 +380,9 @@ def compute_individual_statistics(predictions: xr.DataArray):
         },
     )
 
-    for data_id in predictions.data_id:
+    for data_id in tqdm(
+        predictions.data_id, total=len(predictions.data_id), unit='images'
+    ):
         for model_id in predictions.model_id:
             data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
             mean = data[0:2]
@@ -414,7 +425,9 @@ def compute_individual_thresholds(input_stats: xr.DataArray):
         },
     )
 
-    for model_id in input_stats.model_id:
+    for model_id in tqdm(
+        input_stats.model_id, total=len(input_stats.model_id), unit='models'
+    ):
         for statistic in statistics:
             # First, we must compute the quantiles for the statistic
             quantile_values = np.percentile(
@@ -504,8 +517,8 @@ def graph_all_individual_thresholded_predictions(
         'confidence',
         'accuracy',
         f'{save_path}/indv/confidence_accuracy.png',
-        'Confidence vs. Accuracy',
-        'Confidence Percentile Threshold',
+        'Coverage Analysis of Confidence vs. Accuracy for All Models',
+        'Minumum Confidence Percentile Threshold',
         'Accuracy',
     )
 
@@ -516,9 +529,9 @@ def graph_all_individual_thresholded_predictions(
         'confidence',
         'f1',
         f'{save_path}/indv/confidence_f1.png',
-        'Confidence vs. F1',
-        'Confidence Percentile Threshold',
-        'F1',
+        'Coverage Analysis of Confidence vs. F1 Score for All Models',
+        'Minimum Confidence Percentile Threshold',
+        'F1 Score',
     )
 
     # Entropy Accuracy
@@ -528,8 +541,8 @@ def graph_all_individual_thresholded_predictions(
         'entropy',
         'accuracy',
         f'{save_path}/indv/entropy_accuracy.png',
-        'Entropy vs. Accuracy',
-        'Entropy Percentile Threshold',
+        'Coverage Analysis of Entropy vs. Accuracy for All Models',
+        'Maximum Entropy Percentile Threshold',
         'Accuracy',
     )
 
@@ -540,16 +553,17 @@ def graph_all_individual_thresholded_predictions(
         'entropy',
         'f1',
         f'{save_path}/indv/entropy_f1.png',
-        'Entropy vs. F1',
-        'Entropy Percentile Threshold',
-        'F1',
+        'Coverage Analysis of Entropy vs. F1 Score for All Models',
+        'Maximum Entropy Percentile Threshold',
+        'F1 Score',
     )
 
 
 # Calculate statistics of subsets of models for sensitivity analysis
 def calculate_subset_statistics(predictions: xr.DataArray):
-    # Calculate subsets for 1-50 models
-    subsets = range(1, len(predictions.model_id) + 1)
+    # Calculate subsets for 1-49 models
+    subsets = range(1, len(predictions.model_id))
+
     zeros = np.zeros(
         (len(predictions.data_id), len(subsets), 7)
     )  # Include stdev, but for 1 models set to NaN
@@ -572,7 +586,9 @@ def calculate_subset_statistics(predictions: xr.DataArray):
         },
     )
 
-    for data_id in predictions.data_id:
+    for data_id in tqdm(
+        predictions.data_id, total=len(predictions.data_id), unit='images'
+    ):
         for subset in subsets:
             data = predictions.sel(
                 data_id=data_id, model_id=predictions.model_id[:subset]
@@ -581,7 +597,7 @@ def calculate_subset_statistics(predictions: xr.DataArray):
             stdev = data.std(dim='model_id')[1]
             entropy = (-mean * np.log(mean)).sum()
             confidence = mean.max()
-            actual = data[3]
+            actual = data[0][3]
             predicted = mean.argmax()
             correct = actual == predicted
 
@@ -600,17 +616,80 @@ def calculate_subset_statistics(predictions: xr.DataArray):
 
 # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
 def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
-    subsets = subset_stats.subsets
-    stats = ['accuracy', 'f1', 'ECE', 'MCE']
+    subsets = subset_stats.model_count
+    stats = ['accuracy', 'f1', 'ece']
 
     zeros = np.zeros((len(subsets), len(stats)))
 
     sens_analysis = xr.DataArray(
         zeros,
         dims=('model_count', 'statistic'),
-        coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
+        coords={'model_count': subsets, 'statistic': stats},
     )
 
+    for subset in tqdm(subsets, total=len(subsets), unit='model subsets'):
+
+        data = subset_stats.sel(model_count=subset)
+        acc = compute_metric(data, 'accuracy').item()
+        f1 = compute_metric(data, 'f1').item()
+        ece = compute_metric(data, 'ece').item()
+
+        sens_analysis.loc[{'model_count': subset.item()}] = [acc, f1, ece]
+
+    return sens_analysis
+
+
+def graph_sensitivity_analysis(
+    sens_analysis: xr.DataArray, statistic, save_path, title, xlabel, ylabel
+):
+    data = sens_analysis.sel(statistic=statistic)
+
+    xdata = data.coords['model_count'].values
+    ydata = data.values
+
+    fig, ax = plt.subplots()
+    ax.plot(xdata, ydata)
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+
+    plt.savefig(save_path)
+
+
+def calculate_overall_stats(ensemble_statistics: xr.DataArray):
+    accuracy = compute_metric(ensemble_statistics, 'accuracy')
+    f1 = compute_metric(ensemble_statistics, 'f1')
+
+    return {'accuracy': accuracy.item(), 'f1': f1.item()}
+
+
+# https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
+def calculate_ece_stats(confidences, predicted_labels, true_labels, bins=10):
+    bin_boundaries = np.linspace(0, 1, bins + 1)
+    bin_lowers = bin_boundaries[:-1]
+    bin_uppers = bin_boundaries[1:]
+
+    ece = np.zeros(1)
+
+    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
+        in_bin = np.logical_and(
+            confidences > bin_lower.item(), confidences <= bin_upper.item()
+        )
+        prob_in_bin = in_bin.mean()
+
+        if prob_in_bin.item() > 0:
+            accuracy_in_bin = true_labels[in_bin].mean()
+
+            avg_confidence_in_bin = confidences[in_bin].mean()
+
+            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
+
+    return ece
+
+
+def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
+    fix, ax = plt.subplot()
+
 
 # Main Function
 def main():
@@ -625,7 +704,7 @@ def main():
 
     # Load Datasets
     print('Loading Datasets...')
-    dataset = load_datasets(ENSEMBLE_PATH)
+    (test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
     print('Datasets Loaded')
 
     # Get Predictions, either by running the models or loading them from a file
@@ -638,20 +717,29 @@ def main():
 
         # Get Predictions
         print('Getting Predictions...')
-        predictions = get_ensemble_predictions(models, dataset, device)
+        test_predictions = get_ensemble_predictions(models, test_dataset, device)
+        val_predictions = get_ensemble_predictions(
+            models, val_dataset, device, len(test_dataset)
+        )
         print('Predictions Loaded')
 
         # Save Prediction
-        predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
+        test_predictions.to_netcdf(f'{V4_PATH}/test_predictions.nc')
+        val_predictions.to_netcdf(f'{V4_PATH}/val_predictions.nc')
 
     else:
-        predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
+        test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
+        val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
 
     # Prune Data
     print('Pruning Data...')
     if config['operation']['exclude_blank_ids']:
         excluded_data_ids = config['ensemble']['excluded_ids']
-        predictions = prune_data(predictions, excluded_data_ids)
+        test_predictions = prune_data(test_predictions, excluded_data_ids)
+        val_predictions = prune_data(val_predictions, excluded_data_ids)
+
+    # Concatenate Predictions
+    predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
 
     # Compute Ensemble Statistics
     print('Computing Ensemble Statistics...')
@@ -678,7 +766,7 @@ def main():
         'confidence',
         'stdev',
         f'{V4_PATH}/confidence_stdev.png',
-        'Confidence vs. Standard Deviation',
+        'Confidence and Standard Deviation for Predictions',
         'Confidence',
         'Standard Deviation',
     )
@@ -706,6 +794,29 @@ def main():
     )
     print('Individual Thresholded Predictions Graphed')
 
+    # Compute subset statistics and graph
+    print('Computing Sensitivity Analysis...')
+    subset_stats = calculate_subset_statistics(predictions)
+    sens_analysis = calculate_sensitivity_analysis(subset_stats)
+    graph_sensitivity_analysis(
+        sens_analysis,
+        'accuracy',
+        f'{V4_PATH}/sens_analysis.png',
+        'Sensitivity Analsis of Accuracy vs. # of Models',
+        '# of Models',
+        'Accuracy',
+    )
+    graph_sensitivity_analysis(
+        sens_analysis,
+        'ece',
+        f'{V4_PATH}/sens_analysis_ece.png',
+        'Sensitivity Analysis of ECE vs. # of Models',
+        '# of Models',
+        'ECE',
+    )
+    print(sens_analysis.sel(statistic='accuracy'))
+    print(calculate_overall_stats(ensemble_statistics))
+
 
 if __name__ == '__main__':
     main()

+ 21 - 0
xarray_images.py

@@ -0,0 +1,21 @@
+# I need a couple of images of positive and negative class members for my poster
+
+import xarray as xr
+import numpy as np
+import matplotlib.pyplot as plt
+import torch as th
+
+
+# Dataset is a torch dataset
+def get_image(dataset, idx):
+    img = dataset[idx][0].numpy()
+    return img
+
+
+def plot_image(img, path):
+    plt.imshow(img)
+    plt.savefig(path)
+
+
+def get_random_positive_image(dataset):
+    idx =

+ 98 - 0
xarray_sensitivity.py

@@ -0,0 +1,98 @@
+# For this project, we want to calculate the accuracy for many different numbers of models and model selections
+# We cannot calculate every possible permutation and combination of models, so we will use a random selection
+
+
+import math
+import itertools as it
+
+import xarray as xr
+import numpy as np
+import matplotlib.pyplot as plt
+
+import torch as th
+import random as rand
+
+import threshold_xarray as txr
+import os
+
+
+# Generate a selection of combinations given an iterable, combination length, and number of combinations
+# This checks the number of possible combinations and compares it to the requested number of combinations
+# If the number of requested combinations is more than the possible number of combinations, it wil error
+# If it is less, it will generate all possible combinations and select the requested number of combinations
+# If it is much less, it will randomly generate the requested number of combinations and check that they are unique
+# If it is equal, it will generate and return all possible combinations\
+def get_combinations(iterable, r, n_combinations):
+    possible_combinations = math.comb(len(iterable), r)
+
+    if n_combinations < possible_combinations:
+        raise ValueError(
+            f'Number of requested combinations {n_combinations} of length {r} on set of length {len(iterable)} is less than the possible number of combinations {possible_combinations}'
+        )
+    elif n_combinations == possible_combinations:
+        return list(it.combinations(iterable, r))
+    else:
+        if n_combinations < possible_combinations / 5:
+            combinations = []
+            while len(combinations) < n_combinations:
+                combination = rand.sample(iterable, r)
+
+                if combination not in combinations:
+                    combinations.append(combination)
+            return combinations
+        else:
+            combinations = list(it.combinations(iterable, r))  # All possible combinations
+            return rand.sample(
+                combinations, n_combinations
+            )  # Randomly select n_combinations
+        
+
+# Now that we have a function to generate combinations, we can generate a list of 49 * 50 + 1 combinations
+# This will be a list of 2451 combinations of 50 models
+
+models = list(range(50))
+combos = {}
+for i in range(49):
+    combos[i] = get_combinations(models, i + 1, 50)
+
+combos[50] = [models]
+
+
+# Now that we have the list of combinations, we need the predictions
+print('Loading Config...')
+config = txr.load_config()
+ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+V4_PATH = ENSEMBLE_PATH + '/v4'
+
+if not os.path.exists(V4_PATH):
+    os.makedirs(V4_PATH)
+print('Config Loaded')
+
+test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
+val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
+
+# Prune Data
+print('Pruning Data...')
+if config['operation']['exclude_blank_ids']:
+    excluded_data_ids = config['ensemble']['excluded_ids']
+    test_predictions = txr.prune_data(test_predictions, excluded_data_ids)
+    val_predictions = txr.prune_data(val_predictions, excluded_data_ids)
+
+# Concatenate Predictions
+predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
+
+# Now that we have the list of predictions, we can calculate the accuracy and other stats for each combination
+# We will calculate the accuracy for each combination of models and save the results
+
+# Calculate the accuracy for each combination of models
+for num_models, model_combinations in combos.items():
+    print(f'Calculating Accuracy for {num_models} Models')
+    for i, model_combination in enumerate(model_combinations):
+        print(f'Calculating Accuracy for Combination {i} of {len(model_combinations)}')
+        model_predictions = predictions.sel(model=model_combination)
+        
+        # Calculate the accuracy
+        num_correct = 
+
+
+