visualize_saliency.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import torch
  2. import numpy as np
  3. from captum.attr import Saliency
  4. from datareader import DataReader
  5. from model import ModelCT
  6. from torch.utils.data import DataLoader
  7. import os
  8. import json
  9. from matplotlib import pyplot as plt
  10. from scipy.ndimage import gaussian_filter
  11. if __name__ == "__main__":
  12. # Nalozimo testne podatke
  13. main_path_to_data = "/data/PSUF_naloge/5-naloga/processed"
  14. model_folder = "trained_models/testrun123"
  15. with open (os.path.join(main_path_to_data, "test_info.json")) as fp:
  16. test_info = json.load(fp)
  17. # Nalozimo model, ga damo v eval mode
  18. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  19. model = ModelCT()
  20. model.to(device)
  21. model.load_state_dict(torch.load(os.path.join(model_folder, "trained_model_weights.pth")))
  22. model.eval()
  23. # Naredimo testni generator
  24. test_datareader = DataReader(main_path_to_data, test_info)
  25. test_generator = DataLoader(test_datareader, batch_size=1, shuffle=False, pin_memory = True, num_workers=2)
  26. # Iz knjiznice captum nalozimo Saliency
  27. saliency = Saliency(model)
  28. # V testnih podatkih poiscemo primer z dobro klasifikacijo hude okuzbe (y==1, y_hat > 0.95)
  29. for item_test in test_generator:
  30. x, y = item_test
  31. x = x.to(device)
  32. y = y.to(device)
  33. # Forward pass
  34. y_hat = model.forward(x)
  35. y_hat = torch.sigmoid(y_hat)
  36. if int(y) == 1 and float(y_hat) > 0.95:
  37. attribution = saliency.attribute(x)
  38. attribution = np.rot90(attribution.detach().cpu().numpy().squeeze())
  39. original = np.rot90(x.cpu().numpy().squeeze())
  40. # Po zelji zgladimo saliency z gaussom
  41. attribution = gaussian_filter(attribution, sigma=2)
  42. # Vizualiziramo saliency map in originalno sliko
  43. fig, (ax1, ax2) = plt.subplots(1, 2)
  44. ax1.set_title("Orignal", fontweight='bold', fontsize=10)
  45. ax1.imshow(original, cmap='Greys_r')
  46. ax1.axis('off')
  47. ax2.set_title("Relativni doprinos pikslov", fontweight='bold', fontsize=10)
  48. ax2.imshow(original, cmap='gray')
  49. ax2.imshow(np.ma.masked_where(attribution==0, attribution), alpha=0.8, cmap='RdBu')
  50. ax2.axis('off')
  51. break