123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- from torch import nn
- from jaxtyping import Float
- import torch
- from typing import Tuple
- class SepConv3d(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Tuple[int, int, int],
- stride: int | Tuple[int, int, int] = 1,
- padding: int | str = 0,
- bias: bool = False,
- ):
- super(SepConv3d, self).__init__()
- self.depthwise = nn.Conv3d(
- in_channels,
- out_channels,
- kernel_size,
- groups=out_channels,
- padding=padding,
- bias=bias,
- stride=stride,
- )
- def forward(self, x: Float[torch.Tensor, "N C D H W"]):
- x = self.depthwise(x)
- return x
- class SplitConvBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- mid_channels: int,
- out_channels: int,
- split_dim: int,
- drop_rate: float,
- ):
- super(SplitConvBlock, self).__init__()
- self.split_dim = split_dim
- self.leftconv_1 = SepConvBlock(
- in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
- )
- self.rightconv_1 = SepConvBlock(
- in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
- )
- self.leftconv_2 = SepConvBlock(
- mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
- )
- self.rightconv_2 = SepConvBlock(
- mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
- )
- def forward(self, x: Float[torch.Tensor, "N C D H W"]):
- (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
- self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
- self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
- left = self.leftblock(left)
- right = self.rightblock(right)
- a = torch.cat((left, right), dim=self.split_dim)
- return a
- class MidFlowBlock(nn.Module):
- def __init__(self, channels: int, drop_rate: float):
- super(MidFlowBlock, self).__init__()
- self.conv1 = ConvBlock(
- channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
- )
- self.conv2 = ConvBlock(
- channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
- )
- self.conv3 = ConvBlock(
- channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
- )
- # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
- self.block = self.conv1
- def forward(self, x: Float[torch.Tensor, "N C D H W"]):
- a = nn.ELU()(self.block(x) + x)
- return a
- class ConvBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Tuple[int, int, int],
- stride: Tuple[int, int, int] = (1, 1, 1),
- padding: str = "valid",
- droprate: float = 0.0,
- pool: bool = False,
- ):
- super(ConvBlock, self).__init__()
- self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
- self.norm = nn.BatchNorm3d(out_channels)
- self.elu = nn.ELU()
- self.dropout = nn.Dropout(droprate)
- if pool:
- self.maxpool = nn.MaxPool3d(3, stride=2)
- else:
- self.maxpool = None
- def forward(self, x: Float[torch.Tensor, "N C D H W"]):
- a = self.conv(x)
- a = self.norm(a)
- a = self.elu(a)
- if self.maxpool:
- a = self.maxpool(a)
- a = self.dropout(a)
- return a
- class FullConnBlock(nn.Module):
- def __init__(self, in_channels: int, out_channels: int, droprate: float = 0.0):
- super(FullConnBlock, self).__init__()
- self.dense = nn.Linear(in_channels, out_channels)
- self.norm = nn.BatchNorm1d(out_channels)
- self.elu = nn.ELU()
- self.dropout = nn.Dropout(droprate)
- def forward(self, x: Float[torch.Tensor, "N C"]):
- x = self.dense(x)
- x = self.norm(x)
- x = self.elu(x)
- x = self.dropout(x)
- return x
- class SepConvBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Tuple[int, int, int],
- stride: Tuple[int, int, int] = (1, 1, 1),
- padding: str | int = "valid",
- droprate: float = 0.0,
- pool: bool = False,
- ):
- super(SepConvBlock, self).__init__()
- self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
- self.norm = nn.BatchNorm3d(out_channels)
- self.elu = nn.ELU()
- self.dropout = nn.Dropout(droprate)
- if pool:
- self.maxpool = nn.MaxPool3d(3, stride=2)
- else:
- self.maxpool = None
- def forward(self, x: Float[torch.Tensor, "N C D H W"]):
- x = self.conv(x)
- x = self.norm(x)
- x = self.elu(x)
- if self.maxpool:
- x = self.maxpool(x)
- x = self.dropout(x)
- return x
|