cnn.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import Tuple
  2. from torch import nn
  3. import torch
  4. import model.layers as ly
  5. from jaxtyping import Float
  6. class CNN_Image_Section(nn.Module):
  7. def __init__(self, image_channels: int, droprate: float = 0.0):
  8. super().__init__()
  9. # Initial Convolutional Blocks
  10. self.conv1 = ly.ConvBlock(
  11. image_channels,
  12. 192,
  13. (11, 13, 11),
  14. stride=(4, 4, 4),
  15. droprate=droprate,
  16. pool=False,
  17. )
  18. self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
  19. # Midflow Block
  20. self.midflow = ly.MidFlowBlock(384, droprate)
  21. # Split Convolutional Block
  22. self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
  23. # Fully Connected Block
  24. self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
  25. def forward(self, x: Float[torch.Tensor, "N C D H W"]):
  26. x = self.conv1(x)
  27. x = self.conv2(x)
  28. x = self.midflow(x)
  29. x = self.splitconv(x)
  30. x = torch.flatten(x, 1)
  31. x = self.fc_image(x)
  32. return x
  33. class CNN3D(nn.Module):
  34. def __init__(
  35. self,
  36. image_channels: int,
  37. clin_data_channels: int,
  38. num_classes: int,
  39. droprate: float = 0.0,
  40. ):
  41. super().__init__()
  42. self.image_section = CNN_Image_Section(image_channels, droprate=droprate)
  43. self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
  44. self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
  45. self.dense1 = nn.Linear(20 + 20, 10)
  46. self.dense2 = nn.Linear(10, num_classes)
  47. self.softmax = nn.Softmax(dim=1)
  48. def forward(
  49. self, x_in: Tuple[Float[torch.Tensor, "N C D H W"], Float[torch.Tensor, "N F"]]
  50. ):
  51. image_data, clin_data = x_in
  52. image_out = self.image_section(image_data)
  53. clin_out = self.fc_clin2(self.fc_clin1(clin_data))
  54. combined = torch.cat((image_out, clin_out), dim=1)
  55. x = self.dense1(combined)
  56. x = self.dense2(x)
  57. x = self.softmax(x)
  58. return x