|
@@ -26,7 +26,7 @@ class SepConv3d(nn.Module):
|
|
|
def forward(self, x):
|
|
|
x = self.depthwise(x)
|
|
|
return x
|
|
|
-
|
|
|
+
|
|
|
|
|
|
class SplitConvBlock(nn.Module):
|
|
|
def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
|
|
@@ -35,25 +35,22 @@ class SplitConvBlock(nn.Module):
|
|
|
self.split_dim = split_dim
|
|
|
|
|
|
self.leftconv_1 = SepConvBlock(
|
|
|
- in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
+ in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
self.rightconv_1 = SepConvBlock(
|
|
|
- in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
+ in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
|
|
|
self.leftconv_2 = SepConvBlock(
|
|
|
- mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
+ mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
self.rightconv_2 = SepConvBlock(
|
|
|
- mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
|
|
|
+ mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
|
|
|
|
-
|
|
|
-
|
|
|
def forward(self, x):
|
|
|
(left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
|
|
|
|
|
|
-
|
|
|
self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
|
|
|
self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
|
|
|
|
|
@@ -61,47 +58,47 @@ class SplitConvBlock(nn.Module):
|
|
|
right = self.rightblock(right)
|
|
|
x = torch.cat((left, right), dim=self.split_dim)
|
|
|
return x
|
|
|
-
|
|
|
+
|
|
|
|
|
|
class MidFlowBlock(nn.Module):
|
|
|
def __init__(self, channels, drop_rate):
|
|
|
super(MidFlowBlock, self).__init__()
|
|
|
|
|
|
self.conv1 = ConvBlock(
|
|
|
- channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
+ channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
|
self.conv2 = ConvBlock(
|
|
|
channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
|
self.conv3 = ConvBlock(
|
|
|
- channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
+ channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
|
|
|
|
- #self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
|
|
|
+ # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
|
|
|
self.block = self.conv1
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = nn.ELU()(self.block(x) + x)
|
|
|
return x
|
|
|
|
|
|
-
|
|
|
+
|
|
|
class ConvBlock(nn.Module):
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- in_channels,
|
|
|
- out_channels,
|
|
|
- kernel_size,
|
|
|
- stride=(1, 1, 1),
|
|
|
- padding="valid",
|
|
|
- droprate=None,
|
|
|
- pool=False,
|
|
|
+ self,
|
|
|
+ in_channels,
|
|
|
+ out_channels,
|
|
|
+ kernel_size,
|
|
|
+ stride=(1, 1, 1),
|
|
|
+ padding="valid",
|
|
|
+ droprate=None,
|
|
|
+ pool=False,
|
|
|
):
|
|
|
super(ConvBlock, self).__init__()
|
|
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
|
self.elu = nn.ELU()
|
|
|
self.dropout = nn.Dropout(droprate)
|
|
|
-
|
|
|
+
|
|
|
if pool:
|
|
|
self.maxpool = nn.MaxPool3d(3, stride=2)
|
|
|
else:
|
|
@@ -111,13 +108,12 @@ class ConvBlock(nn.Module):
|
|
|
x = self.conv(x)
|
|
|
x = self.norm(x)
|
|
|
x = self.elu(x)
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
if self.maxpool:
|
|
|
x = self.maxpool(x)
|
|
|
|
|
|
x = self.dropout(x)
|
|
|
-
|
|
|
+
|
|
|
return x
|
|
|
|
|
|
|
|
@@ -135,25 +131,25 @@ class FullConnBlock(nn.Module):
|
|
|
x = self.elu(x)
|
|
|
x = self.dropout(x)
|
|
|
return x
|
|
|
-
|
|
|
+
|
|
|
|
|
|
class SepConvBlock(nn.Module):
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- in_channels,
|
|
|
- out_channels,
|
|
|
- kernel_size,
|
|
|
- stride = (1, 1, 1),
|
|
|
- padding = "valid",
|
|
|
- droprate = None,
|
|
|
- pool = False,
|
|
|
+ self,
|
|
|
+ in_channels,
|
|
|
+ out_channels,
|
|
|
+ kernel_size,
|
|
|
+ stride=(1, 1, 1),
|
|
|
+ padding="valid",
|
|
|
+ droprate=None,
|
|
|
+ pool=False,
|
|
|
):
|
|
|
super(SepConvBlock, self).__init__()
|
|
|
self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
|
self.elu = nn.ELU()
|
|
|
self.dropout = nn.Dropout(droprate)
|
|
|
-
|
|
|
+
|
|
|
if pool:
|
|
|
self.maxpool = nn.MaxPool3d(3, stride=2)
|
|
|
else:
|
|
@@ -163,11 +159,10 @@ class SepConvBlock(nn.Module):
|
|
|
x = self.conv(x)
|
|
|
x = self.norm(x)
|
|
|
x = self.elu(x)
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
if self.maxpool:
|
|
|
x = self.maxpool(x)
|
|
|
|
|
|
x = self.dropout(x)
|
|
|
|
|
|
- return x
|
|
|
+ return x
|