interpretabillity.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. import numpy as np
  3. from captum.attr import Saliency
  4. from model import ModelCT
  5. import os
  6. import json
  7. from matplotlib import pyplot as plt
  8. # Load test data
  9. main_path_to_data = "/data/PSUF_naloge/5-naloga/processed/"
  10. model_weights = "trained_models/trained_model_weights.pth"
  11. with open(os.path.join(main_path_to_data, "test_info.json")) as fp:
  12. test_info = json.load(fp)
  13. # Load model and set to eval mode
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. model = ModelCT()
  16. model.to(device)
  17. model.load_state_dict(torch.load(model_weights))
  18. model.eval()
  19. # Load the image as a NumPy array and convert it to a torch tensor
  20. x_np = np.load(os.path.join(main_path_to_data, "severe/study_0582_slice_4.npy")).astype(np.float32)
  21. # If the image is grayscale (2D array), add a channel dimension:
  22. if x_np.ndim == 2:
  23. x_np = np.expand_dims(x_np, axis=0) # shape becomes (1, H, W)
  24. # Add a batch dimension (if required by your model)
  25. x_tensor = torch.from_numpy(x_np).unsqueeze(0) # shape becomes (1, 1, H, W)
  26. x_tensor = x_tensor.to(device)
  27. x_tensor.requires_grad_() # Ensure gradients are tracked for saliency
  28. # Create saliency object from Captum
  29. saliency = Saliency(model)
  30. # Compute attribution
  31. attribution = saliency.attribute(x_tensor)
  32. # Prepare the outputs for visualization: remove batch/channel dimensions as needed
  33. attribution_np = attribution.detach().cpu().numpy().squeeze()
  34. original_np = x_tensor.detach().cpu().numpy().squeeze()
  35. # Optionally rotate the images for better visualization
  36. attribution_np = np.rot90(attribution_np)
  37. original_np = np.rot90(original_np)
  38. # Visualize the original image and its saliency map
  39. fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
  40. ax1.set_title("Original", fontweight='bold', fontsize=10)
  41. ax1.imshow(original_np, cmap='gray')
  42. ax1.axis('off')
  43. ax2.set_title("Relative Pixel Contribution", fontweight='bold', fontsize=10)
  44. ax2.imshow(original_np, cmap='gray')
  45. ax2.imshow(np.ma.masked_where(attribution_np == 0, attribution_np), alpha=0.8, cmap='RdBu')
  46. ax2.axis('off')
  47. plt.show()