|
@@ -3,7 +3,7 @@ from torchvision.transforms import ToTensor
|
|
import os
|
|
import os
|
|
import pandas as pd
|
|
import pandas as pd
|
|
import numpy as np
|
|
import numpy as np
|
|
-import layers as ly
|
|
|
|
|
|
+import utils.layers as ly
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torchvision
|
|
import torchvision
|
|
@@ -32,26 +32,8 @@ class CNN_Net(nn.Module):
|
|
def __init__(self, image_channels, clin_data_channels, droprate):
|
|
def __init__(self, image_channels, clin_data_channels, droprate):
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
- # Initial Convolutional Blocks
|
|
|
|
- self.conv1 = ly.ConvBlock(
|
|
|
|
- image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=False
|
|
|
|
- )
|
|
|
|
- self.conv2 = ly.ConvBlock(
|
|
|
|
- 192, 384, (5, 6, 5), droprate=droprate, pool=False
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # Midflow Block
|
|
|
|
- self.midflow = ly.MidFlowBlock(384, droprate)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- # Split Convolutional Block
|
|
|
|
- self.splitconv = ly.SplitConvBlock(384, 192, 96, 4, droprate)
|
|
|
|
-
|
|
|
|
- #Fully Connected Block
|
|
|
|
- self.fc_image = ly.FullConnBlock(96, 20, droprate=droprate)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+ #Image Section
|
|
|
|
+ self.image_section = CNN_Image_Section(image_channels, droprate)
|
|
|
|
|
|
#Data Layers, fully connected
|
|
#Data Layers, fully connected
|
|
self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
|
|
self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
|
|
@@ -61,7 +43,7 @@ class CNN_Net(nn.Module):
|
|
#Final Dense Layer
|
|
#Final Dense Layer
|
|
self.dense1 = nn.Linear(40, 5)
|
|
self.dense1 = nn.Linear(40, 5)
|
|
self.dense2 = nn.Linear(5, 2)
|
|
self.dense2 = nn.Linear(5, 2)
|
|
- self.softmax = nn.Softmax()
|
|
|
|
|
|
+ self.softmax = nn.Softmax(dim = 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -69,19 +51,14 @@ class CNN_Net(nn.Module):
|
|
|
|
|
|
image, clin_data = x
|
|
image, clin_data = x
|
|
|
|
|
|
-
|
|
|
|
- image = self.conv1(image)
|
|
|
|
- image = self.conv2(image)
|
|
|
|
- image = self.midflow(image)
|
|
|
|
- image = self.splitconv(image)
|
|
|
|
- image = torch.flatten(image, 1)
|
|
|
|
- image = self.fc_image(image)
|
|
|
|
|
|
+ image = self.image_section(image)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
|
|
clin_data = self.fc_clin1(clin_data)
|
|
clin_data = self.fc_clin1(clin_data)
|
|
clin_data = self.fc_clin2(clin_data)
|
|
clin_data = self.fc_clin2(clin_data)
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
x = torch.cat((image, clin_data), dim=1)
|
|
x = torch.cat((image, clin_data), dim=1)
|
|
x = self.dense1(x)
|
|
x = self.dense1(x)
|
|
x = self.dense2(x)
|
|
x = self.dense2(x)
|
|
@@ -119,12 +96,9 @@ class CNN_Image_Section(nn.Module):
|
|
x = self.conv2(x)
|
|
x = self.conv2(x)
|
|
x = self.midflow(x)
|
|
x = self.midflow(x)
|
|
x = self.splitconv(x)
|
|
x = self.splitconv(x)
|
|
- print(x.shape)
|
|
|
|
x = torch.flatten(x, 1)
|
|
x = torch.flatten(x, 1)
|
|
x = self.fc_image(x)
|
|
x = self.fc_image(x)
|
|
|
|
|
|
|
|
+ return x
|
|
|
|
|
|
|
|
|
|
-print(ts.summary(CNN_Image_Section(1, 0.5).cuda(), (1, 91, 109, 91)))
|
|
|
|
-
|
|
|
|
-
|
|
|