from model import ModelCT from matplotlib import pyplot as plt import torch # Pseudo code for visualizing filters, adapt as needed device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = ModelCT() model.to(device) model.load_state_dict(torch.load("trained_model_weights.pth")) model.eval() # Example of loading filters from the first convolutional layer (backbone.conv1) weights = model.backbone.conv1.weight.data.cpu().numpy() # weights.shape = (64,1,7,7) -> 64 filters, 1 channel, size 7x7 # Visualization of 21st filter from the first convolutional layer plt.imshow(weights[20,0,:,:], cmap='gray')