visualize_saliency.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. import numpy as np
  3. from captum.attr import Saliency
  4. from datareader import DataReader
  5. from model import ModelCT
  6. from torch.utils.data import DataLoader
  7. import os
  8. import json
  9. from matplotlib import pyplot as plt
  10. from scipy.ndimage import gaussian_filter
  11. if __name__ == "__main__":
  12. # Nalozimo testne podatke
  13. main_path_to_data = "C:/Users/Klanecek/Desktop/processed"
  14. model_folder = "trained_models/testrun123_v2"
  15. with open (os.path.join(main_path_to_data, "test_info.json")) as fp:
  16. test_info = json.load(fp)
  17. # Nalozimo model, ga damo v eval mode
  18. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  19. model_folder = "trained_models/testrun123_v2/" # mapa naucenega modela
  20. model = ModelCT()
  21. model.to(device)
  22. model.load_state_dict(torch.load(os.path.join(model_folder, "trained_model_weights.pth")))
  23. model.eval()
  24. # Naredimo testni generator
  25. test_datareader = DataReader(main_path_to_data, test_info)
  26. test_generator = DataLoader(test_datareader, batch_size=1, shuffle=False, pin_memory = True, num_workers=2)
  27. # Iz knjiznice captum nalozimo Saliency
  28. saliency = Saliency(model)
  29. # V testnih podatkih poiscemo primer z dobro klasifikacijo hude okuzbe (y==1, y_hat > 0.95)
  30. for item_test in test_generator:
  31. x, y = item_test
  32. x = x.to(device)
  33. y = y.to(device)
  34. # Forward pass
  35. y_hat = model.forward(x)
  36. y_hat = torch.sigmoid(y_hat)
  37. if int(y) == 1 and float(y_hat) > 0.95:
  38. attribution = saliency.attribute(x)
  39. attribution = np.rot90(attribution.detach().cpu().numpy().squeeze())
  40. original = np.rot90(x.cpu().numpy().squeeze())
  41. # Po zelji zgladimo saliency z gaussom
  42. attribution = gaussian_filter(attribution, sigma=2)
  43. # Vizualiziramo saliency map in originalno sliko
  44. fig, (ax1, ax2) = plt.subplots(1, 2)
  45. ax1.set_title("Orignal", fontweight='bold', fontsize=10)
  46. ax1.imshow(original, cmap='Greys_r')
  47. ax1.axis('off')
  48. ax2.set_title("Relativni doprinos pikslov", fontweight='bold', fontsize=10)
  49. ax2.imshow(original, cmap='gray')
  50. ax2.imshow(np.ma.masked_where(attribution==0, attribution), alpha=0.8, cmap='RdBu')
  51. ax2.axis('off')
  52. break