layers.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from torch import nn
  2. import torch
  3. class SepConv3d(nn.Module):
  4. def __init__(
  5. self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False
  6. ):
  7. super(SepConv3d, self).__init__()
  8. self.depthwise = nn.Conv3d(
  9. in_channels,
  10. out_channels,
  11. kernel_size,
  12. groups=out_channels,
  13. padding=padding,
  14. bias=bias,
  15. stride=stride,
  16. )
  17. def forward(self, x):
  18. x = self.depthwise(x)
  19. return x
  20. class SplitConvBlock(nn.Module):
  21. def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
  22. super(SplitConvBlock, self).__init__()
  23. self.split_dim = split_dim
  24. self.leftconv_1 = SepConvBlock(
  25. in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
  26. )
  27. self.rightconv_1 = SepConvBlock(
  28. in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
  29. )
  30. self.leftconv_2 = SepConvBlock(
  31. mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
  32. )
  33. self.rightconv_2 = SepConvBlock(
  34. mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
  35. )
  36. def forward(self, x):
  37. (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
  38. self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
  39. self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
  40. left = self.leftblock(left)
  41. right = self.rightblock(right)
  42. x = torch.cat((left, right), dim=self.split_dim)
  43. return x
  44. class MidFlowBlock(nn.Module):
  45. def __init__(self, channels, drop_rate):
  46. super(MidFlowBlock, self).__init__()
  47. self.conv1 = ConvBlock(
  48. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  49. )
  50. self.conv2 = ConvBlock(
  51. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  52. )
  53. self.conv3 = ConvBlock(
  54. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  55. )
  56. # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
  57. self.block = self.conv1
  58. def forward(self, x):
  59. x = nn.ELU()(self.block(x) + x)
  60. return x
  61. class ConvBlock(nn.Module):
  62. def __init__(
  63. self,
  64. in_channels,
  65. out_channels,
  66. kernel_size,
  67. stride=(1, 1, 1),
  68. padding="valid",
  69. droprate=None,
  70. pool=False,
  71. ):
  72. super(ConvBlock, self).__init__()
  73. self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
  74. self.norm = nn.BatchNorm3d(out_channels)
  75. self.elu = nn.ELU()
  76. self.dropout = nn.Dropout(droprate)
  77. if pool:
  78. self.maxpool = nn.MaxPool3d(3, stride=2)
  79. else:
  80. self.maxpool = None
  81. def forward(self, x):
  82. x = self.conv(x)
  83. x = self.norm(x)
  84. x = self.elu(x)
  85. if self.maxpool:
  86. x = self.maxpool(x)
  87. x = self.dropout(x)
  88. return x
  89. class FullConnBlock(nn.Module):
  90. def __init__(self, in_channels, out_channels, droprate=0.0):
  91. super(FullConnBlock, self).__init__()
  92. self.dense = nn.Linear(in_channels, out_channels)
  93. self.norm = nn.BatchNorm1d(out_channels)
  94. self.elu = nn.ELU()
  95. self.dropout = nn.Dropout(droprate)
  96. def forward(self, x):
  97. x = self.dense(x)
  98. x = self.norm(x)
  99. x = self.elu(x)
  100. x = self.dropout(x)
  101. return x
  102. class SepConvBlock(nn.Module):
  103. def __init__(
  104. self,
  105. in_channels,
  106. out_channels,
  107. kernel_size,
  108. stride=(1, 1, 1),
  109. padding="valid",
  110. droprate=None,
  111. pool=False,
  112. ):
  113. super(SepConvBlock, self).__init__()
  114. self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
  115. self.norm = nn.BatchNorm3d(out_channels)
  116. self.elu = nn.ELU()
  117. self.dropout = nn.Dropout(droprate)
  118. if pool:
  119. self.maxpool = nn.MaxPool3d(3, stride=2)
  120. else:
  121. self.maxpool = None
  122. def forward(self, x):
  123. x = self.conv(x)
  124. x = self.norm(x)
  125. x = self.elu(x)
  126. if self.maxpool:
  127. x = self.maxpool(x)
  128. x = self.dropout(x)
  129. return x