from torch import nn from torchvision.transforms import ToTensor import os import pandas as pd import numpy as np import utils.layers as ly import torch import torchvision class Parameters: def __init__(self, param_dict): self.CNN_w_regularizer = param_dict["CNN_w_regularizer"] self.RNN_w_regularizer = param_dict["RNN_w_regularizer"] self.CNN_batch_size = param_dict["CNN_batch_size"] self.RNN_batch_size = param_dict["RNN_batch_size"] self.CNN_drop_rate = param_dict["CNN_drop_rate"] self.RNN_drop_rate = param_dict["RNN_drop_rate"] self.epochs = param_dict["epochs"] self.gpu = param_dict["gpu"] self.model_filepath = param_dict["model_filepath"] + "/net.h5" self.num_clinical = param_dict["num_clinical"] self.image_shape = param_dict["image_shape"] self.final_layer_size = param_dict["final_layer_size"] self.optimizer = param_dict["optimizer"] class CNN_Net(nn.Module): def __init__(self, image_channels, clin_data_channels, droprate): super().__init__() # Initial Convolutional Blocks self.conv1 = ly.ConvolutionalBlock( image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=True ) self.conv2 = ly.ConvolutionalBlock( 192, 384, (5, 6, 5), droprate=droprate, pool=True ) # Midflow Block self.midflow = ly.MidFlowBlock(384, droprate) # Split Convolutional Block self.splitconv = ly.SplitConvBlock(384, 192, 96, 4, droprate) #Fully Connected Block self.fc_image = ly.FullyConnectedBlock(96, 20, droprate=droprate) #Data Layers, fully connected self.fc_clin1 = ly.FullyConnectedBlock(clin_data_channels, 64, droprate=droprate) self.fc_clin2 = ly.FullyConnectedBlock(64, 20, droprate=droprate) #Final Dense Layer self.dense1 = nn.Linear(40, 5) self.dense2 = nn.Linear(5, 2) self.softmax = nn.Softmax() def forward(self, x): image, clin_data = x print("Input image shape:", image.shape) image = self.conv1(image) print("Conv1 shape:", image.shape) image = self.conv2(image) print("Conv2 shape:", image.shape) image = self.midflow(image) print("Midflow shape:", image.shape) image = self.splitconv(image) print("Splitconv shape:", image.shape) image = torch.flatten(image, 1) print("Flatten shape:", image.shape) image = self.fc_image(image) clin_data = self.fc_clin1(clin_data) clin_data = self.fc_clin2(clin_data) x = torch.cat((image, clin_data), dim=1) x = self.dense1(x) x = self.dense2(x) x = self.softmax(x) return x