123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- import pandas as pd
- import numpy as np
- import os
- import tomli as toml
- from utils.data.datasets import prepare_datasets
- import utils.ensemble as ens
- import torch
- import matplotlib.pyplot as plt
- import sklearn.metrics as metrics
- from tqdm import tqdm
- # CONFIGURATION
- if os.getenv("ADL_CONFIG_PATH") is None:
- with open("config.toml", "rb") as f:
- config = toml.load(f)
- else:
- with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
- config = toml.load(f)
- # This function returns a list of the accuracies given a threshold
- def threshold(config):
- # First, get the model data
- ts, vs, test_set = prepare_datasets(
- config["paths"]["mri_data"],
- config["paths"]["xls_data"],
- config["dataset"]["validation_split"],
- 944,
- config["training"]["device"],
- )
- test_set = test_set + vs
- models, _ = ens.load_models(
- config["paths"]["model_output"] + config["ensemble"]["name"] + "/",
- config["training"]["device"],
- )
- predictions = []
- # Evaluate ensemble and uncertainty test set
- for mdata, target in tqdm(test_set, total=len(test_set)):
- mri, xls = mdata
- mri = mri.unsqueeze(0)
- xls = xls.unsqueeze(0)
- mdata = (mri, xls)
- mean, variance = ens.ensemble_predict(models, mdata)
- stdev = torch.sqrt(variance)
- prediction = mean.item()
- target = target[1]
- # Check if the prediction is correct
- correct = (prediction < 0.5 and int(target.item()) == 0) or (
- prediction >= 0.5 and int(target.item()) == 1
- )
- predictions.append(
- {
- "Prediction": prediction,
- "Actual": target.item(),
- "Stdev": stdev.item(),
- "Correct": correct,
- }
- )
- # Sort the predictions by the uncertainty
- predictions = pd.DataFrame(predictions).sort_values(by="Stdev")
- thresholds = []
- quantiles = np.arange(0.1, 1, 0.1)
- # get uncertainty quantiles
- for quantile in quantiles:
- thresholds.append(predictions["Stdev"].quantile(quantile))
- # Calculate the accuracy of the model for each threshold
- accuracies = []
- # Calculate the accuracy of the model for each threshold
- for threshold, quantile in zip(thresholds, quantiles):
- filtered = predictions[predictions["Stdev"] <= threshold]
- correct = filtered["Correct"].sum()
- total = len(filtered)
- accuracy = correct / total
- false_positives = len(
- filtered[(filtered["Prediction"] >= 0.5) & (filtered["Actual"] == 0)]
- )
- false_negatives = len(
- filtered[(filtered["Prediction"] < 0.5) & (filtered["Actual"] == 1)]
- )
- f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
- auc = metrics.roc_auc_score(filtered["Actual"], filtered["Prediction"])
- accuracies.append(
- {
- "Threshold": threshold,
- "Accuracy": accuracy,
- "Quantile": quantile,
- "F1": f1,
- "AUC": auc,
- }
- )
- predictions.to_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
- )
- return pd.DataFrame(accuracies)
- result = threshold(config)
- result.to_csv("coverage.csv")
- result = pd.read_csv("coverage.csv")
- predictions = pd.read_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
- )
- print(result)
- plt.figure()
- plt.plot(result["Quantile"], result["Accuracy"])
- plt.xlabel("Coverage")
- plt.ylabel("Accuracy")
- plt.gca().invert_xaxis()
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
- )
- plt.figure()
- plt.plot(result["Quantile"], result["F1"])
- plt.xlabel("Coverage")
- plt.ylabel("F1")
- plt.gca().invert_xaxis()
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
- )
- plt.figure()
- plt.plot(result["Quantile"], result["AUC"])
- plt.xlabel("Coverage")
- plt.ylabel("AUC")
- plt.gca().invert_xaxis()
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
- )
- # create histogram of the incorrect predictions vs the uncertainty
- plt.figure()
- plt.hist(predictions[~predictions["Correct"]]["Stdev"], bins=10)
- plt.xlabel("Uncertainty")
- plt.ylabel("Number of incorrect predictions")
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
- )
|