12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import json
- import torch
- from model import ModelCT
- from datareader import DataReader
- from torch.utils.data import DataLoader
- from matplotlib import pyplot as plt
- import torch.nn.functional as F
- import numpy as np
- import os
- # Razred za shranjevanje konvolucij
- class SaveOutput:
- def __init__(self):
- self.outputs = []
-
- def __call__(self, module, module_in, module_out):
- self.outputs.append(module_out)
-
- def clear(self):
- self.outputs = []
- if __name__ == '__main__':
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # Nalozimo testne podatke
- main_path_to_data = "C:/Users/Klanecek/Desktop/processed"
- model_folder = "trained_models/testrun123_v2"
- with open (os.path.join(main_path_to_data, "test_info.json")) as fp:
- test_info = json.load(fp)
-
- # Nalozimo model, ki ga damo v eval mode
- model = ModelCT()
- model.to(device)
- model.load_state_dict(torch.load(os.path.join(model_folder, "trained_model_weights.pth")))
- model.eval()
- # Inicializiramo razred v SO in registriramo kavlje v nasem modelu
- SO = SaveOutput()
- for layer in model.modules():
- if isinstance(layer, torch.nn.modules.conv.Conv2d):
- handle = layer.register_forward_hook(SO)
- # Naredimo test_generator z batch_size=1
- test_datareader = DataReader(main_path_to_data, test_info)
- test_generator = DataLoader(test_datareader, batch_size=1, shuffle=False, pin_memory=True, num_workers=2)
- # Vzamemo prvi testni primer npr.
- item_test = next(iter(test_generator))
- # Propagiramo prvi testni primer skozi mrezo, razred SaveOutput si shrani vse konvolucije
- input_image = item_test[0].to(device)
- _ = model(input_image)
- # Izberemo katero konvolucijo bi radi pogledali (color_channel bo vedno 0)
- color_channel = 0 # indeks barvnega kanala (pri nas le 1 kanal)
- idx_layer = 5 # indeks konvolucijske plasti (Conv2d) - (pri nas 21)
- idx_convolution = 17 # indeks konvolucije na dani plasti (max odvisen od plasti)
- # Vizualiziramo
- convolution_0_5_17 = SO.outputs[idx_layer][color_channel][idx_convolution][:,:].cpu().detach().numpy()
- plt.figure()
- plt.imshow(np.rot90(convolution_0_5_17), cmap="gray")
-
- plt.figure()
- plt.imshow(np.rot90(input_image.cpu().numpy()[0,0,:,:]),cmap="gray")
|