visualize_filters.py 639 B

123456789101112131415161718
  1. from model import ModelCT
  2. from matplotlib import pyplot as plt
  3. import torch
  4. # Pseudo code for visualizing filters, adapt as needed
  5. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  6. # Load model
  7. model = ModelCT()
  8. model.to(device)
  9. model.load_state_dict(torch.load("trained_model_weights.pth"))
  10. model.eval()
  11. # Example of loading filters from the first convolutional layer (backbone.conv1)
  12. weights = model.backbone.conv1.weight.data.cpu().numpy() # weights.shape = (64,1,7,7) -> 64 filters, 1 channel, size 7x7
  13. # Visualization of 21st filter from the first convolutional layer
  14. plt.imshow(weights[20,0,:,:], cmap='gray')