|
@@ -10,7 +10,9 @@ from tqdm import tqdm
|
|
import utils.metrics as met
|
|
import utils.metrics as met
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as mtick
|
|
import matplotlib.ticker as mtick
|
|
|
|
+import utils.models.cnn
|
|
|
|
|
|
|
|
+torch.serialization.safe_globals([utils.models.cnn.CNN])
|
|
|
|
|
|
|
|
|
|
# The datastructures for this file are as follows
|
|
# The datastructures for this file are as follows
|
|
@@ -29,11 +31,11 @@ import matplotlib.ticker as mtick
|
|
|
|
|
|
# Loads configuration dictionary
|
|
# Loads configuration dictionary
|
|
def load_config():
|
|
def load_config():
|
|
- if os.getenv('ADL_CONFIG_PATH') is None:
|
|
|
|
- with open('config.toml', 'rb') as f:
|
|
|
|
|
|
+ if os.getenv("ADL_CONFIG_PATH") is None:
|
|
|
|
+ with open("config.toml", "rb") as f:
|
|
config = toml.load(f)
|
|
config = toml.load(f)
|
|
else:
|
|
else:
|
|
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
|
|
|
|
|
|
+ with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
|
|
config = toml.load(f)
|
|
config = toml.load(f)
|
|
|
|
|
|
return config
|
|
return config
|
|
@@ -41,17 +43,20 @@ def load_config():
|
|
|
|
|
|
# Loads models into a dictionary
|
|
# Loads models into a dictionary
|
|
def load_models_v2(folder, device):
|
|
def load_models_v2(folder, device):
|
|
- glob_path = os.path.join(folder, '*.pt')
|
|
|
|
|
|
+ glob_path = os.path.join(folder, "*.pt")
|
|
model_files = glob.glob(glob_path)
|
|
model_files = glob.glob(glob_path)
|
|
model_dict = {}
|
|
model_dict = {}
|
|
|
|
|
|
for model_file in model_files:
|
|
for model_file in model_files:
|
|
|
|
+ with open(model_file, "r") as f:
|
|
|
|
+ print(torch.serialization.get_unsafe_globals_in_checkpoint(f))
|
|
|
|
+
|
|
model = torch.load(model_file, map_location=device)
|
|
model = torch.load(model_file, map_location=device)
|
|
- model_id = os.path.basename(model_file).split('_')[0]
|
|
|
|
|
|
+ model_id = os.path.basename(model_file).split("_")[0]
|
|
model_dict[model_id] = model
|
|
model_dict[model_id] = model
|
|
|
|
|
|
if len(model_dict) == 0:
|
|
if len(model_dict) == 0:
|
|
- raise FileNotFoundError('No models found in the specified directory: ' + folder)
|
|
|
|
|
|
+ raise FileNotFoundError("No models found in the specified directory: " + folder)
|
|
|
|
|
|
return model_dict
|
|
return model_dict
|
|
|
|
|
|
@@ -67,8 +72,8 @@ def preprocess_data(data, device):
|
|
# Loads datasets and returns concatenated test and validation datasets
|
|
# Loads datasets and returns concatenated test and validation datasets
|
|
def load_datasets(ensemble_path):
|
|
def load_datasets(ensemble_path):
|
|
return (
|
|
return (
|
|
- torch.load(f'{ensemble_path}/test_dataset.pt'),
|
|
|
|
- torch.load(f'{ensemble_path}/val_dataset.pt'),
|
|
|
|
|
|
+ torch.load(f"{ensemble_path}/test_dataset.pt"),
|
|
|
|
+ torch.load(f"{ensemble_path}/val_dataset.pt"),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -77,21 +82,21 @@ def get_ensemble_predictions(models, dataset, device, id_offset=0):
|
|
zeros = np.zeros((len(dataset), len(models), 4))
|
|
zeros = np.zeros((len(dataset), len(models), 4))
|
|
predictions = xr.DataArray(
|
|
predictions = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
- dims=('data_id', 'model_id', 'prediction_value'),
|
|
|
|
|
|
+ dims=("data_id", "model_id", "prediction_value"),
|
|
coords={
|
|
coords={
|
|
- 'data_id': range(id_offset, len(dataset) + id_offset),
|
|
|
|
- 'model_id': list(models.keys()),
|
|
|
|
- 'prediction_value': [
|
|
|
|
- 'negative_prediction',
|
|
|
|
- 'positive_prediction',
|
|
|
|
- 'negative_actual',
|
|
|
|
- 'positive_actual',
|
|
|
|
|
|
+ "data_id": range(id_offset, len(dataset) + id_offset),
|
|
|
|
+ "model_id": list(models.keys()),
|
|
|
|
+ "prediction_value": [
|
|
|
|
+ "negative_prediction",
|
|
|
|
+ "positive_prediction",
|
|
|
|
+ "negative_actual",
|
|
|
|
+ "positive_actual",
|
|
],
|
|
],
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
for data_id, (data, target) in tqdm(
|
|
for data_id, (data, target) in tqdm(
|
|
- enumerate(dataset), total=len(dataset), unit='images'
|
|
|
|
|
|
+ enumerate(dataset), total=len(dataset), unit="images"
|
|
):
|
|
):
|
|
dat = preprocess_data(data, device)
|
|
dat = preprocess_data(data, device)
|
|
actual = list(target.cpu().numpy())
|
|
actual = list(target.cpu().numpy())
|
|
@@ -101,8 +106,8 @@ def get_ensemble_predictions(models, dataset, device, id_offset=0):
|
|
prediction = output.cpu().numpy().tolist()[0]
|
|
prediction = output.cpu().numpy().tolist()[0]
|
|
|
|
|
|
predictions.loc[
|
|
predictions.loc[
|
|
- {'data_id': data_id + id_offset, 'model_id': model_id}
|
|
|
|
- ] = prediction + actual
|
|
|
|
|
|
+ {"data_id": data_id + id_offset, "model_id": model_id}
|
|
|
|
+ ] = (prediction + actual)
|
|
|
|
|
|
return predictions
|
|
return predictions
|
|
|
|
|
|
@@ -113,27 +118,27 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
|
|
|
|
|
|
ensemble_statistics = xr.DataArray(
|
|
ensemble_statistics = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
- dims=('data_id', 'statistic'),
|
|
|
|
|
|
+ dims=("data_id", "statistic"),
|
|
coords={
|
|
coords={
|
|
- 'data_id': predictions.data_id,
|
|
|
|
- 'statistic': [
|
|
|
|
- 'mean',
|
|
|
|
- 'stdev',
|
|
|
|
- 'entropy',
|
|
|
|
- 'confidence',
|
|
|
|
- 'correct',
|
|
|
|
- 'predicted',
|
|
|
|
- 'actual',
|
|
|
|
|
|
+ "data_id": predictions.data_id,
|
|
|
|
+ "statistic": [
|
|
|
|
+ "mean",
|
|
|
|
+ "stdev",
|
|
|
|
+ "entropy",
|
|
|
|
+ "confidence",
|
|
|
|
+ "correct",
|
|
|
|
+ "predicted",
|
|
|
|
+ "actual",
|
|
],
|
|
],
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
for data_id in predictions.data_id:
|
|
for data_id in predictions.data_id:
|
|
- data = predictions.loc[{'data_id': data_id}]
|
|
|
|
- mean = data.mean(dim='model_id')[
|
|
|
|
|
|
+ data = predictions.loc[{"data_id": data_id}]
|
|
|
|
+ mean = data.mean(dim="model_id")[
|
|
0:2
|
|
0:2
|
|
] # Only take the predictions, not the actual
|
|
] # Only take the predictions, not the actual
|
|
- stdev = data.std(dim='model_id')[
|
|
|
|
|
|
+ stdev = data.std(dim="model_id")[
|
|
1
|
|
1
|
|
] # Only need the standard deviation of the postive prediction
|
|
] # Only need the standard deviation of the postive prediction
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
@@ -142,11 +147,11 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
|
|
confidence = mean.max()
|
|
confidence = mean.max()
|
|
|
|
|
|
# only need one of the actual values, since they are all the same, just get the first actual_positive
|
|
# only need one of the actual values, since they are all the same, just get the first actual_positive
|
|
- actual = data.loc[{'prediction_value': 'positive_actual'}][0]
|
|
|
|
|
|
+ actual = data.loc[{"prediction_value": "positive_actual"}][0]
|
|
predicted = mean.argmax()
|
|
predicted = mean.argmax()
|
|
correct = actual == predicted
|
|
correct = actual == predicted
|
|
|
|
|
|
- ensemble_statistics.loc[{'data_id': data_id}] = [
|
|
|
|
|
|
+ ensemble_statistics.loc[{"data_id": data_id}] = [
|
|
mean[1],
|
|
mean[1],
|
|
stdev,
|
|
stdev,
|
|
entropy,
|
|
entropy,
|
|
@@ -162,15 +167,15 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
|
|
# Compute the thresholded predictions given an array of predictions
|
|
# Compute the thresholded predictions given an array of predictions
|
|
def compute_thresholded_predictions(input_stats: xr.DataArray):
|
|
def compute_thresholded_predictions(input_stats: xr.DataArray):
|
|
quantiles = np.linspace(0.00, 1.00, 21) * 100
|
|
quantiles = np.linspace(0.00, 1.00, 21) * 100
|
|
- metrics = ['accuracy', 'f1']
|
|
|
|
- statistics = ['stdev', 'entropy', 'confidence']
|
|
|
|
|
|
+ metrics = ["accuracy", "f1"]
|
|
|
|
+ statistics = ["stdev", "entropy", "confidence"]
|
|
|
|
|
|
zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
|
|
zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
|
|
|
|
|
|
thresholded_predictions = xr.DataArray(
|
|
thresholded_predictions = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
- dims=('quantile', 'statistic', 'metric'),
|
|
|
|
- coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
|
|
|
|
|
|
+ dims=("quantile", "statistic", "metric"),
|
|
|
|
+ coords={"quantile": quantiles, "statistic": statistics, "metric": metrics},
|
|
)
|
|
)
|
|
|
|
|
|
for statistic in statistics:
|
|
for statistic in statistics:
|
|
@@ -197,7 +202,7 @@ def compute_thresholded_predictions(input_stats: xr.DataArray):
|
|
|
|
|
|
for metric in metrics:
|
|
for metric in metrics:
|
|
thresholded_predictions.loc[
|
|
thresholded_predictions.loc[
|
|
- {'quantile': quantile, 'statistic': statistic, 'metric': metric}
|
|
|
|
|
|
+ {"quantile": quantile, "statistic": statistic, "metric": metric}
|
|
] = compute_metric(filtered_data, metric)
|
|
] = compute_metric(filtered_data, metric)
|
|
|
|
|
|
return thresholded_predictions
|
|
return thresholded_predictions
|
|
@@ -208,26 +213,26 @@ def compute_thresholded_predictions(input_stats: xr.DataArray):
|
|
# So we threshold confidence low to high, entropy and stdev high to low
|
|
# So we threshold confidence low to high, entropy and stdev high to low
|
|
# So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
|
|
# So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
|
|
def low_to_high(stat):
|
|
def low_to_high(stat):
|
|
- return stat in ['confidence']
|
|
|
|
|
|
+ return stat in ["confidence"]
|
|
|
|
|
|
|
|
|
|
# Compute a given metric on a DataArray of statstics
|
|
# Compute a given metric on a DataArray of statstics
|
|
def compute_metric(arr, metric):
|
|
def compute_metric(arr, metric):
|
|
- if metric == 'accuracy':
|
|
|
|
- return np.mean(arr.loc[{'statistic': 'correct'}])
|
|
|
|
- elif metric == 'f1':
|
|
|
|
|
|
+ if metric == "accuracy":
|
|
|
|
+ return np.mean(arr.loc[{"statistic": "correct"}])
|
|
|
|
+ elif metric == "f1":
|
|
return met.F1(
|
|
return met.F1(
|
|
- arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
|
|
|
|
|
|
+ arr.loc[{"statistic": "predicted"}], arr.loc[{"statistic": "actual"}]
|
|
)
|
|
)
|
|
- elif metric == 'ece':
|
|
|
|
- true_labels = arr.loc[{'statistic': 'actual'}].values
|
|
|
|
- predicted_labels = arr.loc[{'statistic': 'predicted'}].values
|
|
|
|
- confidences = arr.loc[{'statistic': 'confidence'}].values
|
|
|
|
|
|
+ elif metric == "ece":
|
|
|
|
+ true_labels = arr.loc[{"statistic": "actual"}].values
|
|
|
|
+ predicted_labels = arr.loc[{"statistic": "predicted"}].values
|
|
|
|
+ confidences = arr.loc[{"statistic": "confidence"}].values
|
|
|
|
|
|
return calculate_ece_stats(confidences, predicted_labels, true_labels)
|
|
return calculate_ece_stats(confidences, predicted_labels, true_labels)
|
|
|
|
|
|
else:
|
|
else:
|
|
- raise ValueError('Invalid metric: ' + metric)
|
|
|
|
|
|
+ raise ValueError("Invalid metric: " + metric)
|
|
|
|
|
|
|
|
|
|
# Graph a thresholded prediction for a given statistic and metric
|
|
# Graph a thresholded prediction for a given statistic and metric
|
|
@@ -236,11 +241,11 @@ def graph_thresholded_prediction(
|
|
):
|
|
):
|
|
data = thresholded_predictions.sel(statistic=statistic, metric=metric)
|
|
data = thresholded_predictions.sel(statistic=statistic, metric=metric)
|
|
|
|
|
|
- x_data = data.coords['quantile'].values
|
|
|
|
|
|
+ x_data = data.coords["quantile"].values
|
|
y_data = data.values
|
|
y_data = data.values
|
|
|
|
|
|
fig, ax = plt.subplots()
|
|
fig, ax = plt.subplots()
|
|
- ax.plot(x_data, y_data, 'bx-', label='Ensemble')
|
|
|
|
|
|
+ ax.plot(x_data, y_data, "bx-", label="Ensemble")
|
|
ax.set_title(title)
|
|
ax.set_title(title)
|
|
ax.set_xlabel(xlabel)
|
|
ax.set_xlabel(xlabel)
|
|
ax.set_ylabel(ylabel)
|
|
ax.set_ylabel(ylabel)
|
|
@@ -257,68 +262,68 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
# Confidence Accuracy
|
|
# Confidence Accuracy
|
|
graph_thresholded_prediction(
|
|
graph_thresholded_prediction(
|
|
thresholded_predictions,
|
|
thresholded_predictions,
|
|
- 'confidence',
|
|
|
|
- 'accuracy',
|
|
|
|
- f'{save_path}/confidence_accuracy.png',
|
|
|
|
- 'Coverage Analysis of Confidence vs. Accuracy',
|
|
|
|
- 'Minimum Confidence Percentile Threshold',
|
|
|
|
- 'Accuracy',
|
|
|
|
|
|
+ "confidence",
|
|
|
|
+ "accuracy",
|
|
|
|
+ f"{save_path}/confidence_accuracy.png",
|
|
|
|
+ "Coverage Analysis of Confidence vs. Accuracy",
|
|
|
|
+ "Minimum Confidence Percentile Threshold",
|
|
|
|
+ "Accuracy",
|
|
)
|
|
)
|
|
|
|
|
|
# Confidence F1
|
|
# Confidence F1
|
|
graph_thresholded_prediction(
|
|
graph_thresholded_prediction(
|
|
thresholded_predictions,
|
|
thresholded_predictions,
|
|
- 'confidence',
|
|
|
|
- 'f1',
|
|
|
|
- f'{save_path}/confidence_f1.png',
|
|
|
|
- 'Coverage Analysis of Confidence vs. F1 Score',
|
|
|
|
- 'Minimum Confidence Percentile Threshold',
|
|
|
|
- 'F1 Score',
|
|
|
|
|
|
+ "confidence",
|
|
|
|
+ "f1",
|
|
|
|
+ f"{save_path}/confidence_f1.png",
|
|
|
|
+ "Coverage Analysis of Confidence vs. F1 Score",
|
|
|
|
+ "Minimum Confidence Percentile Threshold",
|
|
|
|
+ "F1 Score",
|
|
)
|
|
)
|
|
|
|
|
|
# Entropy Accuracy
|
|
# Entropy Accuracy
|
|
graph_thresholded_prediction(
|
|
graph_thresholded_prediction(
|
|
thresholded_predictions,
|
|
thresholded_predictions,
|
|
- 'entropy',
|
|
|
|
- 'accuracy',
|
|
|
|
- f'{save_path}/entropy_accuracy.png',
|
|
|
|
- 'Coverage Analysis of Entropy vs. Accuracy',
|
|
|
|
- 'Maximum Entropy Percentile Threshold',
|
|
|
|
- 'Accuracy',
|
|
|
|
|
|
+ "entropy",
|
|
|
|
+ "accuracy",
|
|
|
|
+ f"{save_path}/entropy_accuracy.png",
|
|
|
|
+ "Coverage Analysis of Entropy vs. Accuracy",
|
|
|
|
+ "Maximum Entropy Percentile Threshold",
|
|
|
|
+ "Accuracy",
|
|
)
|
|
)
|
|
|
|
|
|
# Entropy F1
|
|
# Entropy F1
|
|
|
|
|
|
graph_thresholded_prediction(
|
|
graph_thresholded_prediction(
|
|
thresholded_predictions,
|
|
thresholded_predictions,
|
|
- 'entropy',
|
|
|
|
- 'f1',
|
|
|
|
- f'{save_path}/entropy_f1.png',
|
|
|
|
- 'Coverage Analysis of Entropy vs. F1 Score',
|
|
|
|
- 'Maximum Entropy Percentile Threshold',
|
|
|
|
- 'F1 Score',
|
|
|
|
|
|
+ "entropy",
|
|
|
|
+ "f1",
|
|
|
|
+ f"{save_path}/entropy_f1.png",
|
|
|
|
+ "Coverage Analysis of Entropy vs. F1 Score",
|
|
|
|
+ "Maximum Entropy Percentile Threshold",
|
|
|
|
+ "F1 Score",
|
|
)
|
|
)
|
|
|
|
|
|
# Stdev Accuracy
|
|
# Stdev Accuracy
|
|
graph_thresholded_prediction(
|
|
graph_thresholded_prediction(
|
|
thresholded_predictions,
|
|
thresholded_predictions,
|
|
- 'stdev',
|
|
|
|
- 'accuracy',
|
|
|
|
- f'{save_path}/stdev_accuracy.png',
|
|
|
|
- 'Coverage Analysis of Standard Deviation vs. Accuracy',
|
|
|
|
- 'Maximum Standard Deviation Percentile Threshold',
|
|
|
|
- 'Accuracy',
|
|
|
|
|
|
+ "stdev",
|
|
|
|
+ "accuracy",
|
|
|
|
+ f"{save_path}/stdev_accuracy.png",
|
|
|
|
+ "Coverage Analysis of Standard Deviation vs. Accuracy",
|
|
|
|
+ "Maximum Standard Deviation Percentile Threshold",
|
|
|
|
+ "Accuracy",
|
|
)
|
|
)
|
|
|
|
|
|
# Stdev F1
|
|
# Stdev F1
|
|
graph_thresholded_prediction(
|
|
graph_thresholded_prediction(
|
|
thresholded_predictions,
|
|
thresholded_predictions,
|
|
- 'stdev',
|
|
|
|
- 'f1',
|
|
|
|
- f'{save_path}/stdev_f1.png',
|
|
|
|
- 'Coverage Analysis of Standard Deviation vs. F1 Score',
|
|
|
|
- 'Maximum Standard Deviation Percentile Threshold',
|
|
|
|
- 'F1',
|
|
|
|
|
|
+ "stdev",
|
|
|
|
+ "f1",
|
|
|
|
+ f"{save_path}/stdev_f1.png",
|
|
|
|
+ "Coverage Analysis of Standard Deviation vs. F1 Score",
|
|
|
|
+ "Maximum Standard Deviation Percentile Threshold",
|
|
|
|
+ "F1",
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -326,13 +331,13 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
|
|
def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
|
|
# Filter for correct predictions
|
|
# Filter for correct predictions
|
|
c_stats = stats.where(
|
|
c_stats = stats.where(
|
|
- stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
|
|
|
|
|
|
+ stats.data_id.isin(np.where((stats.sel(statistic="correct") == 1).values)),
|
|
drop=True,
|
|
drop=True,
|
|
)
|
|
)
|
|
|
|
|
|
# Filter for incorrect predictions
|
|
# Filter for incorrect predictions
|
|
i_stats = stats.where(
|
|
i_stats = stats.where(
|
|
- stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
|
|
|
|
|
|
+ stats.data_id.isin(np.where((stats.sel(statistic="correct") == 0).values)),
|
|
drop=True,
|
|
drop=True,
|
|
)
|
|
)
|
|
|
|
|
|
@@ -344,8 +349,8 @@ def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
|
|
y_data_i = i_stats.sel(statistic=y_stat).values
|
|
y_data_i = i_stats.sel(statistic=y_stat).values
|
|
|
|
|
|
fig, ax = plt.subplots()
|
|
fig, ax = plt.subplots()
|
|
- ax.plot(x_data_c, y_data_c, 'go', label='Correct')
|
|
|
|
- ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
|
|
|
|
|
|
+ ax.plot(x_data_c, y_data_c, "go", label="Correct")
|
|
|
|
+ ax.plot(x_data_i, y_data_i, "ro", label="Incorrect")
|
|
ax.set_title(title)
|
|
ax.set_title(title)
|
|
ax.set_xlabel(xlabel)
|
|
ax.set_xlabel(xlabel)
|
|
ax.set_ylabel(ylabel)
|
|
ax.set_ylabel(ylabel)
|
|
@@ -365,26 +370,26 @@ def compute_individual_statistics(predictions: xr.DataArray):
|
|
|
|
|
|
indv_statistics = xr.DataArray(
|
|
indv_statistics = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
- dims=('data_id', 'model_id', 'statistic'),
|
|
|
|
|
|
+ dims=("data_id", "model_id", "statistic"),
|
|
coords={
|
|
coords={
|
|
- 'data_id': predictions.data_id,
|
|
|
|
- 'model_id': predictions.model_id,
|
|
|
|
- 'statistic': [
|
|
|
|
- 'mean',
|
|
|
|
- 'entropy',
|
|
|
|
- 'confidence',
|
|
|
|
- 'correct',
|
|
|
|
- 'predicted',
|
|
|
|
- 'actual',
|
|
|
|
|
|
+ "data_id": predictions.data_id,
|
|
|
|
+ "model_id": predictions.model_id,
|
|
|
|
+ "statistic": [
|
|
|
|
+ "mean",
|
|
|
|
+ "entropy",
|
|
|
|
+ "confidence",
|
|
|
|
+ "correct",
|
|
|
|
+ "predicted",
|
|
|
|
+ "actual",
|
|
],
|
|
],
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
for data_id in tqdm(
|
|
for data_id in tqdm(
|
|
- predictions.data_id, total=len(predictions.data_id), unit='images'
|
|
|
|
|
|
+ predictions.data_id, total=len(predictions.data_id), unit="images"
|
|
):
|
|
):
|
|
for model_id in predictions.model_id:
|
|
for model_id in predictions.model_id:
|
|
- data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
|
|
|
|
|
|
+ data = predictions.loc[{"data_id": data_id, "model_id": model_id}]
|
|
mean = data[0:2]
|
|
mean = data[0:2]
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
confidence = mean.max()
|
|
confidence = mean.max()
|
|
@@ -392,7 +397,7 @@ def compute_individual_statistics(predictions: xr.DataArray):
|
|
predicted = mean.argmax()
|
|
predicted = mean.argmax()
|
|
correct = actual == predicted
|
|
correct = actual == predicted
|
|
|
|
|
|
- indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
|
|
|
|
|
|
+ indv_statistics.loc[{"data_id": data_id, "model_id": model_id}] = [
|
|
mean[1],
|
|
mean[1],
|
|
entropy,
|
|
entropy,
|
|
confidence,
|
|
confidence,
|
|
@@ -407,8 +412,8 @@ def compute_individual_statistics(predictions: xr.DataArray):
|
|
# Compute individual model thresholds
|
|
# Compute individual model thresholds
|
|
def compute_individual_thresholds(input_stats: xr.DataArray):
|
|
def compute_individual_thresholds(input_stats: xr.DataArray):
|
|
quantiles = np.linspace(0.05, 0.95, 19) * 100
|
|
quantiles = np.linspace(0.05, 0.95, 19) * 100
|
|
- metrics = ['accuracy', 'f1']
|
|
|
|
- statistics = ['entropy', 'confidence']
|
|
|
|
|
|
+ metrics = ["accuracy", "f1"]
|
|
|
|
+ statistics = ["entropy", "confidence"]
|
|
|
|
|
|
zeros = np.zeros(
|
|
zeros = np.zeros(
|
|
(len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
|
|
(len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
|
|
@@ -416,17 +421,17 @@ def compute_individual_thresholds(input_stats: xr.DataArray):
|
|
|
|
|
|
indv_thresholds = xr.DataArray(
|
|
indv_thresholds = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
- dims=('model_id', 'quantile', 'statistic', 'metric'),
|
|
|
|
|
|
+ dims=("model_id", "quantile", "statistic", "metric"),
|
|
coords={
|
|
coords={
|
|
- 'model_id': input_stats.model_id,
|
|
|
|
- 'quantile': quantiles,
|
|
|
|
- 'statistic': statistics,
|
|
|
|
- 'metric': metrics,
|
|
|
|
|
|
+ "model_id": input_stats.model_id,
|
|
|
|
+ "quantile": quantiles,
|
|
|
|
+ "statistic": statistics,
|
|
|
|
+ "metric": metrics,
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
for model_id in tqdm(
|
|
for model_id in tqdm(
|
|
- input_stats.model_id, total=len(input_stats.model_id), unit='models'
|
|
|
|
|
|
+ input_stats.model_id, total=len(input_stats.model_id), unit="models"
|
|
):
|
|
):
|
|
for statistic in statistics:
|
|
for statistic in statistics:
|
|
# First, we must compute the quantiles for the statistic
|
|
# First, we must compute the quantiles for the statistic
|
|
@@ -457,10 +462,10 @@ def compute_individual_thresholds(input_stats: xr.DataArray):
|
|
for metric in metrics:
|
|
for metric in metrics:
|
|
indv_thresholds.loc[
|
|
indv_thresholds.loc[
|
|
{
|
|
{
|
|
- 'model_id': model_id,
|
|
|
|
- 'quantile': quantile,
|
|
|
|
- 'statistic': statistic,
|
|
|
|
- 'metric': metric,
|
|
|
|
|
|
+ "model_id": model_id,
|
|
|
|
+ "quantile": quantile,
|
|
|
|
+ "statistic": statistic,
|
|
|
|
+ "metric": metric,
|
|
}
|
|
}
|
|
] = compute_metric(filtered_data, metric)
|
|
] = compute_metric(filtered_data, metric)
|
|
|
|
|
|
@@ -481,18 +486,18 @@ def graph_individual_thresholded_predictions(
|
|
data = indv_thresholds.sel(statistic=statistic, metric=metric)
|
|
data = indv_thresholds.sel(statistic=statistic, metric=metric)
|
|
e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
|
|
e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
|
|
|
|
|
|
- x_data = data.coords['quantile'].values
|
|
|
|
|
|
+ x_data = data.coords["quantile"].values
|
|
y_data = data.values
|
|
y_data = data.values
|
|
|
|
|
|
- e_x_data = e_data.coords['quantile'].values
|
|
|
|
|
|
+ e_x_data = e_data.coords["quantile"].values
|
|
e_y_data = e_data.values
|
|
e_y_data = e_data.values
|
|
|
|
|
|
fig, ax = plt.subplots()
|
|
fig, ax = plt.subplots()
|
|
- for model_id in data.coords['model_id'].values:
|
|
|
|
|
|
+ for model_id in data.coords["model_id"].values:
|
|
model_data = data.sel(model_id=model_id)
|
|
model_data = data.sel(model_id=model_id)
|
|
ax.plot(x_data, model_data)
|
|
ax.plot(x_data, model_data)
|
|
|
|
|
|
- ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
|
|
|
|
|
|
+ ax.plot(e_x_data, e_y_data, "kx-", label="Ensemble")
|
|
|
|
|
|
ax.set_title(title)
|
|
ax.set_title(title)
|
|
ax.set_xlabel(xlabel)
|
|
ax.set_xlabel(xlabel)
|
|
@@ -514,48 +519,48 @@ def graph_all_individual_thresholded_predictions(
|
|
graph_individual_thresholded_predictions(
|
|
graph_individual_thresholded_predictions(
|
|
indv_thresholds,
|
|
indv_thresholds,
|
|
ensemble_thresholds,
|
|
ensemble_thresholds,
|
|
- 'confidence',
|
|
|
|
- 'accuracy',
|
|
|
|
- f'{save_path}/indv/confidence_accuracy.png',
|
|
|
|
- 'Coverage Analysis of Confidence vs. Accuracy for All Models',
|
|
|
|
- 'Minumum Confidence Percentile Threshold',
|
|
|
|
- 'Accuracy',
|
|
|
|
|
|
+ "confidence",
|
|
|
|
+ "accuracy",
|
|
|
|
+ f"{save_path}/indv/confidence_accuracy.png",
|
|
|
|
+ "Coverage Analysis of Confidence vs. Accuracy for All Models",
|
|
|
|
+ "Minumum Confidence Percentile Threshold",
|
|
|
|
+ "Accuracy",
|
|
)
|
|
)
|
|
|
|
|
|
# Confidence F1
|
|
# Confidence F1
|
|
graph_individual_thresholded_predictions(
|
|
graph_individual_thresholded_predictions(
|
|
indv_thresholds,
|
|
indv_thresholds,
|
|
ensemble_thresholds,
|
|
ensemble_thresholds,
|
|
- 'confidence',
|
|
|
|
- 'f1',
|
|
|
|
- f'{save_path}/indv/confidence_f1.png',
|
|
|
|
- 'Coverage Analysis of Confidence vs. F1 Score for All Models',
|
|
|
|
- 'Minimum Confidence Percentile Threshold',
|
|
|
|
- 'F1 Score',
|
|
|
|
|
|
+ "confidence",
|
|
|
|
+ "f1",
|
|
|
|
+ f"{save_path}/indv/confidence_f1.png",
|
|
|
|
+ "Coverage Analysis of Confidence vs. F1 Score for All Models",
|
|
|
|
+ "Minimum Confidence Percentile Threshold",
|
|
|
|
+ "F1 Score",
|
|
)
|
|
)
|
|
|
|
|
|
# Entropy Accuracy
|
|
# Entropy Accuracy
|
|
graph_individual_thresholded_predictions(
|
|
graph_individual_thresholded_predictions(
|
|
indv_thresholds,
|
|
indv_thresholds,
|
|
ensemble_thresholds,
|
|
ensemble_thresholds,
|
|
- 'entropy',
|
|
|
|
- 'accuracy',
|
|
|
|
- f'{save_path}/indv/entropy_accuracy.png',
|
|
|
|
- 'Coverage Analysis of Entropy vs. Accuracy for All Models',
|
|
|
|
- 'Maximum Entropy Percentile Threshold',
|
|
|
|
- 'Accuracy',
|
|
|
|
|
|
+ "entropy",
|
|
|
|
+ "accuracy",
|
|
|
|
+ f"{save_path}/indv/entropy_accuracy.png",
|
|
|
|
+ "Coverage Analysis of Entropy vs. Accuracy for All Models",
|
|
|
|
+ "Maximum Entropy Percentile Threshold",
|
|
|
|
+ "Accuracy",
|
|
)
|
|
)
|
|
|
|
|
|
# Entropy F1
|
|
# Entropy F1
|
|
graph_individual_thresholded_predictions(
|
|
graph_individual_thresholded_predictions(
|
|
indv_thresholds,
|
|
indv_thresholds,
|
|
ensemble_thresholds,
|
|
ensemble_thresholds,
|
|
- 'entropy',
|
|
|
|
- 'f1',
|
|
|
|
- f'{save_path}/indv/entropy_f1.png',
|
|
|
|
- 'Coverage Analysis of Entropy vs. F1 Score for All Models',
|
|
|
|
- 'Maximum Entropy Percentile Threshold',
|
|
|
|
- 'F1 Score',
|
|
|
|
|
|
+ "entropy",
|
|
|
|
+ "f1",
|
|
|
|
+ f"{save_path}/indv/entropy_f1.png",
|
|
|
|
+ "Coverage Analysis of Entropy vs. F1 Score for All Models",
|
|
|
|
+ "Maximum Entropy Percentile Threshold",
|
|
|
|
+ "F1 Score",
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -570,38 +575,38 @@ def calculate_subset_statistics(predictions: xr.DataArray):
|
|
|
|
|
|
subset_stats = xr.DataArray(
|
|
subset_stats = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
- dims=('data_id', 'model_count', 'statistic'),
|
|
|
|
|
|
+ dims=("data_id", "model_count", "statistic"),
|
|
coords={
|
|
coords={
|
|
- 'data_id': predictions.data_id,
|
|
|
|
- 'model_count': subsets,
|
|
|
|
- 'statistic': [
|
|
|
|
- 'mean',
|
|
|
|
- 'stdev',
|
|
|
|
- 'entropy',
|
|
|
|
- 'confidence',
|
|
|
|
- 'correct',
|
|
|
|
- 'predicted',
|
|
|
|
- 'actual',
|
|
|
|
|
|
+ "data_id": predictions.data_id,
|
|
|
|
+ "model_count": subsets,
|
|
|
|
+ "statistic": [
|
|
|
|
+ "mean",
|
|
|
|
+ "stdev",
|
|
|
|
+ "entropy",
|
|
|
|
+ "confidence",
|
|
|
|
+ "correct",
|
|
|
|
+ "predicted",
|
|
|
|
+ "actual",
|
|
],
|
|
],
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
for data_id in tqdm(
|
|
for data_id in tqdm(
|
|
- predictions.data_id, total=len(predictions.data_id), unit='images'
|
|
|
|
|
|
+ predictions.data_id, total=len(predictions.data_id), unit="images"
|
|
):
|
|
):
|
|
for subset in subsets:
|
|
for subset in subsets:
|
|
data = predictions.sel(
|
|
data = predictions.sel(
|
|
data_id=data_id, model_id=predictions.model_id[:subset]
|
|
data_id=data_id, model_id=predictions.model_id[:subset]
|
|
)
|
|
)
|
|
- mean = data.mean(dim='model_id')[0:2]
|
|
|
|
- stdev = data.std(dim='model_id')[1]
|
|
|
|
|
|
+ mean = data.mean(dim="model_id")[0:2]
|
|
|
|
+ stdev = data.std(dim="model_id")[1]
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
confidence = mean.max()
|
|
confidence = mean.max()
|
|
actual = data[0][3]
|
|
actual = data[0][3]
|
|
predicted = mean.argmax()
|
|
predicted = mean.argmax()
|
|
correct = actual == predicted
|
|
correct = actual == predicted
|
|
|
|
|
|
- subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
|
|
|
|
|
|
+ subset_stats.loc[{"data_id": data_id, "model_count": subset}] = [
|
|
mean[1],
|
|
mean[1],
|
|
stdev,
|
|
stdev,
|
|
entropy,
|
|
entropy,
|
|
@@ -617,24 +622,24 @@ def calculate_subset_statistics(predictions: xr.DataArray):
|
|
# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
|
|
# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
|
|
def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
|
|
def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
|
|
subsets = subset_stats.model_count
|
|
subsets = subset_stats.model_count
|
|
- stats = ['accuracy', 'f1', 'ece']
|
|
|
|
|
|
+ stats = ["accuracy", "f1", "ece"]
|
|
|
|
|
|
zeros = np.zeros((len(subsets), len(stats)))
|
|
zeros = np.zeros((len(subsets), len(stats)))
|
|
|
|
|
|
sens_analysis = xr.DataArray(
|
|
sens_analysis = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
- dims=('model_count', 'statistic'),
|
|
|
|
- coords={'model_count': subsets, 'statistic': stats},
|
|
|
|
|
|
+ dims=("model_count", "statistic"),
|
|
|
|
+ coords={"model_count": subsets, "statistic": stats},
|
|
)
|
|
)
|
|
|
|
|
|
- for subset in tqdm(subsets, total=len(subsets), unit='model subsets'):
|
|
|
|
|
|
+ for subset in tqdm(subsets, total=len(subsets), unit="model subsets"):
|
|
|
|
|
|
data = subset_stats.sel(model_count=subset)
|
|
data = subset_stats.sel(model_count=subset)
|
|
- acc = compute_metric(data, 'accuracy').item()
|
|
|
|
- f1 = compute_metric(data, 'f1').item()
|
|
|
|
- ece = compute_metric(data, 'ece').item()
|
|
|
|
|
|
+ acc = compute_metric(data, "accuracy").item()
|
|
|
|
+ f1 = compute_metric(data, "f1").item()
|
|
|
|
+ ece = compute_metric(data, "ece").item()
|
|
|
|
|
|
- sens_analysis.loc[{'model_count': subset.item()}] = [acc, f1, ece]
|
|
|
|
|
|
+ sens_analysis.loc[{"model_count": subset.item()}] = [acc, f1, ece]
|
|
|
|
|
|
return sens_analysis
|
|
return sens_analysis
|
|
|
|
|
|
@@ -644,7 +649,7 @@ def graph_sensitivity_analysis(
|
|
):
|
|
):
|
|
data = sens_analysis.sel(statistic=statistic)
|
|
data = sens_analysis.sel(statistic=statistic)
|
|
|
|
|
|
- xdata = data.coords['model_count'].values
|
|
|
|
|
|
+ xdata = data.coords["model_count"].values
|
|
ydata = data.values
|
|
ydata = data.values
|
|
|
|
|
|
fig, ax = plt.subplots()
|
|
fig, ax = plt.subplots()
|
|
@@ -657,10 +662,10 @@ def graph_sensitivity_analysis(
|
|
|
|
|
|
|
|
|
|
def calculate_overall_stats(ensemble_statistics: xr.DataArray):
|
|
def calculate_overall_stats(ensemble_statistics: xr.DataArray):
|
|
- accuracy = compute_metric(ensemble_statistics, 'accuracy')
|
|
|
|
- f1 = compute_metric(ensemble_statistics, 'f1')
|
|
|
|
|
|
+ accuracy = compute_metric(ensemble_statistics, "accuracy")
|
|
|
|
+ f1 = compute_metric(ensemble_statistics, "f1")
|
|
|
|
|
|
- return {'accuracy': accuracy.item(), 'f1': f1.item()}
|
|
|
|
|
|
+ return {"accuracy": accuracy.item(), "f1": f1.item()}
|
|
|
|
|
|
|
|
|
|
# https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
|
|
# https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
|
|
@@ -693,130 +698,130 @@ def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
|
|
|
|
|
|
# Main Function
|
|
# Main Function
|
|
def main():
|
|
def main():
|
|
- print('Loading Config...')
|
|
|
|
|
|
+ print("Loading Config...")
|
|
config = load_config()
|
|
config = load_config()
|
|
ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
|
|
ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
|
|
- V4_PATH = ENSEMBLE_PATH + '/v4'
|
|
|
|
|
|
+ V4_PATH = ENSEMBLE_PATH + "/v4"
|
|
|
|
|
|
if not os.path.exists(V4_PATH):
|
|
if not os.path.exists(V4_PATH):
|
|
os.makedirs(V4_PATH)
|
|
os.makedirs(V4_PATH)
|
|
- print('Config Loaded')
|
|
|
|
|
|
+ print("Config Loaded")
|
|
|
|
|
|
# Load Datasets
|
|
# Load Datasets
|
|
- print('Loading Datasets...')
|
|
|
|
|
|
+ print("Loading Datasets...")
|
|
(test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
|
|
(test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
|
|
- print('Datasets Loaded')
|
|
|
|
|
|
+ print("Datasets Loaded")
|
|
|
|
|
|
# Get Predictions, either by running the models or loading them from a file
|
|
# Get Predictions, either by running the models or loading them from a file
|
|
- if config['ensemble']['run_models']:
|
|
|
|
|
|
+ if config["ensemble"]["run_models"]:
|
|
# Load Models
|
|
# Load Models
|
|
- print('Loading Models...')
|
|
|
|
- device = torch.device(config['training']['device'])
|
|
|
|
- models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
|
|
|
|
- print('Models Loaded')
|
|
|
|
|
|
+ print("Loading Models...")
|
|
|
|
+ device = torch.device(config["training"]["device"])
|
|
|
|
+ models = load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
|
|
|
|
+ print("Models Loaded")
|
|
|
|
|
|
# Get Predictions
|
|
# Get Predictions
|
|
- print('Getting Predictions...')
|
|
|
|
|
|
+ print("Getting Predictions...")
|
|
test_predictions = get_ensemble_predictions(models, test_dataset, device)
|
|
test_predictions = get_ensemble_predictions(models, test_dataset, device)
|
|
val_predictions = get_ensemble_predictions(
|
|
val_predictions = get_ensemble_predictions(
|
|
models, val_dataset, device, len(test_dataset)
|
|
models, val_dataset, device, len(test_dataset)
|
|
)
|
|
)
|
|
- print('Predictions Loaded')
|
|
|
|
|
|
+ print("Predictions Loaded")
|
|
|
|
|
|
# Save Prediction
|
|
# Save Prediction
|
|
- test_predictions.to_netcdf(f'{V4_PATH}/test_predictions.nc')
|
|
|
|
- val_predictions.to_netcdf(f'{V4_PATH}/val_predictions.nc')
|
|
|
|
|
|
+ test_predictions.to_netcdf(f"{V4_PATH}/test_predictions.nc")
|
|
|
|
+ val_predictions.to_netcdf(f"{V4_PATH}/val_predictions.nc")
|
|
|
|
|
|
else:
|
|
else:
|
|
- test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
|
|
|
|
- val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
|
|
|
|
|
|
+ test_predictions = xr.open_dataarray(f"{V4_PATH}/test_predictions.nc")
|
|
|
|
+ val_predictions = xr.open_dataarray(f"{V4_PATH}/val_predictions.nc")
|
|
|
|
|
|
# Prune Data
|
|
# Prune Data
|
|
- print('Pruning Data...')
|
|
|
|
- if config['operation']['exclude_blank_ids']:
|
|
|
|
- excluded_data_ids = config['ensemble']['excluded_ids']
|
|
|
|
|
|
+ print("Pruning Data...")
|
|
|
|
+ if config["operation"]["exclude_blank_ids"]:
|
|
|
|
+ excluded_data_ids = config["ensemble"]["excluded_ids"]
|
|
test_predictions = prune_data(test_predictions, excluded_data_ids)
|
|
test_predictions = prune_data(test_predictions, excluded_data_ids)
|
|
val_predictions = prune_data(val_predictions, excluded_data_ids)
|
|
val_predictions = prune_data(val_predictions, excluded_data_ids)
|
|
|
|
|
|
# Concatenate Predictions
|
|
# Concatenate Predictions
|
|
- predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
|
|
|
|
|
|
+ predictions = xr.concat([test_predictions, val_predictions], dim="data_id")
|
|
|
|
|
|
# Compute Ensemble Statistics
|
|
# Compute Ensemble Statistics
|
|
- print('Computing Ensemble Statistics...')
|
|
|
|
|
|
+ print("Computing Ensemble Statistics...")
|
|
ensemble_statistics = compute_ensemble_statistics(predictions)
|
|
ensemble_statistics = compute_ensemble_statistics(predictions)
|
|
- ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
|
|
|
|
- print('Ensemble Statistics Computed')
|
|
|
|
|
|
+ ensemble_statistics.to_netcdf(f"{V4_PATH}/ensemble_statistics.nc")
|
|
|
|
+ print("Ensemble Statistics Computed")
|
|
|
|
|
|
# Compute Thresholded Predictions
|
|
# Compute Thresholded Predictions
|
|
- print('Computing Thresholded Predictions...')
|
|
|
|
|
|
+ print("Computing Thresholded Predictions...")
|
|
thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
|
|
thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
|
|
- thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
|
|
|
|
- print('Thresholded Predictions Computed')
|
|
|
|
|
|
+ thresholded_predictions.to_netcdf(f"{V4_PATH}/thresholded_predictions.nc")
|
|
|
|
+ print("Thresholded Predictions Computed")
|
|
|
|
|
|
# Graph Thresholded Predictions
|
|
# Graph Thresholded Predictions
|
|
- print('Graphing Thresholded Predictions...')
|
|
|
|
|
|
+ print("Graphing Thresholded Predictions...")
|
|
graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
|
|
graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
|
|
- print('Thresholded Predictions Graphed')
|
|
|
|
|
|
+ print("Thresholded Predictions Graphed")
|
|
|
|
|
|
# Additional Graphs
|
|
# Additional Graphs
|
|
- print('Graphing Additional Graphs...')
|
|
|
|
|
|
+ print("Graphing Additional Graphs...")
|
|
# Confidence vs stdev
|
|
# Confidence vs stdev
|
|
graph_statistics(
|
|
graph_statistics(
|
|
ensemble_statistics,
|
|
ensemble_statistics,
|
|
- 'confidence',
|
|
|
|
- 'stdev',
|
|
|
|
- f'{V4_PATH}/confidence_stdev.png',
|
|
|
|
- 'Confidence and Standard Deviation for Predictions',
|
|
|
|
- 'Confidence',
|
|
|
|
- 'Standard Deviation',
|
|
|
|
|
|
+ "confidence",
|
|
|
|
+ "stdev",
|
|
|
|
+ f"{V4_PATH}/confidence_stdev.png",
|
|
|
|
+ "Confidence and Standard Deviation for Predictions",
|
|
|
|
+ "Confidence",
|
|
|
|
+ "Standard Deviation",
|
|
)
|
|
)
|
|
- print('Additional Graphs Graphed')
|
|
|
|
|
|
+ print("Additional Graphs Graphed")
|
|
|
|
|
|
# Compute Individual Statistics
|
|
# Compute Individual Statistics
|
|
- print('Computing Individual Statistics...')
|
|
|
|
|
|
+ print("Computing Individual Statistics...")
|
|
indv_statistics = compute_individual_statistics(predictions)
|
|
indv_statistics = compute_individual_statistics(predictions)
|
|
- indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
|
|
|
|
- print('Individual Statistics Computed')
|
|
|
|
|
|
+ indv_statistics.to_netcdf(f"{V4_PATH}/indv_statistics.nc")
|
|
|
|
+ print("Individual Statistics Computed")
|
|
|
|
|
|
# Compute Individual Thresholds
|
|
# Compute Individual Thresholds
|
|
- print('Computing Individual Thresholds...')
|
|
|
|
|
|
+ print("Computing Individual Thresholds...")
|
|
indv_thresholds = compute_individual_thresholds(indv_statistics)
|
|
indv_thresholds = compute_individual_thresholds(indv_statistics)
|
|
- indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
|
|
|
|
- print('Individual Thresholds Computed')
|
|
|
|
|
|
+ indv_thresholds.to_netcdf(f"{V4_PATH}/indv_thresholds.nc")
|
|
|
|
+ print("Individual Thresholds Computed")
|
|
|
|
|
|
# Graph Individual Thresholded Predictions
|
|
# Graph Individual Thresholded Predictions
|
|
- print('Graphing Individual Thresholded Predictions...')
|
|
|
|
- if not os.path.exists(f'{V4_PATH}/indv'):
|
|
|
|
- os.makedirs(f'{V4_PATH}/indv')
|
|
|
|
|
|
+ print("Graphing Individual Thresholded Predictions...")
|
|
|
|
+ if not os.path.exists(f"{V4_PATH}/indv"):
|
|
|
|
+ os.makedirs(f"{V4_PATH}/indv")
|
|
|
|
|
|
graph_all_individual_thresholded_predictions(
|
|
graph_all_individual_thresholded_predictions(
|
|
indv_thresholds, thresholded_predictions, V4_PATH
|
|
indv_thresholds, thresholded_predictions, V4_PATH
|
|
)
|
|
)
|
|
- print('Individual Thresholded Predictions Graphed')
|
|
|
|
|
|
+ print("Individual Thresholded Predictions Graphed")
|
|
|
|
|
|
# Compute subset statistics and graph
|
|
# Compute subset statistics and graph
|
|
- print('Computing Sensitivity Analysis...')
|
|
|
|
|
|
+ print("Computing Sensitivity Analysis...")
|
|
subset_stats = calculate_subset_statistics(predictions)
|
|
subset_stats = calculate_subset_statistics(predictions)
|
|
sens_analysis = calculate_sensitivity_analysis(subset_stats)
|
|
sens_analysis = calculate_sensitivity_analysis(subset_stats)
|
|
graph_sensitivity_analysis(
|
|
graph_sensitivity_analysis(
|
|
sens_analysis,
|
|
sens_analysis,
|
|
- 'accuracy',
|
|
|
|
- f'{V4_PATH}/sens_analysis.png',
|
|
|
|
- 'Sensitivity Analsis of Accuracy vs. # of Models',
|
|
|
|
- '# of Models',
|
|
|
|
- 'Accuracy',
|
|
|
|
|
|
+ "accuracy",
|
|
|
|
+ f"{V4_PATH}/sens_analysis.png",
|
|
|
|
+ "Sensitivity Analsis of Accuracy vs. # of Models",
|
|
|
|
+ "# of Models",
|
|
|
|
+ "Accuracy",
|
|
)
|
|
)
|
|
graph_sensitivity_analysis(
|
|
graph_sensitivity_analysis(
|
|
sens_analysis,
|
|
sens_analysis,
|
|
- 'ece',
|
|
|
|
- f'{V4_PATH}/sens_analysis_ece.png',
|
|
|
|
- 'Sensitivity Analysis of ECE vs. # of Models',
|
|
|
|
- '# of Models',
|
|
|
|
- 'ECE',
|
|
|
|
|
|
+ "ece",
|
|
|
|
+ f"{V4_PATH}/sens_analysis_ece.png",
|
|
|
|
+ "Sensitivity Analysis of ECE vs. # of Models",
|
|
|
|
+ "# of Models",
|
|
|
|
+ "ECE",
|
|
)
|
|
)
|
|
- print(sens_analysis.sel(statistic='accuracy'))
|
|
|
|
|
|
+ print(sens_analysis.sel(statistic="accuracy"))
|
|
print(calculate_overall_stats(ensemble_statistics))
|
|
print(calculate_overall_stats(ensemble_statistics))
|
|
|
|
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
|
|
+if __name__ == "__main__":
|
|
main()
|
|
main()
|