|
@@ -0,0 +1,61 @@
|
|
|
+import torch
|
|
|
+import numpy as np
|
|
|
+from captum.attr import Saliency
|
|
|
+from model import ModelCT
|
|
|
+import os
|
|
|
+import json
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+
|
|
|
+# Load test data
|
|
|
+main_path_to_data = "/data/PSUF_naloge/5-naloga/processed/"
|
|
|
+model_weights = "trained_models/trained_model_weights.pth"
|
|
|
+
|
|
|
+with open(os.path.join(main_path_to_data, "test_info.json")) as fp:
|
|
|
+ test_info = json.load(fp)
|
|
|
+
|
|
|
+# Load model and set to eval mode
|
|
|
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+model = ModelCT()
|
|
|
+model.to(device)
|
|
|
+model.load_state_dict(torch.load(model_weights))
|
|
|
+model.eval()
|
|
|
+
|
|
|
+# Load the image as a NumPy array and convert it to a torch tensor
|
|
|
+x_np = np.load(os.path.join(main_path_to_data, "severe/study_0582_slice_4.npy")).astype(np.float32)
|
|
|
+
|
|
|
+# If the image is grayscale (2D array), add a channel dimension:
|
|
|
+if x_np.ndim == 2:
|
|
|
+ x_np = np.expand_dims(x_np, axis=0) # shape becomes (1, H, W)
|
|
|
+
|
|
|
+# Add a batch dimension (if required by your model)
|
|
|
+x_tensor = torch.from_numpy(x_np).unsqueeze(0) # shape becomes (1, 1, H, W)
|
|
|
+x_tensor = x_tensor.to(device)
|
|
|
+x_tensor.requires_grad_() # Ensure gradients are tracked for saliency
|
|
|
+
|
|
|
+# Create saliency object from Captum
|
|
|
+saliency = Saliency(model)
|
|
|
+
|
|
|
+# Compute attribution
|
|
|
+attribution = saliency.attribute(x_tensor)
|
|
|
+
|
|
|
+# Prepare the outputs for visualization: remove batch/channel dimensions as needed
|
|
|
+attribution_np = attribution.detach().cpu().numpy().squeeze()
|
|
|
+original_np = x_tensor.detach().cpu().numpy().squeeze()
|
|
|
+
|
|
|
+# Optionally rotate the images for better visualization
|
|
|
+attribution_np = np.rot90(attribution_np)
|
|
|
+original_np = np.rot90(original_np)
|
|
|
+
|
|
|
+# Visualize the original image and its saliency map
|
|
|
+fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
|
|
|
+
|
|
|
+ax1.set_title("Original", fontweight='bold', fontsize=10)
|
|
|
+ax1.imshow(original_np, cmap='gray')
|
|
|
+ax1.axis('off')
|
|
|
+
|
|
|
+ax2.set_title("Relative Pixel Contribution", fontweight='bold', fontsize=10)
|
|
|
+ax2.imshow(original_np, cmap='gray')
|
|
|
+ax2.imshow(np.ma.masked_where(attribution_np == 0, attribution_np), alpha=0.8, cmap='RdBu')
|
|
|
+ax2.axis('off')
|
|
|
+
|
|
|
+plt.show()
|