sanity_check.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import xarray as xr
  2. import numpy as np
  3. from utils.config import config
  4. import pathlib as pl
  5. import colorama as clr
  6. import os
  7. os.chdir(pl.Path(__file__).parent)
  8. model_dataset_path = pl.Path("../model_evaluations") / pl.Path(
  9. config["analysis"]["evaluation_name"].strip()
  10. ).with_suffix(".nc")
  11. array = xr.open_dataset(model_dataset_path) # type: ignore
  12. predictions: xr.DataArray = array["predictions"]
  13. labels: xr.DataArray = array["labels"]
  14. # Average predictions across models
  15. avg_predictions = predictions.mean(dim="model")
  16. # Sort from highest to lowest confidence for the positive class (img_class=1)
  17. sorted_indices = np.argsort(-avg_predictions.sel(img_class=1).values)
  18. sorted_avg_predictions = avg_predictions.isel(img_id=sorted_indices)
  19. sorted_labels = labels.isel(img_id=sorted_indices)
  20. # Print out all predictions with their labels
  21. top_n = sorted_avg_predictions.sizes[
  22. "img_id"
  23. ] # Change this value to print more or fewer
  24. print(
  25. clr.Fore.CYAN
  26. + f"Top {top_n} Predictions (Confidence for Positive Class):"
  27. + clr.Style.RESET_ALL
  28. )
  29. for i in range(top_n):
  30. confidence = sorted_avg_predictions.sel(img_class=1).isel(img_id=i).item()
  31. label = sorted_labels.isel(img_id=i, label=1).values
  32. correctness = (
  33. "CORRECT"
  34. if (confidence >= 0.5 and label == 1) or (confidence < 0.5 and label == 0)
  35. else "INCORRECT"
  36. )
  37. color = clr.Fore.GREEN if correctness == "CORRECT" else clr.Fore.RED
  38. print(
  39. f"Image ID: {sorted_avg_predictions.img_id.isel(img_id=i).item():<8}, "
  40. f"Confidence: {confidence:.4f}, "
  41. f"Label: {label:<3}, " + color + f"{correctness:<9}" + clr.Style.RESET_ALL
  42. )
  43. # Calculate overall accuracy
  44. predicted_positive = avg_predictions.sel(img_class=1) >= 0.5
  45. true_positive = labels.sel(label=1) == 1
  46. correct_predictions = (predicted_positive == true_positive).sum().item()
  47. total_predictions = len(avg_predictions.img_id)
  48. overall_accuracy = (
  49. correct_predictions / total_predictions if total_predictions > 0 else 0.0
  50. )
  51. print(
  52. clr.Fore.MAGENTA
  53. + f"\nOverall Accuracy (Threshold 0.5): {overall_accuracy:.4f}"
  54. + clr.Style.RESET_ALL
  55. )
  56. # Then go through all individual models and print out their accuracies for comparison, sorted from highest to lowest
  57. model_accuracies = []
  58. for model_idx in predictions.coords["model"].values:
  59. model_preds = predictions.sel(model=model_idx)
  60. predicted_positive = model_preds.sel(img_class=1) >= 0.5
  61. correct_predictions = (predicted_positive == true_positive).sum().item()
  62. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  63. model_accuracies.append((model_idx, accuracy))
  64. # Sort by accuracy
  65. model_accuracies.sort(key=lambda x: x[1], reverse=True)
  66. print(
  67. clr.Fore.CYAN
  68. + f"\nIndividual Model Accuracies (Threshold 0.5):"
  69. + clr.Style.RESET_ALL
  70. )
  71. for model_idx, accuracy in model_accuracies:
  72. print(f"Model {int(model_idx):<3}: Accuracy: {accuracy:.4f}")
  73. # Then calculate the average accuracy if we were to ensemble the top K models, for K=1 to total number of models
  74. total_models = len(predictions.coords["model"].values)
  75. ensemble_accuracies = []
  76. for k in range(1, total_models + 1):
  77. top_k_models = [ma[0] for ma in model_accuracies[:k]]
  78. ensemble_preds = predictions.sel(model=top_k_models).mean(dim="model")
  79. predicted_positive = ensemble_preds.sel(img_class=1) >= 0.5
  80. correct_predictions = (predicted_positive == true_positive).sum().item()
  81. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  82. ensemble_accuracies.append((k, accuracy))
  83. print(
  84. clr.Fore.CYAN
  85. + f"\nEnsemble Accuracies for Top K Models (Threshold 0.5):"
  86. + clr.Style.RESET_ALL
  87. )
  88. for k, accuracy in ensemble_accuracies:
  89. print(f"Top {k:<3} Models: Ensemble Accuracy: {accuracy:.4f}")
  90. # Finally, identify the top 5 most confidently incorrect predictions
  91. incorrect_predictions = []
  92. for i in range(len(avg_predictions.img_id)):
  93. confidence = avg_predictions.sel(img_class=1).isel(img_id=i).item()
  94. label = labels.isel(img_id=i, label=1).values
  95. predicted_label = 1 if confidence >= 0.5 else 0
  96. if predicted_label != label:
  97. incorrect_predictions.append((i, confidence, label))
  98. # Sort by confidence
  99. incorrect_predictions.sort(key=lambda x: -abs(x[1] - 0.5))
  100. top_incorrect = incorrect_predictions[:5]
  101. print(
  102. clr.Fore.YELLOW
  103. + f"\nTop 5 Most Confident Incorrect Predictions:"
  104. + clr.Style.RESET_ALL
  105. )
  106. for i, confidence, label in top_incorrect:
  107. predicted_label = 1 if confidence >= 0.5 else 0
  108. print(
  109. f"Image ID: {avg_predictions.img_id.isel(img_id=i).item():<8}, "
  110. f"Confidence: {confidence:.4f}, "
  111. f"Predicted Label: {predicted_label:<3}, "
  112. f"True Label: {label:<3}"
  113. )