models.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import layers as ly
  7. import torch
  8. import torchvision
  9. import torchsummary as ts
  10. class Parameters:
  11. def __init__(self, param_dict):
  12. self.CNN_w_regularizer = param_dict["CNN_w_regularizer"]
  13. self.RNN_w_regularizer = param_dict["RNN_w_regularizer"]
  14. self.CNN_batch_size = param_dict["CNN_batch_size"]
  15. self.RNN_batch_size = param_dict["RNN_batch_size"]
  16. self.CNN_drop_rate = param_dict["CNN_drop_rate"]
  17. self.RNN_drop_rate = param_dict["RNN_drop_rate"]
  18. self.epochs = param_dict["epochs"]
  19. self.gpu = param_dict["gpu"]
  20. self.model_filepath = param_dict["model_filepath"] + "/net.h5"
  21. self.num_clinical = param_dict["num_clinical"]
  22. self.image_shape = param_dict["image_shape"]
  23. self.final_layer_size = param_dict["final_layer_size"]
  24. self.optimizer = param_dict["optimizer"]
  25. class CNN_Net(nn.Module):
  26. def __init__(self, image_channels, clin_data_channels, droprate):
  27. super().__init__()
  28. # Initial Convolutional Blocks
  29. self.conv1 = ly.ConvBlock(
  30. image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=False
  31. )
  32. self.conv2 = ly.ConvBlock(
  33. 192, 384, (5, 6, 5), droprate=droprate, pool=False
  34. )
  35. # Midflow Block
  36. self.midflow = ly.MidFlowBlock(384, droprate)
  37. # Split Convolutional Block
  38. self.splitconv = ly.SplitConvBlock(384, 192, 96, 4, droprate)
  39. #Fully Connected Block
  40. self.fc_image = ly.FullConnBlock(96, 20, droprate=droprate)
  41. #Data Layers, fully connected
  42. self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
  43. self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
  44. #Final Dense Layer
  45. self.dense1 = nn.Linear(40, 5)
  46. self.dense2 = nn.Linear(5, 2)
  47. self.softmax = nn.Softmax()
  48. def forward(self, x):
  49. image, clin_data = x
  50. image = self.conv1(image)
  51. image = self.conv2(image)
  52. image = self.midflow(image)
  53. image = self.splitconv(image)
  54. image = torch.flatten(image, 1)
  55. image = self.fc_image(image)
  56. clin_data = self.fc_clin1(clin_data)
  57. clin_data = self.fc_clin2(clin_data)
  58. x = torch.cat((image, clin_data), dim=1)
  59. x = self.dense1(x)
  60. x = self.dense2(x)
  61. x = self.softmax(x)
  62. return x
  63. class CNN_Image_Section(nn.Module):
  64. def __init__(self, image_channels, droprate):
  65. super().__init__()
  66. # Initial Convolutional Blocks
  67. self.conv1 = ly.ConvBlock(
  68. image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=False
  69. )
  70. self.conv2 = ly.ConvBlock(
  71. 192, 384, (5, 6, 5), droprate=droprate, pool=False
  72. )
  73. # Midflow Block
  74. self.midflow = ly.MidFlowBlock(384, droprate)
  75. # Split Convolutional Block
  76. self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
  77. #Fully Connected Block
  78. self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
  79. def forward(self, x):
  80. x = self.conv1(x)
  81. x = self.conv2(x)
  82. x = self.midflow(x)
  83. x = self.splitconv(x)
  84. print(x.shape)
  85. x = torch.flatten(x, 1)
  86. x = self.fc_image(x)
  87. print(ts.summary(CNN_Image_Section(1, 0.5).cuda(), (1, 91, 109, 91)))