sanity_check.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import xarray as xr
  2. import numpy as np
  3. import sys
  4. import os
  5. sys.path.append(
  6. os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
  7. ) # to allow imports from parent directory
  8. from utils.config import config
  9. import pathlib as pl
  10. import colorama as clr
  11. model_dataset_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
  12. array = xr.open_dataset(model_dataset_path) # type: ignore
  13. predictions: xr.DataArray = array["predictions"]
  14. labels: xr.DataArray = array["labels"]
  15. # Average predictions across models
  16. avg_predictions = predictions.mean(dim="model")
  17. # Sort from highest to lowest confidence for the positive class (img_class=1)
  18. sorted_indices = np.argsort(-avg_predictions.sel(img_class=1).values)
  19. sorted_avg_predictions = avg_predictions.isel(img_id=sorted_indices)
  20. sorted_labels = labels.isel(img_id=sorted_indices)
  21. # Print out all predictions with their labels
  22. top_n = sorted_avg_predictions.sizes[
  23. "img_id"
  24. ] # Change this value to print more or fewer
  25. print(
  26. clr.Fore.CYAN
  27. + f"Top {top_n} Predictions (Confidence for Positive Class):"
  28. + clr.Style.RESET_ALL
  29. )
  30. for i in range(top_n):
  31. confidence = sorted_avg_predictions.sel(img_class=1).isel(img_id=i).item()
  32. label = sorted_labels.isel(img_id=i, label=1).values
  33. correctness = (
  34. "CORRECT"
  35. if (confidence >= 0.5 and label == 1) or (confidence < 0.5 and label == 0)
  36. else "INCORRECT"
  37. )
  38. color = clr.Fore.GREEN if correctness == "CORRECT" else clr.Fore.RED
  39. print(
  40. f"Image ID: {sorted_avg_predictions.img_id.isel(img_id=i).item():<8}, "
  41. f"Confidence: {confidence:.4f}, "
  42. f"Label: {label:<3}, " + color + f"{correctness:<9}" + clr.Style.RESET_ALL
  43. )
  44. # Calculate overall accuracy
  45. predicted_positive = avg_predictions.sel(img_class=1) >= 0.5
  46. true_positive = labels.sel(label=1) == 1
  47. correct_predictions = (predicted_positive == true_positive).sum().item()
  48. total_predictions = len(avg_predictions.img_id)
  49. overall_accuracy = (
  50. correct_predictions / total_predictions if total_predictions > 0 else 0.0
  51. )
  52. print(
  53. clr.Fore.MAGENTA
  54. + f"\nOverall Accuracy (Threshold 0.5): {overall_accuracy:.4f}"
  55. + clr.Style.RESET_ALL
  56. )
  57. # Then go through all individual models and print out their accuracies for comparison, sorted from highest to lowest
  58. model_accuracies = []
  59. for model_idx in predictions.coords["model"].values:
  60. model_preds = predictions.sel(model=model_idx)
  61. predicted_positive = model_preds.sel(img_class=1) >= 0.5
  62. correct_predictions = (predicted_positive == true_positive).sum().item()
  63. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  64. model_accuracies.append((model_idx, accuracy))
  65. # Sort by accuracy
  66. model_accuracies.sort(key=lambda x: x[1], reverse=True)
  67. print(
  68. clr.Fore.CYAN
  69. + f"\nIndividual Model Accuracies (Threshold 0.5):"
  70. + clr.Style.RESET_ALL
  71. )
  72. for model_idx, accuracy in model_accuracies:
  73. print(f"Model {int(model_idx):<3}: Accuracy: {accuracy:.4f}")
  74. # Then calculate the average accuracy if we were to ensemble the top K models, for K=1 to total number of models
  75. total_models = len(predictions.coords["model"].values)
  76. ensemble_accuracies = []
  77. for k in range(1, total_models + 1):
  78. top_k_models = [ma[0] for ma in model_accuracies[:k]]
  79. ensemble_preds = predictions.sel(model=top_k_models).mean(dim="model")
  80. predicted_positive = ensemble_preds.sel(img_class=1) >= 0.5
  81. correct_predictions = (predicted_positive == true_positive).sum().item()
  82. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  83. ensemble_accuracies.append((k, accuracy))
  84. print(
  85. clr.Fore.CYAN
  86. + f"\nEnsemble Accuracies for Top K Models (Threshold 0.5):"
  87. + clr.Style.RESET_ALL
  88. )
  89. for k, accuracy in ensemble_accuracies:
  90. print(f"Top {k:<3} Models: Ensemble Accuracy: {accuracy:.4f}")
  91. # Finally, identify the top 5 most confidently incorrect predictions
  92. incorrect_predictions = []
  93. for i in range(len(avg_predictions.img_id)):
  94. confidence = avg_predictions.sel(img_class=1).isel(img_id=i).item()
  95. label = labels.isel(img_id=i, label=1).values
  96. predicted_label = 1 if confidence >= 0.5 else 0
  97. if predicted_label != label:
  98. incorrect_predictions.append((i, confidence, label))
  99. # Sort by confidence
  100. incorrect_predictions.sort(key=lambda x: -abs(x[1] - 0.5))
  101. top_incorrect = incorrect_predictions[:5]
  102. print(
  103. clr.Fore.YELLOW
  104. + f"\nTop 5 Most Confident Incorrect Predictions:"
  105. + clr.Style.RESET_ALL
  106. )
  107. for i, confidence, label in top_incorrect:
  108. predicted_label = 1 if confidence >= 0.5 else 0
  109. print(
  110. f"Image ID: {avg_predictions.img_id.isel(img_id=i).item():<8}, "
  111. f"Confidence: {confidence:.4f}, "
  112. f"Predicted Label: {predicted_label:<3}, "
  113. f"True Label: {label:<3}"
  114. )