visualize_filters.py 684 B

12345678910111213141516171819
  1. from model import ModelCT
  2. from matplotlib import pyplot as plt
  3. import numpy as np
  4. import torch
  5. import os
  6. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  7. model_folder = "trained_models/testrun123_v2/" # mapa naucenega modela
  8. model = ModelCT()
  9. model.to(device)
  10. model.load_state_dict(torch.load(os.path.join(model_folder, "trained_model_weights.pth")))
  11. model.eval()
  12. # Nalozimo npr. filtre iz prve konvolucijske plasti backbone.conv1
  13. weights = model.backbone.conv1.weight.data.cpu().numpy()
  14. # weights.shape = (64,1,7,7) -> 64 filtrov, z 1 kanalom, velikosti 7x7
  15. # Vizualizacija 21. filtra iz prve plasti (backbone.conv1)
  16. plt.imshow(weights[20,0,:,:], cmap='gray')