|
@@ -42,41 +42,55 @@ class CNN_Net(nn.Module):
|
|
|
# Midflow Block
|
|
|
self.midflow = ly.MidFlowBlock(384, droprate)
|
|
|
|
|
|
- # Combine
|
|
|
- self.combined = nn.Sequential(self.conv1, self.conv2, self.midflow)
|
|
|
+
|
|
|
|
|
|
# Split Convolutional Block
|
|
|
self.splitconv = ly.SplitConvBlock(384, 192, 96, 4, droprate)
|
|
|
|
|
|
#Fully Connected Block
|
|
|
- self.fc1 = ly.FullyConnectedBlock(96, 20, droprate=droprate)
|
|
|
+ self.fc_image = ly.FullyConnectedBlock(96, 20, droprate=droprate)
|
|
|
|
|
|
- self.image_layers = nn.Sequential(self.combined, self.splitconv).double()
|
|
|
|
|
|
|
|
|
#Data Layers, fully connected
|
|
|
- self.fc1 = ly.FullyConnectedBlock(clin_data_channels, 64, droprate=droprate)
|
|
|
- self.fc2 = ly.FullyConnectedBlock(64, 20, droprate=droprate)
|
|
|
+ self.fc_clin1 = ly.FullyConnectedBlock(clin_data_channels, 64, droprate=droprate)
|
|
|
+ self.fc_clin2 = ly.FullyConnectedBlock(64, 20, droprate=droprate)
|
|
|
|
|
|
- #Connect Data
|
|
|
- self.data_layers = nn.Sequential(self.fc1, self.fc2).double()
|
|
|
|
|
|
#Final Dense Layer
|
|
|
self.dense1 = nn.Linear(40, 5)
|
|
|
self.dense2 = nn.Linear(5, 2)
|
|
|
self.softmax = nn.Softmax()
|
|
|
|
|
|
- self.final_layers = nn.Sequential(self.dense1, self.dense2, self.softmax)
|
|
|
+
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
image, clin_data = x
|
|
|
|
|
|
- print(image.shape)
|
|
|
+ print("Input image shape:", image.shape)
|
|
|
|
|
|
- image = self.image_layers(image)
|
|
|
+ image = self.conv1(image)
|
|
|
+ print("Conv1 shape:", image.shape)
|
|
|
+ image = self.conv2(image)
|
|
|
+ print("Conv2 shape:", image.shape)
|
|
|
+ image = self.midflow(image)
|
|
|
+ print("Midflow shape:", image.shape)
|
|
|
+ image = self.splitconv(image)
|
|
|
+ print("Splitconv shape:", image.shape)
|
|
|
+ image = torch.flatten(image, 1)
|
|
|
+ print("Flatten shape:", image.shape)
|
|
|
+ image = self.fc_image(image)
|
|
|
+
|
|
|
+ clin_data = self.fc_clin1(clin_data)
|
|
|
+ clin_data = self.fc_clin2(clin_data)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
x = torch.cat((image, clin_data), dim=1)
|
|
|
- x = self.final_layers(x)
|
|
|
+ x = self.dense1(x)
|
|
|
+ x = self.dense2(x)
|
|
|
+ x = self.softmax(x)
|
|
|
return x
|
|
|
|
|
|
|