CNN.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import torch
  2. from torch import device, cuda, cat, stack
  3. import torch.nn as nn
  4. import utils.CNN_Layers as CustomLayers
  5. class CNN_Net(nn.Module):
  6. def __init__(self, prps, final_layer_size=5):
  7. super(CNN_Net, self).__init__()
  8. self.final_layer_size = final_layer_size
  9. self.device = device('cuda:0' if cuda.is_available() else 'cpu')
  10. print("CNN Initialized. Using: " + str(self.device))
  11. # LAYERS
  12. print(f"CNN Model Initialization")
  13. self.conv1 = CustomLayers.Conv_elu_maxpool_drop(1, 192, (11, 13, 11), stride=(4,4,4), pool=True, prps=prps)
  14. self.conv2 = CustomLayers.Conv_elu_maxpool_drop(192, 384, (5, 6, 5), stride=(1,1,1), pool=True, prps=prps)
  15. self.conv3_mid_flow = CustomLayers.Mid_flow(384, 384, prps=prps)
  16. self.conv4_sepConv = CustomLayers.Conv_elu_maxpool_drop(384, 96,(3, 4, 3), stride=(1,1,1), pool=True, prps=prps,
  17. sep_conv=True)
  18. self.conv5_sepConv = CustomLayers.Conv_elu_maxpool_drop(96, 48, (3, 4, 3), stride=(1, 1, 1), pool=True,
  19. prps=prps, sep_conv=True)
  20. self.fc1 = CustomLayers.Fc_elu_drop(113568, 10, prps=prps, softmax=False) # TODO, concatenate clinical data after this
  21. self.fc2 = CustomLayers.Fc_elu_drop(10, final_layer_size, prps=prps, softmax=True) # For now this works as output layer, though may be incorrect
  22. self.fc_clinical1 = CustomLayers.Fc_elu_drop(6, 30, prps=prps, softmax=False)
  23. self.fc_clinical2 = CustomLayers.Fc_elu_drop(30,10, prps=prps, softmax=False)
  24. # FORWARDS
  25. def forward(self, x):
  26. clinical_data = x[1].to(torch.float32)
  27. x = x[0]
  28. x = self.conv1(x)
  29. x = self.conv2(x)
  30. x = self.conv3_mid_flow(x)
  31. x = self.conv4_sepConv(x)
  32. x = self.conv5_sepConv(x)
  33. # FLATTEN x
  34. flatten_size = x.size(1) * x.size(2) * x.size(3) * x.size(4)
  35. x = x.view(-1, flatten_size)
  36. x = self.fc1(x)
  37. # Clinical
  38. clinical_data = self.fc_clinical1(clinical_data)
  39. clinical_data = self.fc_clinical2(clinical_data)
  40. x = cat((x, clinical_data), dim=1)
  41. print(x.shape)
  42. x = self.fc2(x)
  43. return x