models.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import torch
  7. import torchvision
  8. class SeperableConv3d(nn.Module):
  9. def __init__(
  10. self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False
  11. ):
  12. super(SeperableConv3d, self).__init__()
  13. self.depthwise = nn.Conv3d(
  14. in_channels,
  15. in_channels,
  16. kernel_size,
  17. groups=in_channels,
  18. padding=padding,
  19. bias=bias,
  20. stride=stride,
  21. )
  22. self.pointwise = nn.Conv3d(
  23. in_channels, out_channels, 1, padding=padding, bias=bias, stride=stride
  24. )
  25. def forward(self, x):
  26. x = self.depthwise(x)
  27. x = self.pointwise(x)
  28. return x
  29. class SplitConvBlock(nn.Module):
  30. def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
  31. super(SplitConvBlock, self).__init__()
  32. self.split_dim = split_dim
  33. self.leftconv_1 = CNN_Net.SeperableConvolutionalBlock(
  34. (3, 4, 3), in_channels //2, mid_channels //2, droprate=drop_rate
  35. )
  36. self.rightconv_1 = CNN_Net.SeperableConvolutionalBlock(
  37. (4, 3, 3), in_channels //2, mid_channels //2, droprate=drop_rate
  38. )
  39. self.leftconv_2 = CNN_Net.SeperableConvolutionalBlock(
  40. (3, 4, 3), mid_channels //2, out_channels //2, droprate=drop_rate
  41. )
  42. self.rightconv_2 = CNN_Net.SeperableConvolutionalBlock(
  43. (4, 3, 3), mid_channels //2, out_channels //2, droprate=drop_rate
  44. )
  45. def forward(self, x):
  46. (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
  47. self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
  48. self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
  49. left = self.leftblock(left)
  50. right = self.rightblock(right)
  51. return torch.cat((left, right), dim=self.split_dim)
  52. class MidFlowBlock(nn.Module):
  53. def __init__(self, channels, drop_rate):
  54. super(MidFlowBlock, self).__init__()
  55. self.conv1 = CNN_Net.SeperableConvolutionalBlock(
  56. (3, 3, 3), channels, channels, droprate=drop_rate
  57. )
  58. self.conv2 = CNN_Net.SeperableConvolutionalBlock(
  59. (3, 3, 3), channels, channels, droprate=drop_rate
  60. )
  61. self.conv3 = CNN_Net.SeperableConvolutionalBlock(
  62. (3, 3, 3), channels, channels, droprate=drop_rate
  63. )
  64. self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
  65. def forward(self, x):
  66. return nn.ELU(self.block(x) + x)
  67. class Parameters:
  68. def __init__(self, param_dict):
  69. self.CNN_w_regularizer = param_dict["CNN_w_regularizer"]
  70. self.RNN_w_regularizer = param_dict["RNN_w_regularizer"]
  71. self.CNN_batch_size = param_dict["CNN_batch_size"]
  72. self.RNN_batch_size = param_dict["RNN_batch_size"]
  73. self.CNN_drop_rate = param_dict["CNN_drop_rate"]
  74. self.RNN_drop_rate = param_dict["RNN_drop_rate"]
  75. self.epochs = param_dict["epochs"]
  76. self.gpu = param_dict["gpu"]
  77. self.model_filepath = param_dict["model_filepath"] + "/net.h5"
  78. self.num_clinical = param_dict["num_clinical"]
  79. self.image_shape = param_dict["image_shape"]
  80. self.final_layer_size = param_dict["final_layer_size"]
  81. self.optimizer = param_dict["optimizer"]
  82. class CNN_Net(nn.Module):
  83. def ConvolutionalBlock(
  84. kernel_size,
  85. in_channels,
  86. out_channels,
  87. stride=(1, 1, 1),
  88. padding="valid",
  89. droprate=None,
  90. pool=False,
  91. ):
  92. conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
  93. norm = nn.BatchNorm3d(out_channels)
  94. elu = nn.ELU()
  95. dropout = nn.Dropout(droprate)
  96. if pool:
  97. maxpool = nn.MaxPool3d(3, stride=2)
  98. return nn.Sequential(conv, norm, elu, maxpool, dropout)
  99. else:
  100. return nn.Sequential(conv, norm, elu, dropout)
  101. def FullyConnectedBlock(in_channels, out_channels, droprate=0.0):
  102. dense = nn.Linear(in_channels, out_channels)
  103. norm = nn.BatchNorm1d(out_channels)
  104. elu = nn.ELU()
  105. dropout = nn.Dropout(droprate)
  106. return nn.Sequential(dense, norm, elu, dropout)
  107. def SeperableConvolutionalBlock(
  108. kernel_size,
  109. in_channels,
  110. out_channels,
  111. stride=(1, 1, 1),
  112. padding="valid",
  113. droprate=None,
  114. pool=False,
  115. ):
  116. conv = SeperableConv3d(in_channels, out_channels, kernel_size, stride, padding)
  117. norm = nn.BatchNorm3d(out_channels)
  118. elu = nn.ELU()
  119. dropout = nn.Dropout(droprate)
  120. if pool:
  121. maxpool = nn.MaxPool3d(3, stride=2)
  122. return nn.Sequential(conv, norm, elu, maxpool, dropout)
  123. else:
  124. return nn.Sequential(conv, norm, elu, dropout)
  125. def __init__(self, image_channels, clin_data_channels, droprate, final_layer_size):
  126. super().__init__()
  127. # Initial Convolutional Blocks
  128. self.conv1 = CNN_Net.ConvolutionalBlock(
  129. (11, 13, 11), image_channels, 192, stride=(4, 4, 4), droprate=droprate, pool=True
  130. )
  131. self.conv2 = CNN_Net.ConvolutionalBlock(
  132. (5, 6, 5), 192, 384, droprate=droprate, pool=True
  133. )
  134. # Midflow Block
  135. self.midflow = MidFlowBlock(384, droprate)
  136. # Combine
  137. self.combined = nn.Sequential(self.conv1, self.conv2, self.midflow)
  138. # Split Convolutional Block
  139. self.splitconv = SplitConvBlock(384, 192, 96, 4, droprate)
  140. #Fully Connected Block
  141. self.fc1 = CNN_Net.FullyConnectedBlock(96, 20, droprate=droprate)
  142. self.image_layers = nn.Sequential(self.combined, self.splitconv).double()
  143. #Data Layers, fully connected
  144. self.fc1 = CNN_Net.FullyConnectedBlock(clin_data_channels, 64, droprate=droprate)
  145. self.fc2 = CNN_Net.FullyConnectedBlock(64, 20, droprate=droprate)
  146. #Conntect Data
  147. self.data_layers = nn.Sequential(self.fc1, self.fc2).double()
  148. #Final Dense Layer
  149. self.dense1 = nn.Linear(40, final_layer_size)
  150. self.dense2 = nn.Linear(final_layer_size, 2)
  151. self.softmax = nn.Softmax()
  152. self.final_layers = nn.Sequential(self.dense1, self.dense2, self.softmax)
  153. def forward(self, image, clin_data):
  154. print(image.shape)
  155. image = self.image_layers(image)
  156. x = torch.cat((image, clin_data), dim=1)
  157. x = self.final_layers(x)
  158. return x