12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import torch
- import numpy as np
- from captum.attr import Saliency
- from model import ModelCT
- import os
- import json
- from matplotlib import pyplot as plt
- # Load test data
- main_path_to_data = "/data/PSUF_naloge/5-naloga/processed/"
- model_weights = "trained_models/trained_model_weights.pth"
- with open(os.path.join(main_path_to_data, "test_info.json")) as fp:
- test_info = json.load(fp)
- # Load model and set to eval mode
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = ModelCT()
- model.to(device)
- model.load_state_dict(torch.load(model_weights))
- model.eval()
- # Load the image as a NumPy array and convert it to a torch tensor
- x_np = np.load(os.path.join(main_path_to_data, "severe/study_0582_slice_4.npy")).astype(np.float32)
- # If the image is grayscale (2D array), add a channel dimension:
- if x_np.ndim == 2:
- x_np = np.expand_dims(x_np, axis=0) # shape becomes (1, H, W)
- # Add a batch dimension (if required by your model)
- x_tensor = torch.from_numpy(x_np).unsqueeze(0) # shape becomes (1, 1, H, W)
- x_tensor = x_tensor.to(device)
- x_tensor.requires_grad_() # Ensure gradients are tracked for saliency
- # Create saliency object from Captum
- saliency = Saliency(model)
- # Compute attribution
- attribution = saliency.attribute(x_tensor)
- # Prepare the outputs for visualization: remove batch/channel dimensions as needed
- attribution_np = attribution.detach().cpu().numpy().squeeze()
- original_np = x_tensor.detach().cpu().numpy().squeeze()
- # Optionally rotate the images for better visualization
- attribution_np = np.rot90(attribution_np)
- original_np = np.rot90(original_np)
- # Visualize the original image and its saliency map
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
- ax1.set_title("Original", fontweight='bold', fontsize=10)
- ax1.imshow(original_np, cmap='gray')
- ax1.axis('off')
- ax2.set_title("Relative Pixel Contribution", fontweight='bold', fontsize=10)
- ax2.imshow(original_np, cmap='gray')
- ax2.imshow(np.ma.masked_where(attribution_np == 0, attribution_np), alpha=0.8, cmap='RdBu')
- ax2.axis('off')
- plt.show()
|