models.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from torch import nn
  2. from torchvision.transforms import ToTensor
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import utils.layers as ly
  7. import torch
  8. import torchvision
  9. class Parameters:
  10. def __init__(self, param_dict):
  11. self.CNN_w_regularizer = param_dict["CNN_w_regularizer"]
  12. self.RNN_w_regularizer = param_dict["RNN_w_regularizer"]
  13. self.CNN_batch_size = param_dict["CNN_batch_size"]
  14. self.RNN_batch_size = param_dict["RNN_batch_size"]
  15. self.CNN_drop_rate = param_dict["CNN_drop_rate"]
  16. self.RNN_drop_rate = param_dict["RNN_drop_rate"]
  17. self.epochs = param_dict["epochs"]
  18. self.gpu = param_dict["gpu"]
  19. self.model_filepath = param_dict["model_filepath"] + "/net.h5"
  20. self.num_clinical = param_dict["num_clinical"]
  21. self.image_shape = param_dict["image_shape"]
  22. self.final_layer_size = param_dict["final_layer_size"]
  23. self.optimizer = param_dict["optimizer"]
  24. class CNN_Net(nn.Module):
  25. def __init__(self, image_channels, clin_data_channels, droprate):
  26. super().__init__()
  27. # Initial Convolutional Blocks
  28. self.conv1 = ly.ConvolutionalBlock(
  29. image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=True
  30. )
  31. self.conv2 = ly.ConvolutionalBlock(
  32. 192, 384, (5, 6, 5), droprate=droprate, pool=True
  33. )
  34. # Midflow Block
  35. self.midflow = ly.MidFlowBlock(384, droprate)
  36. # Split Convolutional Block
  37. self.splitconv = ly.SplitConvBlock(384, 192, 96, 4, droprate)
  38. #Fully Connected Block
  39. self.fc_image = ly.FullyConnectedBlock(96, 20, droprate=droprate)
  40. #Data Layers, fully connected
  41. self.fc_clin1 = ly.FullyConnectedBlock(clin_data_channels, 64, droprate=droprate)
  42. self.fc_clin2 = ly.FullyConnectedBlock(64, 20, droprate=droprate)
  43. #Final Dense Layer
  44. self.dense1 = nn.Linear(40, 5)
  45. self.dense2 = nn.Linear(5, 2)
  46. self.softmax = nn.Softmax()
  47. def forward(self, x):
  48. image, clin_data = x
  49. print("Input image shape:", image.shape)
  50. image = self.conv1(image)
  51. print("Conv1 shape:", image.shape)
  52. image = self.conv2(image)
  53. print("Conv2 shape:", image.shape)
  54. image = self.midflow(image)
  55. print("Midflow shape:", image.shape)
  56. image = self.splitconv(image)
  57. print("Splitconv shape:", image.shape)
  58. image = torch.flatten(image, 1)
  59. print("Flatten shape:", image.shape)
  60. image = self.fc_image(image)
  61. clin_data = self.fc_clin1(clin_data)
  62. clin_data = self.fc_clin2(clin_data)
  63. x = torch.cat((image, clin_data), dim=1)
  64. x = self.dense1(x)
  65. x = self.dense2(x)
  66. x = self.softmax(x)
  67. return x