|
@@ -10,7 +10,7 @@ import sklearn.metrics as metrics
|
|
|
from tqdm import tqdm
|
|
|
import utils.metrics as met
|
|
|
|
|
|
-RUN = True
|
|
|
+RUN = False
|
|
|
|
|
|
# CONFIGURATION
|
|
|
if os.getenv('ADL_CONFIG_PATH') is None:
|
|
@@ -21,239 +21,205 @@ else:
|
|
|
config = toml.load(f)
|
|
|
|
|
|
|
|
|
-# This function returns a list of the accuracies given a threshold
|
|
|
-def threshold(config):
|
|
|
- # First, get the model data
|
|
|
- test_set = torch.load(
|
|
|
- config['paths']['model_output']
|
|
|
- + config['ensemble']['name']
|
|
|
- + '/test_dataset.pt'
|
|
|
- )
|
|
|
+ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
|
|
|
|
|
|
- vs = torch.load(
|
|
|
- config['paths']['model_output'] + config['ensemble']['name'] + '/val_dataset.pt'
|
|
|
- )
|
|
|
+V2_PATH = ENSEMBLE_PATH + '/v2'
|
|
|
|
|
|
- test_set = test_set + vs
|
|
|
|
|
|
- models, _ = ens.load_models(
|
|
|
- config['paths']['model_output'] + config['ensemble']['name'] + '/models/',
|
|
|
- config['training']['device'],
|
|
|
- )
|
|
|
+# Result is a 1x2 tensor, with the softmax of the 2 predicted classes
|
|
|
+# Want to convert to a predicted class and a confidence
|
|
|
+def output_to_confidence(result):
|
|
|
+ predicted_class = torch.argmax(result).item()
|
|
|
+ confidence = (torch.max(result).item() - 0.5) * 2
|
|
|
|
|
|
- indv_model = models[0]
|
|
|
-
|
|
|
- predictions = []
|
|
|
- indv_predictions = []
|
|
|
-
|
|
|
- # Evaluate ensemble and uncertainty test set
|
|
|
- for mdata, target in tqdm(test_set, total=len(test_set)):
|
|
|
- mri, xls = mdata
|
|
|
- mri = mri.unsqueeze(0)
|
|
|
- xls = xls.unsqueeze(0)
|
|
|
- mdata = (mri, xls)
|
|
|
- mean, variance = ens.ensemble_predict(models, mdata)
|
|
|
- stdev = torch.sqrt(variance)
|
|
|
- prediction = mean.item()
|
|
|
-
|
|
|
- target = target[1]
|
|
|
-
|
|
|
- # Check if the prediction is correct
|
|
|
- correct = (prediction < 0.5 and int(target.item()) == 0) or (
|
|
|
- prediction >= 0.5 and int(target.item()) == 1
|
|
|
- )
|
|
|
-
|
|
|
- predictions.append(
|
|
|
- {
|
|
|
- 'Prediction': prediction,
|
|
|
- 'Actual': target.item(),
|
|
|
- 'Stdev': stdev.item(),
|
|
|
- 'Correct': correct,
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- i_mean = indv_model(mdata)[:, 1].item()
|
|
|
- i_correct = (i_mean < 0.5 and int(target.item()) == 0) or (
|
|
|
- i_mean >= 0.5 and int(target.item()) == 1
|
|
|
- )
|
|
|
-
|
|
|
- indv_predictions.append(
|
|
|
- {
|
|
|
- 'Prediction': i_mean,
|
|
|
- 'Actual': target.item(),
|
|
|
- 'Stdev': 0,
|
|
|
- 'Correct': i_correct,
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- # Sort the predictions by the uncertainty
|
|
|
- predictions = pd.DataFrame(predictions).sort_values(by='Stdev')
|
|
|
-
|
|
|
- # Calculate the metrics for the individual model
|
|
|
- indv_predictions = pd.DataFrame(indv_predictions)
|
|
|
- indv_correct = indv_predictions['Correct'].sum()
|
|
|
- indv_accuracy = indv_correct / len(indv_predictions)
|
|
|
- indv_false_pos = len(
|
|
|
- indv_predictions[
|
|
|
- (indv_predictions['Prediction'] >= 0.5) & (indv_predictions['Actual'] == 0)
|
|
|
- ]
|
|
|
- )
|
|
|
- indv_false_neg = len(
|
|
|
- indv_predictions[
|
|
|
- (indv_predictions['Prediction'] < 0.5) & (indv_predictions['Actual'] == 1)
|
|
|
- ]
|
|
|
+ return torch.Tensor([predicted_class, confidence])
|
|
|
+
|
|
|
+
|
|
|
+# This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
|
|
|
+def get_predictions(config):
|
|
|
+ models, model_descs = ens.load_models(
|
|
|
+ f'{ENSEMBLE_PATH}/models/',
|
|
|
+ config['training']['device'],
|
|
|
)
|
|
|
- indv_f1 = 2 * indv_correct / (2 * indv_correct + indv_false_pos + indv_false_neg)
|
|
|
- indv_auc = metrics.roc_auc_score(
|
|
|
- indv_predictions['Actual'], indv_predictions['Prediction']
|
|
|
+ models = [model.to(config['training']['device']) for model in models]
|
|
|
+ test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
|
|
|
+ f'{ENSEMBLE_PATH}/val_dataset.pt'
|
|
|
)
|
|
|
|
|
|
- indv_metrics = {'Accuracy': indv_accuracy, 'F1': indv_f1, 'AUC': indv_auc}
|
|
|
-
|
|
|
- thresholds = []
|
|
|
- quantiles = np.arange(0.1, 1, 0.1)
|
|
|
- # get uncertainty quantiles
|
|
|
- for quantile in quantiles:
|
|
|
- thresholds.append(predictions['Stdev'].quantile(quantile))
|
|
|
-
|
|
|
- # Calculate the accuracy of the model for each threshold
|
|
|
- accuracies = []
|
|
|
- # Calculate the accuracy of the model for each threshold
|
|
|
- for threshold, quantile in zip(thresholds, quantiles):
|
|
|
- filtered = predictions[predictions['Stdev'] <= threshold]
|
|
|
- correct = filtered['Correct'].sum()
|
|
|
- total = len(filtered)
|
|
|
- accuracy = correct / total
|
|
|
-
|
|
|
- false_positives = len(
|
|
|
- filtered[(filtered['Prediction'] >= 0.5) & (filtered['Actual'] == 0)]
|
|
|
- )
|
|
|
-
|
|
|
- false_negatives = len(
|
|
|
- filtered[(filtered['Prediction'] < 0.5) & (filtered['Actual'] == 1)]
|
|
|
- )
|
|
|
-
|
|
|
- f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
|
|
|
-
|
|
|
- auc = metrics.roc_auc_score(filtered['Actual'], filtered['Prediction'])
|
|
|
-
|
|
|
- accuracies.append(
|
|
|
- {
|
|
|
- 'Threshold': threshold,
|
|
|
- 'Accuracy': accuracy,
|
|
|
- 'Quantile': quantile,
|
|
|
- 'F1': f1,
|
|
|
- 'AUC': auc,
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- predictions.to_csv(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
|
|
|
- )
|
|
|
+ # [([model results], labels)]
|
|
|
+ results = []
|
|
|
|
|
|
- indv_predictions.to_csv(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_predictions.csv"
|
|
|
- )
|
|
|
+ # [(class_1, class_2, true_label)]
|
|
|
+ indv_results = []
|
|
|
+
|
|
|
+ for i, (data, target) in tqdm(
|
|
|
+ enumerate(test_set),
|
|
|
+ total=len(test_set),
|
|
|
+ desc='Getting predictions',
|
|
|
+ unit='sample',
|
|
|
+ ):
|
|
|
+ mri, xls = data
|
|
|
+ mri = mri.unsqueeze(0).to(config['training']['device'])
|
|
|
+ xls = xls.unsqueeze(0).to(config['training']['device'])
|
|
|
+ data = (mri, xls)
|
|
|
+ res = []
|
|
|
+ for j, model in enumerate(models):
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ output = model(data)
|
|
|
+
|
|
|
+ output = output.tolist()
|
|
|
+
|
|
|
+ if j == 0:
|
|
|
+ indv_results.append((output[0][0], output[0][1], target[1].item()))
|
|
|
|
|
|
- return pd.DataFrame(accuracies), indv_metrics
|
|
|
+ res.append(output)
|
|
|
+ results.append((res, target.tolist()))
|
|
|
+
|
|
|
+ # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
|
|
|
+ # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
|
|
|
+ # And one with the ensemble predicted class, ensemble standard deviation and true label
|
|
|
+
|
|
|
+ # [(ensemble predicted class, ensemble confidence, true label)]
|
|
|
+ confidences = []
|
|
|
+
|
|
|
+ # [(ensemble predicted class, ensemble standard deviation, true label)]
|
|
|
+ stdevs = []
|
|
|
+
|
|
|
+ for result in results:
|
|
|
+ model_results, true_label = result
|
|
|
+ # Get the ensemble mean and variance with numpy, as these are lists
|
|
|
+ mean = np.mean(model_results, axis=0)
|
|
|
+ variance = np.var(model_results, axis=0)
|
|
|
+
|
|
|
+ # Calculate confidence and standard deviation
|
|
|
+ confidence = (np.max(mean) - 0.5) * 2
|
|
|
+ stdev = np.sqrt(variance)
|
|
|
+
|
|
|
+ # Get the predicted class
|
|
|
+ predicted_class = np.argmax(mean)
|
|
|
+
|
|
|
+ # Get the confidence and standard deviation of the predicted class
|
|
|
+ print(stdev)
|
|
|
+ pc_stdev = np.squeeze(stdev)[predicted_class]
|
|
|
+
|
|
|
+ # Get the true label
|
|
|
+ true_label = true_label[1]
|
|
|
+
|
|
|
+ confidences.append((predicted_class, confidence, true_label))
|
|
|
+ stdevs.append((predicted_class, pc_stdev, true_label))
|
|
|
+
|
|
|
+ return results, confidences, stdevs, indv_results
|
|
|
|
|
|
|
|
|
if RUN:
|
|
|
- result, indv = threshold(config)
|
|
|
- result.to_csv(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
|
|
|
- )
|
|
|
- indv = pd.DataFrame([indv])
|
|
|
- indv.to_csv(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
|
|
|
+ results, confs, stdevs, indv_results = get_predictions(config)
|
|
|
+ # Convert to pandas dataframes
|
|
|
+ confs_df = pd.DataFrame(
|
|
|
+ confs, columns=['predicted_class', 'confidence', 'true_label']
|
|
|
)
|
|
|
+ stdevs_df = pd.DataFrame(stdevs, columns=['predicted_class', 'stdev', 'true_label'])
|
|
|
|
|
|
-result = pd.read_csv(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
|
|
|
-)
|
|
|
-predictions = pd.read_csv(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
|
|
|
-)
|
|
|
-indv = pd.read_csv(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
|
|
|
+ indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
|
|
|
+
|
|
|
+ if not os.path.exists(V2_PATH):
|
|
|
+ os.makedirs(V2_PATH)
|
|
|
+
|
|
|
+ confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
|
|
|
+ stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
|
|
|
+ indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
|
|
|
+else:
|
|
|
+ confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
|
|
|
+ stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
|
|
|
+ indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
|
|
|
+
|
|
|
+# Plot confidence vs standard deviation
|
|
|
+plt.scatter(confs_df['confidence'], stdevs_df['stdev'])
|
|
|
+plt.xlabel('Confidence')
|
|
|
+plt.ylabel('Standard Deviation')
|
|
|
+plt.title('Confidence vs Standard Deviation')
|
|
|
+plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
|
|
|
+plt.close()
|
|
|
+
|
|
|
+# Calculate Binning for Coverage
|
|
|
+# Sort Dataframes
|
|
|
+confs_df = confs_df.sort_values(by='confidence')
|
|
|
+stdevs_df = stdevs_df.sort_values(by='stdev')
|
|
|
+
|
|
|
+confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
|
|
|
+stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
|
|
|
+
|
|
|
+# Calculate individual model accuracy
|
|
|
+# Determine predicted class
|
|
|
+indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
|
|
|
+indv_df['predicted_class'] = indv_df['predicted_class'].apply(
|
|
|
+ lambda x: 0 if x == 'class_1' else 1
|
|
|
)
|
|
|
+indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
|
|
|
+accuracy_indv = indv_df['correct'].mean()
|
|
|
|
|
|
-print(indv)
|
|
|
+# Calculate percentiles for confidence and standard deviation
|
|
|
+quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11))['confidence']
|
|
|
+quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11))['stdev']
|
|
|
|
|
|
+accuracies_conf = []
|
|
|
+# Use the quantiles to calculate the coverage
|
|
|
+for quantile in quantiles_conf.items():
|
|
|
+ percentile = quantile[0]
|
|
|
|
|
|
-plt.figure()
|
|
|
+ filt = confs_df[confs_df['confidence'] >= quantile[1]]
|
|
|
+ accuracy = (
|
|
|
+ filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
|
|
|
+ )
|
|
|
|
|
|
-plt.plot(result['Quantile'], result['Accuracy'], label='Ensemble Accuracy')
|
|
|
+ accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy})
|
|
|
|
|
|
+accuracies_df = pd.DataFrame(accuracies_conf)
|
|
|
+
|
|
|
+# Plot the coverage
|
|
|
+plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], label='Ensemble')
|
|
|
plt.plot(
|
|
|
- result['Quantile'],
|
|
|
- [indv['Accuracy']] * len(result['Quantile']),
|
|
|
- label='Individual Accuracy',
|
|
|
+ accuracies_df['percentile'],
|
|
|
+ [accuracy_indv] * len(accuracies_df['percentile']),
|
|
|
+ label='Individual',
|
|
|
linestyle='--',
|
|
|
)
|
|
|
+plt.xlabel('Percentile')
|
|
|
+plt.ylabel('Accuracy')
|
|
|
+plt.title('Coverage conf')
|
|
|
plt.legend()
|
|
|
+plt.savefig(f'{V2_PATH}/coverage.png')
|
|
|
+plt.close()
|
|
|
|
|
|
-plt.title('Accuracy vs Coverage')
|
|
|
+# Repeat for standard deviation
|
|
|
+accuracies_stdev = []
|
|
|
|
|
|
-plt.xlabel('Coverage')
|
|
|
-plt.ylabel('Accuracy')
|
|
|
-plt.gca().invert_xaxis()
|
|
|
+for quantile in quantiles_stdev.items():
|
|
|
+ percentile = quantile[0]
|
|
|
|
|
|
-plt.savefig(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
|
|
|
-)
|
|
|
+ filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
|
|
|
+ accuracy = (
|
|
|
+ filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
|
|
|
+ )
|
|
|
|
|
|
-plt.figure()
|
|
|
-plt.plot(result['Quantile'], result['F1'], label='Ensemble F1')
|
|
|
-plt.plot(
|
|
|
- result['Quantile'],
|
|
|
- [indv['F1']] * len(result['Quantile']),
|
|
|
- label='Individual F1',
|
|
|
- linestyle='--',
|
|
|
-)
|
|
|
-plt.legend()
|
|
|
-plt.title('F1 vs Coverage')
|
|
|
+ accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy})
|
|
|
|
|
|
-plt.xlabel('Coverage')
|
|
|
-plt.ylabel('F1')
|
|
|
-plt.gca().invert_xaxis()
|
|
|
+accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
|
|
|
|
|
|
-plt.savefig(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
|
|
|
+# Plot the coverage
|
|
|
+plt.plot(
|
|
|
+ accuracies_stdev_df['percentile'], accuracies_stdev_df['accuracy'], label='Ensemble'
|
|
|
)
|
|
|
-
|
|
|
-plt.figure()
|
|
|
-plt.plot(result['Quantile'], result['AUC'], label='Ensemble AUC')
|
|
|
plt.plot(
|
|
|
- result['Quantile'],
|
|
|
- [indv['AUC']] * len(result['Quantile']),
|
|
|
- label='Individual AUC',
|
|
|
+ accuracies_stdev_df['percentile'],
|
|
|
+ [accuracy_indv] * len(accuracies_stdev_df['percentile']),
|
|
|
+ label='Individual',
|
|
|
linestyle='--',
|
|
|
)
|
|
|
+plt.xlabel('Percentile')
|
|
|
+plt.ylabel('Accuracy')
|
|
|
+plt.title('Coverage Stdev')
|
|
|
plt.legend()
|
|
|
-plt.title('AUC vs Coverage')
|
|
|
-plt.xlabel('Coverage')
|
|
|
-plt.ylabel('AUC')
|
|
|
plt.gca().invert_xaxis()
|
|
|
-
|
|
|
-plt.savefig(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
|
|
|
-)
|
|
|
-
|
|
|
-# create histogram of the incorrect predictions vs the uncertainty
|
|
|
-plt.figure()
|
|
|
-plt.hist(predictions[~predictions['Correct']]['Stdev'], bins=10)
|
|
|
-plt.xlabel('Uncertainty')
|
|
|
-plt.ylabel('Number of incorrect predictions')
|
|
|
-plt.savefig(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
|
|
|
-)
|
|
|
-
|
|
|
-ece = met.ECE(predictions['Prediction'], predictions['Actual'])
|
|
|
-
|
|
|
-print(f'ECE: {ece}')
|
|
|
-
|
|
|
-with open(
|
|
|
- f"{config['paths']['model_output']}{config['ensemble']['name']}/summary.txt", 'a'
|
|
|
-) as f:
|
|
|
- f.write(f'ECE: {ece}\n')
|
|
|
+plt.savefig(f'{V2_PATH}/coverage_stdev.png')
|
|
|
+plt.close()
|