|
@@ -0,0 +1,18 @@
|
|
|
+from model import ModelCT
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+import torch
|
|
|
+
|
|
|
+# Pseudo code for visualizing filters, adapt as needed
|
|
|
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+
|
|
|
+# Load model
|
|
|
+model = ModelCT()
|
|
|
+model.to(device)
|
|
|
+model.load_state_dict(torch.load("trained_model_weights.pth"))
|
|
|
+model.eval()
|
|
|
+
|
|
|
+# Example of loading filters from the first convolutional layer (backbone.conv1)
|
|
|
+weights = model.backbone.conv1.weight.data.cpu().numpy() # weights.shape = (64,1,7,7) -> 64 filters, 1 channel, size 7x7
|
|
|
+
|
|
|
+# Visualization of 21st filter from the first convolutional layer
|
|
|
+plt.imshow(weights[20,0,:,:], cmap='gray')
|