import pandas as pd import numpy as np import os import tomli as toml from utils.data.datasets import prepare_datasets import utils.ensemble as ens import torch import matplotlib.pyplot as plt import sklearn.metrics as metrics from tqdm import tqdm # CONFIGURATION if os.getenv("ADL_CONFIG_PATH") is None: with open("config.toml", "rb") as f: config = toml.load(f) else: with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f: config = toml.load(f) # This function returns a list of the accuracies given a threshold def threshold(config): # First, get the model data ts, vs, test_set = prepare_datasets( config["paths"]["mri_data"], config["paths"]["xls_data"], config["dataset"]["validation_split"], 944, config["training"]["device"], ) test_set = test_set + vs models, _ = ens.load_models( config["paths"]["model_output"] + config["ensemble"]["name"] + "/", config["training"]["device"], ) predictions = [] # Evaluate ensemble and uncertainty test set for mdata, target in tqdm(test_set, total=len(test_set)): mri, xls = mdata mri = mri.unsqueeze(0) xls = xls.unsqueeze(0) mdata = (mri, xls) mean, variance = ens.ensemble_predict(models, mdata) stdev = torch.sqrt(variance) prediction = mean.item() target = target[1] # Check if the prediction is correct correct = (prediction < 0.5 and int(target.item()) == 0) or ( prediction >= 0.5 and int(target.item()) == 1 ) predictions.append( { "Prediction": prediction, "Actual": target.item(), "Stdev": stdev.item(), "Correct": correct, } ) # Sort the predictions by the uncertainty predictions = pd.DataFrame(predictions).sort_values(by="Stdev") thresholds = [] quantiles = np.arange(0.1, 1, 0.1) # get uncertainty quantiles for quantile in quantiles: thresholds.append(predictions["Stdev"].quantile(quantile)) # Calculate the accuracy of the model for each threshold accuracies = [] # Calculate the accuracy of the model for each threshold for threshold, quantile in zip(thresholds, quantiles): filtered = predictions[predictions["Stdev"] <= threshold] correct = filtered["Correct"].sum() total = len(filtered) accuracy = correct / total false_positives = len( filtered[(filtered["Prediction"] >= 0.5) & (filtered["Actual"] == 0)] ) false_negatives = len( filtered[(filtered["Prediction"] < 0.5) & (filtered["Actual"] == 1)] ) f1 = 2 * correct / (2 * correct + false_positives + false_negatives) auc = metrics.roc_auc_score(filtered["Actual"], filtered["Prediction"]) accuracies.append( { "Threshold": threshold, "Accuracy": accuracy, "Quantile": quantile, "F1": f1, "AUC": auc, } ) predictions.to_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv" ) return pd.DataFrame(accuracies) result = threshold(config) result.to_csv("coverage.csv") result = pd.read_csv("coverage.csv") predictions = pd.read_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv" ) print(result) plt.figure() plt.plot(result["Quantile"], result["Accuracy"]) plt.xlabel("Coverage") plt.ylabel("Accuracy") plt.gca().invert_xaxis() plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png" ) plt.figure() plt.plot(result["Quantile"], result["F1"]) plt.xlabel("Coverage") plt.ylabel("F1") plt.gca().invert_xaxis() plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png" ) plt.figure() plt.plot(result["Quantile"], result["AUC"]) plt.xlabel("Coverage") plt.ylabel("AUC") plt.gca().invert_xaxis() plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png" ) # create histogram of the incorrect predictions vs the uncertainty plt.figure() plt.hist(predictions[~predictions["Correct"]]["Stdev"], bins=10) plt.xlabel("Uncertainty") plt.ylabel("Number of incorrect predictions") plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png" )