models.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import utils.layers as ly
  7. import torch
  8. import torchvision
  9. import torchsummary as ts
  10. class Parameters:
  11. def __init__(self, param_dict):
  12. self.CNN_w_regularizer = param_dict["CNN_w_regularizer"]
  13. self.RNN_w_regularizer = param_dict["RNN_w_regularizer"]
  14. self.CNN_batch_size = param_dict["CNN_batch_size"]
  15. self.RNN_batch_size = param_dict["RNN_batch_size"]
  16. self.CNN_drop_rate = param_dict["CNN_drop_rate"]
  17. self.RNN_drop_rate = param_dict["RNN_drop_rate"]
  18. self.epochs = param_dict["epochs"]
  19. self.gpu = param_dict["gpu"]
  20. self.model_filepath = param_dict["model_filepath"] + "/net.h5"
  21. self.num_clinical = param_dict["num_clinical"]
  22. self.image_shape = param_dict["image_shape"]
  23. self.final_layer_size = param_dict["final_layer_size"]
  24. self.optimizer = param_dict["optimizer"]
  25. class CNN_Net(nn.Module):
  26. def __init__(self, image_channels, clin_data_channels, droprate):
  27. super().__init__()
  28. #Image Section
  29. self.image_section = CNN_Image_Section(image_channels, droprate)
  30. #Data Layers, fully connected
  31. self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
  32. self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
  33. #Final Dense Layer
  34. self.dense1 = nn.Linear(40, 5)
  35. self.dense2 = nn.Linear(5, 2)
  36. self.softmax = nn.Softmax(dim = 1)
  37. def forward(self, x):
  38. image, clin_data = x
  39. image = self.image_section(image)
  40. clin_data = self.fc_clin1(clin_data)
  41. clin_data = self.fc_clin2(clin_data)
  42. x = torch.cat((image, clin_data), dim=1)
  43. x = self.dense1(x)
  44. x = self.dense2(x)
  45. x = self.softmax(x)
  46. return x
  47. class CNN_Image_Section(nn.Module):
  48. def __init__(self, image_channels, droprate):
  49. super().__init__()
  50. # Initial Convolutional Blocks
  51. self.conv1 = ly.ConvBlock(
  52. image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=False
  53. )
  54. self.conv2 = ly.ConvBlock(
  55. 192, 384, (5, 6, 5), droprate=droprate, pool=False
  56. )
  57. # Midflow Block
  58. self.midflow = ly.MidFlowBlock(384, droprate)
  59. # Split Convolutional Block
  60. self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
  61. #Fully Connected Block
  62. self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
  63. def forward(self, x):
  64. x = self.conv1(x)
  65. x = self.conv2(x)
  66. x = self.midflow(x)
  67. x = self.splitconv(x)
  68. x = torch.flatten(x, 1)
  69. x = self.fc_image(x)
  70. return x