from typing import Tuple from torch import nn import torch import model.layers as ly from jaxtyping import Float class CNN_Image_Section(nn.Module): def __init__(self, image_channels: int, droprate: float = 0.0): super().__init__() # Initial Convolutional Blocks self.conv1 = ly.ConvBlock( image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=False, ) self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False) # Midflow Block self.midflow = ly.MidFlowBlock(384, droprate) # Split Convolutional Block self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate) # Fully Connected Block self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate) def forward(self, x: Float[torch.Tensor, "N C D H W"]): x = self.conv1(x) x = self.conv2(x) x = self.midflow(x) x = self.splitconv(x) x = torch.flatten(x, 1) x = self.fc_image(x) return x class CNN3D(nn.Module): def __init__( self, image_channels: int, clin_data_channels: int, num_classes: int, droprate: float = 0.0, ): super().__init__() self.image_section = CNN_Image_Section(image_channels, droprate=droprate) self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate) self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate) self.dense1 = nn.Linear(20 + 20, 10) self.dense2 = nn.Linear(10, num_classes) self.softmax = nn.Softmax(dim=1) def forward( self, x_in: Tuple[Float[torch.Tensor, "N C D H W"], Float[torch.Tensor, "N F"]] ): image_data, clin_data = x_in image_out = self.image_section(image_data) clin_out = self.fc_clin2(self.fc_clin1(clin_data)) combined = torch.cat((image_out, clin_out), dim=1) x = self.dense1(combined) x = self.dense2(x) x = self.softmax(x) return x