Przeglądaj źródła

added code for interpretability

Zan Klanecek 1 tydzień temu
rodzic
commit
89320e5ad9
1 zmienionych plików z 61 dodań i 0 usunięć
  1. 61 0
      Problem 5/interpretabillity.py

+ 61 - 0
Problem 5/interpretabillity.py

@@ -0,0 +1,61 @@
+import torch
+import numpy as np
+from captum.attr import Saliency
+from model import ModelCT
+import os
+import json
+from matplotlib import pyplot as plt
+
+# Load test data
+main_path_to_data = "/data/PSUF_naloge/5-naloga/processed/"
+model_weights = "trained_models/trained_model_weights.pth"
+
+with open(os.path.join(main_path_to_data, "test_info.json")) as fp:
+    test_info = json.load(fp)
+
+# Load model and set to eval mode
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = ModelCT()
+model.to(device)
+model.load_state_dict(torch.load(model_weights))
+model.eval()
+
+# Load the image as a NumPy array and convert it to a torch tensor
+x_np = np.load(os.path.join(main_path_to_data, "severe/study_0582_slice_4.npy")).astype(np.float32)
+
+# If the image is grayscale (2D array), add a channel dimension:
+if x_np.ndim == 2:
+    x_np = np.expand_dims(x_np, axis=0)  # shape becomes (1, H, W)
+
+# Add a batch dimension (if required by your model)
+x_tensor = torch.from_numpy(x_np).unsqueeze(0)  # shape becomes (1, 1, H, W)
+x_tensor = x_tensor.to(device)
+x_tensor.requires_grad_()  # Ensure gradients are tracked for saliency
+
+# Create saliency object from Captum
+saliency = Saliency(model)
+
+# Compute attribution
+attribution = saliency.attribute(x_tensor)
+
+# Prepare the outputs for visualization: remove batch/channel dimensions as needed
+attribution_np = attribution.detach().cpu().numpy().squeeze()
+original_np = x_tensor.detach().cpu().numpy().squeeze()
+
+# Optionally rotate the images for better visualization
+attribution_np = np.rot90(attribution_np)
+original_np = np.rot90(original_np)
+
+# Visualize the original image and its saliency map
+fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
+
+ax1.set_title("Original", fontweight='bold', fontsize=10)
+ax1.imshow(original_np, cmap='gray')
+ax1.axis('off')
+
+ax2.set_title("Relative Pixel Contribution", fontweight='bold', fontsize=10)
+ax2.imshow(original_np, cmap='gray')
+ax2.imshow(np.ma.masked_where(attribution_np == 0, attribution_np), alpha=0.8, cmap='RdBu')
+ax2.axis('off')
+
+plt.show()