layers.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import torch
  7. import torchvision
  8. class SepConv3d(nn.Module):
  9. def __init__(
  10. self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False
  11. ):
  12. super(SepConv3d, self).__init__()
  13. self.depthwise = nn.Conv3d(
  14. in_channels,
  15. out_channels,
  16. kernel_size,
  17. groups=out_channels,
  18. padding=padding,
  19. bias=bias,
  20. stride=stride,
  21. )
  22. def forward(self, x):
  23. x = self.depthwise(x)
  24. return x
  25. class SplitConvBlock(nn.Module):
  26. def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
  27. super(SplitConvBlock, self).__init__()
  28. self.split_dim = split_dim
  29. self.leftconv_1 = SepConvBlock(
  30. in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
  31. )
  32. self.rightconv_1 = SepConvBlock(
  33. in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
  34. )
  35. self.leftconv_2 = SepConvBlock(
  36. mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
  37. )
  38. self.rightconv_2 = SepConvBlock(
  39. mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
  40. )
  41. def forward(self, x):
  42. (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
  43. self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
  44. self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
  45. left = self.leftblock(left)
  46. right = self.rightblock(right)
  47. x = torch.cat((left, right), dim=self.split_dim)
  48. return x
  49. class MidFlowBlock(nn.Module):
  50. def __init__(self, channels, drop_rate):
  51. super(MidFlowBlock, self).__init__()
  52. self.conv1 = ConvBlock(
  53. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  54. )
  55. self.conv2 = ConvBlock(
  56. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  57. )
  58. self.conv3 = ConvBlock(
  59. channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
  60. )
  61. # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
  62. self.block = self.conv1
  63. def forward(self, x):
  64. x = nn.ELU()(self.block(x) + x)
  65. return x
  66. class ConvBlock(nn.Module):
  67. def __init__(
  68. self,
  69. in_channels,
  70. out_channels,
  71. kernel_size,
  72. stride=(1, 1, 1),
  73. padding="valid",
  74. droprate=None,
  75. pool=False,
  76. ):
  77. super(ConvBlock, self).__init__()
  78. self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
  79. self.norm = nn.BatchNorm3d(out_channels)
  80. self.elu = nn.ELU()
  81. self.dropout = nn.Dropout(droprate)
  82. if pool:
  83. self.maxpool = nn.MaxPool3d(3, stride=2)
  84. else:
  85. self.maxpool = None
  86. def forward(self, x):
  87. x = self.conv(x)
  88. x = self.norm(x)
  89. x = self.elu(x)
  90. if self.maxpool:
  91. x = self.maxpool(x)
  92. x = self.dropout(x)
  93. return x
  94. class FullConnBlock(nn.Module):
  95. def __init__(self, in_channels, out_channels, droprate=0.0):
  96. super(FullConnBlock, self).__init__()
  97. self.dense = nn.Linear(in_channels, out_channels)
  98. self.norm = nn.BatchNorm1d(out_channels)
  99. self.elu = nn.ELU()
  100. self.dropout = nn.Dropout(droprate)
  101. def forward(self, x):
  102. x = self.dense(x)
  103. x = self.norm(x)
  104. x = self.elu(x)
  105. x = self.dropout(x)
  106. return x
  107. class SepConvBlock(nn.Module):
  108. def __init__(
  109. self,
  110. in_channels,
  111. out_channels,
  112. kernel_size,
  113. stride=(1, 1, 1),
  114. padding="valid",
  115. droprate=None,
  116. pool=False,
  117. ):
  118. super(SepConvBlock, self).__init__()
  119. self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
  120. self.norm = nn.BatchNorm3d(out_channels)
  121. self.elu = nn.ELU()
  122. self.dropout = nn.Dropout(droprate)
  123. if pool:
  124. self.maxpool = nn.MaxPool3d(3, stride=2)
  125. else:
  126. self.maxpool = None
  127. def forward(self, x):
  128. x = self.conv(x)
  129. x = self.norm(x)
  130. x = self.elu(x)
  131. if self.maxpool:
  132. x = self.maxpool(x)
  133. x = self.dropout(x)
  134. return x