cnn.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Tuple
  2. from torch import nn
  3. import utils.models.layers as ly
  4. import torch
  5. class Parameters:
  6. def __init__(self, param_dict):
  7. self.CNN_w_regularizer = param_dict["CNN_w_regularizer"]
  8. self.RNN_w_regularizer = param_dict["RNN_w_regularizer"]
  9. self.CNN_batch_size = param_dict["CNN_batch_size"]
  10. self.RNN_batch_size = param_dict["RNN_batch_size"]
  11. self.CNN_drop_rate = param_dict["CNN_drop_rate"]
  12. self.RNN_drop_rate = param_dict["RNN_drop_rate"]
  13. self.epochs = param_dict["epochs"]
  14. self.gpu = param_dict["gpu"]
  15. self.model_filepath = param_dict["model_filepath"] + "/net.h5"
  16. self.num_clinical = param_dict["num_clinical"]
  17. self.image_shape = param_dict["image_shape"]
  18. self.final_layer_size = param_dict["final_layer_size"]
  19. self.optimizer = param_dict["optimizer"]
  20. class CNN(nn.Module):
  21. def __init__(self, image_channels, clin_data_channels, droprate):
  22. super().__init__()
  23. # Image Section
  24. self.image_section = CNN_Image_Section(image_channels, droprate)
  25. # Data Layers, fully connected
  26. self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
  27. self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
  28. # Final Dense Layer
  29. self.dense1 = nn.Linear(40, 5)
  30. self.dense2 = nn.Linear(5, 2)
  31. self.softmax = nn.Softmax(dim=1)
  32. def forward(self, x_in: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
  33. image, clin_data = x_in
  34. image = self.image_section(image)
  35. clin_data = self.fc_clin1(clin_data)
  36. clin_data = self.fc_clin2(clin_data)
  37. x = torch.cat((image, clin_data), dim=1)
  38. x = self.dense1(x)
  39. x = self.dense2(x)
  40. x = self.softmax(x)
  41. return x
  42. class CNN_Image_Section(nn.Module):
  43. def __init__(self, image_channels, droprate):
  44. super().__init__()
  45. # Initial Convolutional Blocks
  46. self.conv1 = ly.ConvBlock(
  47. image_channels,
  48. 192,
  49. (11, 13, 11),
  50. stride=(4, 4, 4),
  51. droprate=droprate,
  52. pool=False,
  53. )
  54. self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
  55. # Midflow Block
  56. self.midflow = ly.MidFlowBlock(384, droprate)
  57. # Split Convolutional Block
  58. self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
  59. # Fully Connected Block
  60. self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
  61. def forward(self, x):
  62. x = self.conv1(x)
  63. x = self.conv2(x)
  64. x = self.midflow(x)
  65. x = self.splitconv(x)
  66. x = torch.flatten(x, 1)
  67. x = self.fc_image(x)
  68. return x