layers.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import torch
  7. import torchvision
  8. class SeperableConv3d(nn.Module):
  9. def __init__(
  10. self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False
  11. ):
  12. super(SeperableConv3d, self).__init__()
  13. self.depthwise = nn.Conv3d(
  14. in_channels,
  15. out_channels,
  16. kernel_size,
  17. groups=out_channels,
  18. padding=padding,
  19. bias=bias,
  20. stride=stride,
  21. )
  22. def forward(self, x):
  23. x = self.depthwise(x)
  24. return x
  25. class SplitConvBlock(nn.Module):
  26. def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
  27. super(SplitConvBlock, self).__init__()
  28. self.split_dim = split_dim
  29. self.leftconv_1 = SeperableConvolutionalBlock(
  30. in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
  31. )
  32. self.rightconv_1 = SeperableConvolutionalBlock(
  33. in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
  34. )
  35. self.leftconv_2 = SeperableConvolutionalBlock(
  36. mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
  37. )
  38. self.rightconv_2 = SeperableConvolutionalBlock(
  39. mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
  40. )
  41. def forward(self, x):
  42. print("SplitConvBlock in: ", x.shape)
  43. (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
  44. self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
  45. self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
  46. left = self.leftblock(left)
  47. right = self.rightblock(right)
  48. x = torch.cat((left, right), dim=self.split_dim)
  49. print("SplitConvBlock out: ", x.shape)
  50. return x
  51. class MidFlowBlock(nn.Module):
  52. def __init__(self, channels, drop_rate):
  53. super(MidFlowBlock, self).__init__()
  54. self.conv1 = ConvolutionalBlock(
  55. channels, channels, (3, 3, 3), droprate=drop_rate
  56. )
  57. self.conv2 = ConvolutionalBlock(
  58. channels, channels, (3, 3, 3), droprate=drop_rate
  59. )
  60. self.conv3 = ConvolutionalBlock(
  61. channels, channels, (3, 3, 3), droprate=drop_rate
  62. )
  63. #self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
  64. self.block = self.conv1
  65. def forward(self, x):
  66. print("MidFlowBlock in: ", x.shape)
  67. x = nn.ELU(self.block(x) + x)
  68. print("MidFlowBlock out: ", x.shape)
  69. return
  70. class ConvolutionalBlock(nn.Module):
  71. def __init__(
  72. self,
  73. in_channels,
  74. out_channels,
  75. kernel_size,
  76. stride=(1, 1, 1),
  77. padding="valid",
  78. droprate=None,
  79. pool=False,
  80. ):
  81. super(ConvolutionalBlock, self).__init__()
  82. self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
  83. self.norm = nn.BatchNorm3d(out_channels)
  84. self.elu = nn.ELU()
  85. self.dropout = nn.Dropout(droprate)
  86. if pool:
  87. self.maxpool = nn.MaxPool3d(3, stride=2)
  88. else:
  89. self.maxpool = None
  90. def forward(self, x):
  91. print("ConvBlock in: ", x.shape)
  92. x = self.conv(x)
  93. x = self.norm(x)
  94. x = self.elu(x)
  95. if self.maxpool:
  96. x = self.maxpool(x)
  97. x = self.dropout(x)
  98. print("ConvBlock out: ", x.shape)
  99. return x
  100. class FullyConnectedBlock(nn.Module):
  101. def __init__(self, in_channels, out_channels, droprate=0.0):
  102. super(FullyConnectedBlock, self).__init__()
  103. self.dense = nn.Linear(in_channels, out_channels)
  104. self.norm = nn.BatchNorm1d(out_channels)
  105. self.elu = nn.ELU()
  106. self.dropout = nn.Dropout(droprate)
  107. def forward(self, x):
  108. print("FullyConnectedBlock in: ", x.shape)
  109. x = self.dense(x)
  110. x = self.norm(x)
  111. x = self.elu(x)
  112. x = self.dropout(x)
  113. print("FullyConnectedBlock out: ", x.shape)
  114. return x
  115. class SeperableConvolutionalBlock(nn.Module):
  116. def __init__(
  117. self,
  118. in_channels,
  119. out_channels,
  120. kernel_size,
  121. stride = (1, 1, 1),
  122. padding = "valid",
  123. droprate = None,
  124. pool = False,
  125. ):
  126. super(SeperableConvolutionalBlock, self).__init__()
  127. self.conv = SeperableConv3d(in_channels, out_channels, kernel_size, stride, padding)
  128. self.norm = nn.BatchNorm3d(out_channels)
  129. self.elu = nn.ELU()
  130. self.dropout = nn.Dropout(droprate)
  131. if pool:
  132. self.maxpool = nn.MaxPool3d(3, stride=2)
  133. else:
  134. self.maxpool = None
  135. def forward(self, x):
  136. print("SeperableConvBlock in: ", x.shape)
  137. x = self.conv(x)
  138. x = self.norm(x)
  139. x = self.elu(x)
  140. if self.maxpool:
  141. x = self.maxpool(x)
  142. x = self.dropout(x)
  143. print("SeperableConvBlock out: ", x.shape)
  144. return x