from torch import nn from jaxtyping import Float import torch from typing import Tuple class SepCNV3d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple[int, int, int], stride: int | Tuple[int, int, int] = 1, padding: int | str = 0, bias: bool = False, ): super(SepCNV3d, self).__init__() self.depthwise = nn.Conv3d( in_channels, out_channels, kernel_size, groups=out_channels, padding=padding, bias=bias, stride=stride, ) def forward(self, x: Float[torch.Tensor, "N C D H W"]): x = self.depthwise(x) return x class SplitCNVBlock(nn.Module): def __init__( self, in_channels: int, mid_channels: int, out_channels: int, split_dim: int, drop_rate: float, ): super(SplitCNVBlock, self).__init__() self.split_dim = split_dim self.leftcnv_1 = SepCNVBlock( in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate ) self.rightcnv_1 = SepCNVBlock( in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate ) self.leftcnv_2 = SepCNVBlock( mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate ) self.rightcnv_2 = SepCNVBlock( mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate ) def forward(self, x: Float[torch.Tensor, "N C D H W"]): (left, right) = torch.tensor_split(x, 2, dim=self.split_dim) self.leftblock = nn.Sequential(self.leftcnv_1, self.leftcnv_2) self.rightblock = nn.Sequential(self.rightcnv_1, self.rightcnv_2) left = self.leftblock(left) right = self.rightblock(right) a = torch.cat((left, right), dim=self.split_dim) return a class MidFlowBlock(nn.Module): def __init__(self, channels: int, drop_rate: float): super(MidFlowBlock, self).__init__() self.cnv1 = CNVBlock( channels, channels, (3, 3, 3), droprate=drop_rate, padding="same" ) self.cnv2 = CNVBlock( channels, channels, (3, 3, 3), droprate=drop_rate, padding="same" ) self.cnv3 = CNVBlock( channels, channels, (3, 3, 3), droprate=drop_rate, padding="same" ) # self.block = nn.Sequential(self.cnv1, self.cnv2, self.cnv3) self.block = self.cnv1 def forward(self, x: Float[torch.Tensor, "N C D H W"]): a = nn.ELU()(self.block(x) + x) return a class CNVBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple[int, int, int], stride: Tuple[int, int, int] = (1, 1, 1), padding: str = "valid", droprate: float = 0.0, pool: bool = False, ): super(CNVBlock, self).__init__() self.cnv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding) self.norm = nn.BatchNorm3d(out_channels) self.elu = nn.ELU() self.dropout = nn.Dropout(droprate) if pool: self.maxpool = nn.MaxPool3d(3, stride=2) else: self.maxpool = None def forward(self, x: Float[torch.Tensor, "N C D H W"]): a = self.cnv(x) a = self.norm(a) a = self.elu(a) if self.maxpool: a = self.maxpool(a) a = self.dropout(a) return a class FullConnBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, droprate: float = 0.0): super(FullConnBlock, self).__init__() self.dense = nn.Linear(in_channels, out_channels) self.norm = nn.BatchNorm1d(out_channels) self.elu = nn.ELU() self.dropout = nn.Dropout(droprate) def forward(self, x: Float[torch.Tensor, "N C"]): x = self.dense(x) x = self.norm(x) x = self.elu(x) x = self.dropout(x) return x class SepCNVBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple[int, int, int], stride: Tuple[int, int, int] = (1, 1, 1), padding: str | int = "valid", droprate: float = 0.0, pool: bool = False, ): super(SepCNVBlock, self).__init__() self.cnv = SepCNV3d(in_channels, out_channels, kernel_size, stride, padding) self.norm = nn.BatchNorm3d(out_channels) self.elu = nn.ELU() self.dropout = nn.Dropout(droprate) if pool: self.maxpool = nn.MaxPool3d(3, stride=2) else: self.maxpool = None def forward(self, x: Float[torch.Tensor, "N C D H W"]): x = self.cnv(x) x = self.norm(x) x = self.elu(x) if self.maxpool: x = self.maxpool(x) x = self.dropout(x) return x