123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- from typing import Tuple
- from torch import nn
- import torch
- import model.layers as ly
- from jaxtyping import Float
- class CNN_Image_Section(nn.Module):
- def __init__(self, image_channels: int, droprate: float = 0.0):
- super().__init__()
- # Initial Convolutional Blocks
- self.conv1 = ly.ConvBlock(
- image_channels,
- 192,
- (11, 13, 11),
- stride=(4, 4, 4),
- droprate=droprate,
- pool=False,
- )
- self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
- # Midflow Block
- self.midflow = ly.MidFlowBlock(384, droprate)
- # Split Convolutional Block
- self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
- # Fully Connected Block
- self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
- def forward(self, x: Float[torch.Tensor, "N C D H W"]):
- x = self.conv1(x)
- x = self.conv2(x)
- x = self.midflow(x)
- x = self.splitconv(x)
- x = torch.flatten(x, 1)
- x = self.fc_image(x)
- return x
- class CNN3D(nn.Module):
- def __init__(
- self,
- image_channels: int,
- clin_data_channels: int,
- num_classes: int,
- droprate: float = 0.0,
- ):
- super().__init__()
- self.image_section = CNN_Image_Section(image_channels, droprate=droprate)
- self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
- self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
- self.dense1 = nn.Linear(20 + 20, 10)
- self.dense2 = nn.Linear(10, num_classes)
- self.softmax = nn.Softmax(dim=1)
- def forward(
- self, x_in: Tuple[Float[torch.Tensor, "N C D H W"], Float[torch.Tensor, "N F"]]
- ):
- image_data, clin_data = x_in
- image_out = self.image_section(image_data)
- clin_out = self.fc_clin2(self.fc_clin1(clin_data))
- combined = torch.cat((image_out, clin_out), dim=1)
- x = self.dense1(combined)
- x = self.dense2(x)
- x = self.softmax(x)
- return x
|