layers.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from torch import nn
  2. from jaxtyping import Float
  3. import torch
  4. from typing import Tuple
  5. class SepConv3d(nn.Module):
  6. def __init__(
  7. self,
  8. in_channels: int,
  9. out_channels: int,
  10. kernel_size: Tuple[int, int, int],
  11. stride: int | Tuple[int, int, int] = 1,
  12. padding: int | str = 0,
  13. bias: bool = False,
  14. ):
  15. super(SepConv3d, self).__init__()
  16. self.depthwise = nn.Conv3d(
  17. in_channels,
  18. out_channels,
  19. kernel_size,
  20. groups=out_channels,
  21. padding=padding,
  22. bias=bias,
  23. stride=stride,
  24. )
  25. def forward(self, x: Float[torch.Tensor, "N C D H W"]):
  26. x = self.depthwise(x)
  27. return x
  28. class SplitConvBlock(nn.Module):
  29. def __init__(
  30. self,
  31. in_channels: int,
  32. mid_channels: int,
  33. out_channels: int,
  34. split_dim: int,
  35. drop_rate: float,
  36. ):
  37. super(SplitConvBlock, self).__init__()
  38. self.split_dim = split_dim
  39. self.leftconv_1 = SepConvBlock(
  40. in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
  41. )
  42. self.rightconv_1 = SepConvBlock(
  43. in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
  44. )
  45. self.leftconv_2 = SepConvBlock(
  46. mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
  47. )
  48. self.rightconv_2 = SepConvBlock(
  49. mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
  50. )
  51. def forward(self, x: Float[torch.Tensor, "N C D H W"]):
  52. (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
  53. self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
  54. self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
  55. left = self.leftblock(left)
  56. right = self.rightblock(right)
  57. a = torch.cat((left, right), dim=self.split_dim)
  58. return a
  59. class MidFlowBlock(nn.Module):
  60. def __init__(self, channels: int, drop_rate: float):
  61. super(MidFlowBlock, self).__init__()
  62. self.conv1 = ConvBlock(
  63. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  64. )
  65. self.conv2 = ConvBlock(
  66. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  67. )
  68. self.conv3 = ConvBlock(
  69. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  70. )
  71. # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
  72. self.block = self.conv1
  73. def forward(self, x: Float[torch.Tensor, "N C D H W"]):
  74. a = nn.ELU()(self.block(x) + x)
  75. return a
  76. class ConvBlock(nn.Module):
  77. def __init__(
  78. self,
  79. in_channels: int,
  80. out_channels: int,
  81. kernel_size: Tuple[int, int, int],
  82. stride: Tuple[int, int, int] = (1, 1, 1),
  83. padding: str = "valid",
  84. droprate: float = 0.0,
  85. pool: bool = False,
  86. ):
  87. super(ConvBlock, self).__init__()
  88. self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
  89. self.norm = nn.BatchNorm3d(out_channels)
  90. self.elu = nn.ELU()
  91. self.dropout = nn.Dropout(droprate)
  92. if pool:
  93. self.maxpool = nn.MaxPool3d(3, stride=2)
  94. else:
  95. self.maxpool = None
  96. def forward(self, x: Float[torch.Tensor, "N C D H W"]):
  97. a = self.conv(x)
  98. a = self.norm(a)
  99. a = self.elu(a)
  100. if self.maxpool:
  101. a = self.maxpool(a)
  102. a = self.dropout(a)
  103. return a
  104. class FullConnBlock(nn.Module):
  105. def __init__(self, in_channels: int, out_channels: int, droprate: float = 0.0):
  106. super(FullConnBlock, self).__init__()
  107. self.dense = nn.Linear(in_channels, out_channels)
  108. self.norm = nn.BatchNorm1d(out_channels)
  109. self.elu = nn.ELU()
  110. self.dropout = nn.Dropout(droprate)
  111. def forward(self, x: Float[torch.Tensor, "N C"]):
  112. x = self.dense(x)
  113. x = self.norm(x)
  114. x = self.elu(x)
  115. x = self.dropout(x)
  116. return x
  117. class SepConvBlock(nn.Module):
  118. def __init__(
  119. self,
  120. in_channels: int,
  121. out_channels: int,
  122. kernel_size: Tuple[int, int, int],
  123. stride: Tuple[int, int, int] = (1, 1, 1),,
  124. padding: str | int="valid",
  125. droprate: float = 0.0,
  126. pool: bool =False,
  127. ):
  128. super(SepConvBlock, self).__init__()
  129. self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
  130. self.norm = nn.BatchNorm3d(out_channels)
  131. self.elu = nn.ELU()
  132. self.dropout = nn.Dropout(droprate)
  133. if pool:
  134. self.maxpool = nn.MaxPool3d(3, stride=2)
  135. else:
  136. self.maxpool = None
  137. def forward(self, x: Float[torch.Tensor, "N C D H W"]):
  138. x = self.conv(x)
  139. x = self.norm(x)
  140. x = self.elu(x)
  141. if self.maxpool:
  142. x = self.maxpool(x)
  143. x = self.dropout(x)
  144. return x