layers.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import torch
  7. import torchvision
  8. class SeperableConv3d(nn.Module):
  9. def __init__(
  10. self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False
  11. ):
  12. super(SeperableConv3d, self).__init__()
  13. self.depthwise = nn.Conv3d(
  14. in_channels,
  15. out_channels,
  16. kernel_size,
  17. groups=out_channels,
  18. padding=padding,
  19. bias=bias,
  20. stride=stride,
  21. )
  22. def forward(self, x):
  23. x = self.depthwise(x)
  24. return x
  25. class SplitConvBlock(nn.Module):
  26. def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
  27. super(SplitConvBlock, self).__init__()
  28. self.split_dim = split_dim
  29. self.leftconv_1 = SeperableConvolutionalBlock(
  30. in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
  31. )
  32. self.rightconv_1 = SeperableConvolutionalBlock(
  33. in_channels //2, mid_channels //2, (3, 4, 3), droprate=drop_rate
  34. )
  35. self.leftconv_2 = SeperableConvolutionalBlock(
  36. mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
  37. )
  38. self.rightconv_2 = SeperableConvolutionalBlock(
  39. mid_channels //2, out_channels //2, (3, 4, 3), droprate=drop_rate
  40. )
  41. def forward(self, x):
  42. (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
  43. self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
  44. self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
  45. left = self.leftblock(left)
  46. right = self.rightblock(right)
  47. x = torch.cat((left, right), dim=self.split_dim)
  48. return x
  49. class MidFlowBlock(nn.Module):
  50. def __init__(self, channels, drop_rate):
  51. super(MidFlowBlock, self).__init__()
  52. self.conv1 = ConvolutionalBlock(
  53. channels, channels, (3, 3, 3), droprate=drop_rate
  54. )
  55. self.conv2 = ConvolutionalBlock(
  56. channels, channels, (3, 3, 3), droprate=drop_rate
  57. )
  58. self.conv3 = ConvolutionalBlock(
  59. channels, channels, (3, 3, 3), droprate=drop_rate
  60. )
  61. #self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
  62. self.block = self.conv1
  63. def forward(self, x):
  64. x = nn.ELU(self.block(x) + x)
  65. return
  66. class ConvolutionalBlock(nn.Module):
  67. def __init__(
  68. self,
  69. in_channels,
  70. out_channels,
  71. kernel_size,
  72. stride=(1, 1, 1),
  73. padding="valid",
  74. droprate=None,
  75. pool=False,
  76. ):
  77. super(ConvolutionalBlock, self).__init__()
  78. self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
  79. self.norm = nn.BatchNorm3d(out_channels)
  80. self.elu = nn.ELU()
  81. self.dropout = nn.Dropout(droprate)
  82. if pool:
  83. self.maxpool = nn.MaxPool3d(3, stride=2)
  84. else:
  85. self.maxpool = None
  86. def forward(self, x):
  87. x = self.conv(x)
  88. x = self.norm(x)
  89. x = self.elu(x)
  90. if self.maxpool:
  91. x = self.maxpool(x)
  92. x = self.dropout(x)
  93. return x
  94. class FullyConnectedBlock(nn.Module):
  95. def __init__(self, in_channels, out_channels, droprate=0.0):
  96. super(FullyConnectedBlock, self).__init__()
  97. self.dense = nn.Linear(in_channels, out_channels)
  98. self.norm = nn.BatchNorm1d(out_channels)
  99. self.elu = nn.ELU()
  100. self.dropout = nn.Dropout(droprate)
  101. def forward(self, x):
  102. x = self.dense(x)
  103. x = self.norm(x)
  104. x = self.elu(x)
  105. x = self.dropout(x)
  106. return x
  107. class SeperableConvolutionalBlock(nn.Module):
  108. def __init__(
  109. self,
  110. in_channels,
  111. out_channels,
  112. kernel_size,
  113. stride = (1, 1, 1),
  114. padding = "valid",
  115. droprate = None,
  116. pool = False,
  117. ):
  118. super(SeperableConvolutionalBlock, self).__init__()
  119. self.conv = SeperableConv3d(in_channels, out_channels, kernel_size, stride, padding)
  120. self.norm = nn.BatchNorm3d(out_channels)
  121. self.elu = nn.ELU()
  122. self.dropout = nn.Dropout(droprate)
  123. if pool:
  124. self.maxpool = nn.MaxPool3d(3, stride=2)
  125. else:
  126. self.maxpool = None
  127. def forward(self, x):
  128. x = self.conv(x)
  129. x = self.norm(x)
  130. x = self.elu(x)
  131. if self.maxpool:
  132. x = self.maxpool(x)
  133. x = self.dropout(x)
  134. return x