|
@@ -4,7 +4,7 @@ import torch
|
|
|
from typing import Tuple
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
-class SepConv3d(nn.Module):
|
|
|
|
|
|
|
+class SepCNV3d(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
in_channels: int,
|
|
in_channels: int,
|
|
@@ -14,7 +14,7 @@ class SepConv3d(nn.Module):
|
|
|
padding: int | str = 0,
|
|
padding: int | str = 0,
|
|
|
bias: bool = False,
|
|
bias: bool = False,
|
|
|
):
|
|
):
|
|
|
- super(SepConv3d, self).__init__()
|
|
|
|
|
|
|
+ super(SepCNV3d, self).__init__()
|
|
|
self.depthwise = nn.Conv3d(
|
|
self.depthwise = nn.Conv3d(
|
|
|
in_channels,
|
|
in_channels,
|
|
|
out_channels,
|
|
out_channels,
|
|
@@ -30,7 +30,7 @@ class SepConv3d(nn.Module):
|
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
-class SplitConvBlock(nn.Module):
|
|
|
|
|
|
|
+class SplitCNVBlock(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
in_channels: int,
|
|
in_channels: int,
|
|
@@ -39,29 +39,29 @@ class SplitConvBlock(nn.Module):
|
|
|
split_dim: int,
|
|
split_dim: int,
|
|
|
drop_rate: float,
|
|
drop_rate: float,
|
|
|
):
|
|
):
|
|
|
- super(SplitConvBlock, self).__init__()
|
|
|
|
|
|
|
+ super(SplitCNVBlock, self).__init__()
|
|
|
|
|
|
|
|
self.split_dim = split_dim
|
|
self.split_dim = split_dim
|
|
|
|
|
|
|
|
- self.leftconv_1 = SepConvBlock(
|
|
|
|
|
|
|
+ self.leftcnv_1 = SepCNVBlock(
|
|
|
in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
)
|
|
|
- self.rightconv_1 = SepConvBlock(
|
|
|
|
|
|
|
+ self.rightcnv_1 = SepCNVBlock(
|
|
|
in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- self.leftconv_2 = SepConvBlock(
|
|
|
|
|
|
|
+ self.leftcnv_2 = SepCNVBlock(
|
|
|
mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
)
|
|
|
- self.rightconv_2 = SepConvBlock(
|
|
|
|
|
|
|
+ self.rightcnv_2 = SepCNVBlock(
|
|
|
mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
|
(left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
|
|
(left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
|
|
|
|
|
|
|
|
- self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
|
|
|
|
|
- self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
|
|
|
|
|
|
|
+ self.leftblock = nn.Sequential(self.leftcnv_1, self.leftcnv_2)
|
|
|
|
|
+ self.rightblock = nn.Sequential(self.rightcnv_1, self.rightcnv_2)
|
|
|
|
|
|
|
|
left = self.leftblock(left)
|
|
left = self.leftblock(left)
|
|
|
right = self.rightblock(right)
|
|
right = self.rightblock(right)
|
|
@@ -73,25 +73,25 @@ class MidFlowBlock(nn.Module):
|
|
|
def __init__(self, channels: int, drop_rate: float):
|
|
def __init__(self, channels: int, drop_rate: float):
|
|
|
super(MidFlowBlock, self).__init__()
|
|
super(MidFlowBlock, self).__init__()
|
|
|
|
|
|
|
|
- self.conv1 = ConvBlock(
|
|
|
|
|
|
|
+ self.cnv1 = CNVBlock(
|
|
|
channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
)
|
|
|
- self.conv2 = ConvBlock(
|
|
|
|
|
|
|
+ self.cnv2 = CNVBlock(
|
|
|
channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
)
|
|
|
- self.conv3 = ConvBlock(
|
|
|
|
|
|
|
+ self.cnv3 = CNVBlock(
|
|
|
channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
|
|
|
|
|
- self.block = self.conv1
|
|
|
|
|
|
|
+ # self.block = nn.Sequential(self.cnv1, self.cnv2, self.cnv3)
|
|
|
|
|
+ self.block = self.cnv1
|
|
|
|
|
|
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
|
a = nn.ELU()(self.block(x) + x)
|
|
a = nn.ELU()(self.block(x) + x)
|
|
|
return a
|
|
return a
|
|
|
|
|
|
|
|
|
|
|
|
|
-class ConvBlock(nn.Module):
|
|
|
|
|
|
|
+class CNVBlock(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
in_channels: int,
|
|
in_channels: int,
|
|
@@ -102,8 +102,8 @@ class ConvBlock(nn.Module):
|
|
|
droprate: float = 0.0,
|
|
droprate: float = 0.0,
|
|
|
pool: bool = False,
|
|
pool: bool = False,
|
|
|
):
|
|
):
|
|
|
- super(ConvBlock, self).__init__()
|
|
|
|
|
- self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
|
|
|
|
+ super(CNVBlock, self).__init__()
|
|
|
|
|
+ self.cnv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
|
self.elu = nn.ELU()
|
|
self.elu = nn.ELU()
|
|
|
self.dropout = nn.Dropout(droprate)
|
|
self.dropout = nn.Dropout(droprate)
|
|
@@ -114,7 +114,7 @@ class ConvBlock(nn.Module):
|
|
|
self.maxpool = None
|
|
self.maxpool = None
|
|
|
|
|
|
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
|
- a = self.conv(x)
|
|
|
|
|
|
|
+ a = self.cnv(x)
|
|
|
a = self.norm(a)
|
|
a = self.norm(a)
|
|
|
a = self.elu(a)
|
|
a = self.elu(a)
|
|
|
|
|
|
|
@@ -142,7 +142,7 @@ class FullConnBlock(nn.Module):
|
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
-class SepConvBlock(nn.Module):
|
|
|
|
|
|
|
+class SepCNVBlock(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
in_channels: int,
|
|
in_channels: int,
|
|
@@ -153,8 +153,8 @@ class SepConvBlock(nn.Module):
|
|
|
droprate: float = 0.0,
|
|
droprate: float = 0.0,
|
|
|
pool: bool = False,
|
|
pool: bool = False,
|
|
|
):
|
|
):
|
|
|
- super(SepConvBlock, self).__init__()
|
|
|
|
|
- self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
|
|
|
|
+ super(SepCNVBlock, self).__init__()
|
|
|
|
|
+ self.cnv = SepCNV3d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
self.norm = nn.BatchNorm3d(out_channels)
|
|
|
self.elu = nn.ELU()
|
|
self.elu = nn.ELU()
|
|
|
self.dropout = nn.Dropout(droprate)
|
|
self.dropout = nn.Dropout(droprate)
|
|
@@ -165,7 +165,7 @@ class SepConvBlock(nn.Module):
|
|
|
self.maxpool = None
|
|
self.maxpool = None
|
|
|
|
|
|
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
def forward(self, x: Float[torch.Tensor, "N C D H W"]):
|
|
|
- x = self.conv(x)
|
|
|
|
|
|
|
+ x = self.cnv(x)
|
|
|
x = self.norm(x)
|
|
x = self.norm(x)
|
|
|
x = self.elu(x)
|
|
x = self.elu(x)
|
|
|
|
|
|