cnn.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import utils.models.layers as ly
  7. import torch
  8. import torchvision
  9. class Parameters:
  10. def __init__(self, param_dict):
  11. self.CNN_w_regularizer = param_dict["CNN_w_regularizer"]
  12. self.RNN_w_regularizer = param_dict["RNN_w_regularizer"]
  13. self.CNN_batch_size = param_dict["CNN_batch_size"]
  14. self.RNN_batch_size = param_dict["RNN_batch_size"]
  15. self.CNN_drop_rate = param_dict["CNN_drop_rate"]
  16. self.RNN_drop_rate = param_dict["RNN_drop_rate"]
  17. self.epochs = param_dict["epochs"]
  18. self.gpu = param_dict["gpu"]
  19. self.model_filepath = param_dict["model_filepath"] + "/net.h5"
  20. self.num_clinical = param_dict["num_clinical"]
  21. self.image_shape = param_dict["image_shape"]
  22. self.final_layer_size = param_dict["final_layer_size"]
  23. self.optimizer = param_dict["optimizer"]
  24. class CNN(nn.Module):
  25. def __init__(self, image_channels, clin_data_channels, droprate):
  26. super().__init__()
  27. # Image Section
  28. self.image_section = CNN_Image_Section(image_channels, droprate)
  29. # Data Layers, fully connected
  30. self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
  31. self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
  32. # Final Dense Layer
  33. self.dense1 = nn.Linear(40, 5)
  34. self.dense2 = nn.Linear(5, 2)
  35. self.softmax = nn.Softmax(dim=1)
  36. def forward(self, x):
  37. image, clin_data = x
  38. image = self.image_section(image)
  39. clin_data = self.fc_clin1(clin_data)
  40. clin_data = self.fc_clin2(clin_data)
  41. x = torch.cat((image, clin_data), dim=1)
  42. x = self.dense1(x)
  43. x = self.dense2(x)
  44. x = self.softmax(x)
  45. return x
  46. class CNN_Image_Section(nn.Module):
  47. def __init__(self, image_channels, droprate):
  48. super().__init__()
  49. # Initial Convolutional Blocks
  50. self.conv1 = ly.ConvBlock(
  51. image_channels,
  52. 192,
  53. (11, 13, 11),
  54. stride=(4, 4, 4),
  55. droprate=droprate,
  56. pool=False,
  57. )
  58. self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
  59. # Midflow Block
  60. self.midflow = ly.MidFlowBlock(384, droprate)
  61. # Split Convolutional Block
  62. self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
  63. # Fully Connected Block
  64. self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
  65. def forward(self, x):
  66. x = self.conv1(x)
  67. x = self.conv2(x)
  68. x = self.midflow(x)
  69. x = self.splitconv(x)
  70. x = torch.flatten(x, 1)
  71. x = self.fc_image(x)
  72. return x