visualize_convolutions.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import json
  2. import torch
  3. from model import ModelCT
  4. from datareader import DataReader
  5. from torch.utils.data import DataLoader
  6. from matplotlib import pyplot as plt
  7. import torch.nn.functional as F
  8. import numpy as np
  9. import os
  10. # Razred za shranjevanje konvolucij
  11. class SaveOutput:
  12. def __init__(self):
  13. self.outputs = []
  14. def __call__(self, module, module_in, module_out):
  15. self.outputs.append(module_out)
  16. def clear(self):
  17. self.outputs = []
  18. if __name__ == '__main__':
  19. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  20. # Nalozimo testne podatke
  21. main_path_to_data = "/data/PSUF_naloge/5-naloga/processed"
  22. model_folder = "trained_models/testrun123"
  23. with open (os.path.join(main_path_to_data, "test_info.json")) as fp:
  24. test_info = json.load(fp)
  25. # Nalozimo model, ki ga damo v eval mode
  26. model = ModelCT()
  27. model.to(device)
  28. model.load_state_dict(torch.load(os.path.join(model_folder, "trained_model_weights.pth")))
  29. model.eval()
  30. # Inicializiramo razred v SO in registriramo kavlje v nasem modelu
  31. SO = SaveOutput()
  32. for layer in model.modules():
  33. if isinstance(layer, torch.nn.modules.conv.Conv2d):
  34. handle = layer.register_forward_hook(SO)
  35. # Naredimo test_generator z batch_size=1
  36. test_datareader = DataReader(main_path_to_data, test_info)
  37. test_generator = DataLoader(test_datareader, batch_size=1, shuffle=False, pin_memory=True, num_workers=2)
  38. # Vzamemo prvi testni primer npr.
  39. item_test = next(iter(test_generator))
  40. # Propagiramo prvi testni primer skozi mrezo, razred SaveOutput si shrani vse konvolucije
  41. input_image = item_test[0].to(device)
  42. _ = model(input_image)
  43. # Izberemo katero konvolucijo bi radi pogledali (color_channel bo vedno 0)
  44. color_channel = 0 # indeks barvnega kanala (pri nas le 1 kanal)
  45. idx_layer = 5 # indeks konvolucijske plasti (Conv2d) - (pri nas 21)
  46. idx_convolution = 17 # indeks konvolucije na dani plasti (max odvisen od plasti)
  47. # Vizualiziramo
  48. convolution_0_5_17 = SO.outputs[idx_layer][color_channel][idx_convolution][:,:].cpu().detach().numpy()
  49. plt.figure()
  50. plt.imshow(np.rot90(convolution_0_5_17), cmap="gray")
  51. plt.figure()
  52. plt.imshow(np.rot90(input_image.cpu().numpy()[0,0,:,:]),cmap="gray")