|
@@ -8,11 +8,11 @@ import torch
|
|
|
import torchvision
|
|
|
|
|
|
|
|
|
-class SeperableConv3d(nn.Module):
|
|
|
+class SepConv3d(nn.Module):
|
|
|
def __init__(
|
|
|
self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False
|
|
|
):
|
|
|
- super(SeperableConv3d, self).__init__()
|
|
|
+ super(SepConv3d, self).__init__()
|
|
|
self.depthwise = nn.Conv3d(
|
|
|
in_channels,
|
|
|
out_channels,
|
|
@@ -34,17 +34,17 @@ class SplitConvBlock(nn.Module):
|
|
|
|
|
|
self.split_dim = split_dim
|
|
|
|
|
|
- self.leftconv_1 = SeperableConvolutionalBlock(
|
|
|
+ self.leftconv_1 = SepConvBlock(
|
|
|
in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
- self.rightconv_1 = SeperableConvolutionalBlock(
|
|
|
+ self.rightconv_1 = SepConvBlock(
|
|
|
in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
|
|
|
- self.leftconv_2 = SeperableConvolutionalBlock(
|
|
|
+ self.leftconv_2 = SepConvBlock(
|
|
|
mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
- self.rightconv_2 = SeperableConvolutionalBlock(
|
|
|
+ self.rightconv_2 = SepConvBlock(
|
|
|
mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
|
|
@@ -53,6 +53,8 @@ class SplitConvBlock(nn.Module):
|
|
|
def forward(self, x):
|
|
|
(left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
|
|
|
|
|
|
+ print(left.shape, right.shape)
|
|
|
+
|
|
|
self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
|
|
|
self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
|
|
|
|
|
@@ -66,25 +68,25 @@ class MidFlowBlock(nn.Module):
|
|
|
def __init__(self, channels, drop_rate):
|
|
|
super(MidFlowBlock, self).__init__()
|
|
|
|
|
|
- self.conv1 = ConvolutionalBlock(
|
|
|
- channels, channels, (3, 3, 3), droprate=drop_rate
|
|
|
+ self.conv1 = ConvBlock(
|
|
|
+ channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
|
- self.conv2 = ConvolutionalBlock(
|
|
|
- channels, channels, (3, 3, 3), droprate=drop_rate
|
|
|
+ self.conv2 = ConvBlock(
|
|
|
+ channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
|
- self.conv3 = ConvolutionalBlock(
|
|
|
- channels, channels, (3, 3, 3), droprate=drop_rate
|
|
|
+ self.conv3 = ConvBlock(
|
|
|
+ channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
|
|
|
|
#self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
|
|
|
self.block = self.conv1
|
|
|
|
|
|
def forward(self, x):
|
|
|
- x = nn.ELU(self.block(x) + x)
|
|
|
- return
|
|
|
+ x = nn.ELU()(self.block(x) + x)
|
|
|
+ return x
|
|
|
|
|
|
|
|
|
-class ConvolutionalBlock(nn.Module):
|
|
|
+class ConvBlock(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels,
|
|
@@ -95,7 +97,7 @@ class ConvolutionalBlock(nn.Module):
|
|
|
droprate=None,
|
|
|
pool=False,
|
|
|
):
|
|
|
- super(ConvolutionalBlock, self).__init__()
|
|
|
+ super(ConvBlock, self).__init__()
|
|
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
|
self.elu = nn.ELU()
|
|
@@ -120,9 +122,9 @@ class ConvolutionalBlock(nn.Module):
|
|
|
return x
|
|
|
|
|
|
|
|
|
-class FullyConnectedBlock(nn.Module):
|
|
|
+class FullConnBlock(nn.Module):
|
|
|
def __init__(self, in_channels, out_channels, droprate=0.0):
|
|
|
- super(FullyConnectedBlock, self).__init__()
|
|
|
+ super(FullConnBlock, self).__init__()
|
|
|
self.dense = nn.Linear(in_channels, out_channels)
|
|
|
self.norm = nn.BatchNorm1d(out_channels)
|
|
|
self.elu = nn.ELU()
|
|
@@ -136,7 +138,7 @@ class FullyConnectedBlock(nn.Module):
|
|
|
return x
|
|
|
|
|
|
|
|
|
-class SeperableConvolutionalBlock(nn.Module):
|
|
|
+class SepConvBlock(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels,
|
|
@@ -147,8 +149,8 @@ class SeperableConvolutionalBlock(nn.Module):
|
|
|
droprate = None,
|
|
|
pool = False,
|
|
|
):
|
|
|
- super(SeperableConvolutionalBlock, self).__init__()
|
|
|
- self.conv = SeperableConv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
+ super(SepConvBlock, self).__init__()
|
|
|
+ self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
|
self.elu = nn.ELU()
|
|
|
self.dropout = nn.Dropout(droprate)
|