models.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from torch import device, cuda
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. import utils.CNN_methods as CNN
  7. # METHODS: CONV3D, CONV2D, MAXPOOL, LINEAR, ...
  8. class CNN_Net(nn.Module):
  9. # Defines all properties / layers that can be used
  10. def __init__(self, mri_volume, params):
  11. super().__init__()
  12. # self.parameters = nn.ParameterList(params)
  13. self.model = xalex3D(mri_volume)
  14. self.device = device('cuda:0' if cuda.is_available() else 'cpu')
  15. print("CNN Initialized. Using: " + str(self.device))
  16. # Implements layers with x data, "running an epoch on x"
  17. def forward(self, x):
  18. x = F.relu(self.model.f(x, [])) # TODO Add Clinical
  19. return x
  20. # Training data
  21. def train(self, trainloader, PATH):
  22. criterion = nn.CrossEntropyLoss()
  23. optimizer = optim.Adam(self.parameters(), lr=1e-5)
  24. for epoch in range(2): # loop over the dataset multiple times
  25. running_loss = 0.0
  26. for i, data in enumerate(trainloader, 0):
  27. # get the inputs; data is a list of [inputs, labels]
  28. inputs, labels = data[0].to(self.device), data[1].to(self.device)
  29. # zero the parameter gradients
  30. optimizer.zero_grad()
  31. # forward + backward + optimize
  32. outputs = self.forward(inputs)
  33. loss = criterion(outputs, labels)
  34. loss.backward()
  35. optimizer.step()
  36. # print statistics
  37. running_loss += loss.item()
  38. if i % 2000 == 1999: # print every 2000 mini-batches
  39. print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
  40. running_loss = 0.0
  41. print('Finished Training')
  42. torch.save(self.state_dict(), PATH)
  43. def test(self, testloader):
  44. correct = 0
  45. total = 0
  46. # since we're not training, we don't need to calculate the gradients for our outputs
  47. with torch.no_grad():
  48. for data in testloader:
  49. images, labels = data[0].to(self.device), data[1].to(self.devie)
  50. # calculate outputs by running images through the network
  51. outputs = self.forward(images)
  52. # the class with the highest energy is what we choose as prediction
  53. _, predicted = torch.max(outputs.data, 1)
  54. total += labels.size(0)
  55. correct += (predicted == labels).sum().item()
  56. print(f'Accuracy of the network: {100 * correct // total} %')
  57. '''
  58. XAlex3D model.
  59. Functions used:
  60. - conv_elu_maxpool_drop(in_channel, filters, kernel_size, stride=(1,1,1), padding=0, dilation=1,
  61. groups=1, bias=True, padding_mode='zeros', pool=False, drop_rate=0, sep_conv = False)
  62. '''
  63. # TODO, figure out IN_CHANNEL
  64. # TODO, in_channel
  65. class xalex3D(nn.Module):
  66. def __init__(self, mri_volume, drop_rate=0, final_layer_size=50):
  67. self.drop_rate = drop_rate
  68. self.final_layer_size = final_layer_size
  69. # self.conv1 = CNN.conv_elu_maxpool_drop(len(next(iter(mri_volume))), 192, (11, 13, 11), stride=(4, 4, 4), drop_rate=self.drop_rate, pool=True)(next(iter(mri_volume)))
  70. # self.conv2 = CNN.conv_elu_maxpool_drop(self.conv1.shape(), 384, (5, 6, 5), stride=(1, 1, 1), drop_rate=self.drop_rate, pool=True)(self.conv1)
  71. # self.conv_mid_3 = CNN.mid_flow(self.conv2.shape(), self.drop_rate, filters=384)
  72. # self.groupConv4 = CNN.conv_elu_maxpool_drop(self.conv_mid_3.shape(), 96, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
  73. # pool=True, groups=2)(self.conv_mid_3)
  74. # self.groupConv5 = CNN.conv_elu_maxpool_drop(self.groupConv4.shape(), 48, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
  75. # pool=True, groups=2)(self.groupConv4)
  76. #
  77. # self.fc1 = CNN.fc_elu_drop(self.groupConv5.shape(), 20, drop_rate=self.drop_rate)(self.groupConv5)
  78. #
  79. # self.fc2 = CNN.fc_elu_drop(self.fc1.shape(), 50, drop_rate=self.drop_rate)(self.fc1)
  80. def f(self, mri_volume, clinical_inputs):
  81. conv1 = CNN.conv_elu_maxpool_drop(mri_volume.size(), 192, (11, 13, 11), stride=(4, 4, 4), drop_rate=self.drop_rate, pool=True)(mri_volume)
  82. conv2 = CNN.conv_elu_maxpool_drop(conv1.size(), 384, (5, 6, 5), stride=(1, 1, 1), drop_rate=self.drop_rate, pool=True)(conv1)
  83. # MIDDLE FLOW, 3 times sepConv & ELU()
  84. print(f"Residual: {conv2.shape}")
  85. conv_mid_3 = CNN.mid_flow(conv2, self.drop_rate, filters=384)
  86. # CONV in 2 groups (left & right)
  87. groupConv4 = CNN.conv_elu_maxpool_drop(conv_mid_3.size(), 96, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
  88. pool=True, groups=2)(conv_mid_3)
  89. groupConv5 = CNN.conv_elu_maxpool_drop(groupConv4.size(), 48, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
  90. pool=True, groups=2)(groupConv4)
  91. # FCs
  92. fc1 = CNN.fc_elu_drop(groupConv5.size(), 20, drop_rate=self.drop_rate)(groupConv5)
  93. fc2 = CNN.fc_elu_drop(fc1.size(), 50, drop_rate=self.drop_rate)(fc1)
  94. return fc2
  95. """ LAST PART:
  96. # Flatten 3D conv network representations
  97. flat_conv_6 = Reshape((np.prod(K.int_shape(conv6_concat)[1:]),))(conv6_concat)
  98. # 2-layer Dense network for clinical features
  99. vol_fc1 = _fc_bn_relu_drop(64, w_regularizer=w_regularizer,
  100. drop_rate=drop_rate)(clinical_inputs)
  101. flat_volume = _fc_bn_relu_drop(20, w_regularizer=w_regularizer,
  102. drop_rate=drop_rate)(vol_fc1)
  103. # Combine image and clinical features embeddings
  104. fc1 = _fc_bn_relu_drop(20, w_regularizer, drop_rate=drop_rate, name='final_conv')(flat_conv_6)
  105. flat = concatenate([fc1, flat_volume])
  106. # Final 4D embedding
  107. fc2 = Dense(units=final_layer_size, activation='linear', kernel_regularizer=w_regularizer, name='features')(
  108. flat) # was linear activation"""
  109. ''' FULL CODE:
  110. # First layer
  111. conv1_left = _conv_bn_relu_pool_drop(192, 11, 13, 11, strides=(4, 4, 4), w_regularizer=w_regularizer,
  112. drop_rate=drop_rate, pool=True)(mri_volume)
  113. # Second layer
  114. conv2_left = _conv_bn_relu_pool_drop(384, 5, 6, 5, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
  115. conv1_left)
  116. # Introduce Middle Flow (separable convolutions with a residual connection)
  117. print('residual shape ' + str(conv2_left.shape))
  118. conv_mid_1 = mid_flow(conv2_left, drop_rate, w_regularizer,
  119. filters=384) # changed input to conv2_left from conv2_concat
  120. # Split channels for grouped-style convolution
  121. conv_mid_1_1 = Lambda(lambda x: x[:, :, :, :, :192])(conv_mid_1)
  122. conv_mid_1_2 = Lambda(lambda x: x[:, :, :, :, 192:])(conv_mid_1)
  123. conv5_left = _conv_bn_relu_pool_drop(96, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
  124. conv_mid_1_1)
  125. conv5_right = _conv_bn_relu_pool_drop(96, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
  126. conv_mid_1_2)
  127. conv6_left = _conv_bn_relu_pool_drop(48, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
  128. conv5_left)
  129. conv6_right = _conv_bn_relu_pool_drop(48, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
  130. conv5_right)
  131. conv6_concat = concatenate([conv6_left, conv6_right], axis=-1)
  132. # convExtra = Conv3D(48, (20,30,20),
  133. # strides = (1,1,1), kernel_initializer="he_normal",
  134. # padding="same", kernel_regularizer = w_regularizer)(conv6_concat)
  135. # Flatten 3D conv network representations
  136. flat_conv_6 = Reshape((np.prod(K.int_shape(conv6_concat)[1:]),))(conv6_concat)
  137. # 2-layer Dense network for clinical features
  138. vol_fc1 = _fc_bn_relu_drop(64, w_regularizer=w_regularizer,
  139. drop_rate=drop_rate)(clinical_inputs)
  140. flat_volume = _fc_bn_relu_drop(20, w_regularizer=w_regularizer,
  141. drop_rate=drop_rate)(vol_fc1)
  142. # Combine image and clinical features embeddings
  143. fc1 = _fc_bn_relu_drop(20, w_regularizer, drop_rate=drop_rate, name='final_conv')(flat_conv_6)
  144. # fc2 = _fc_bn_relu_drop (40, w_regularizer, drop_rate = drop_rate) (fc1)
  145. flat = concatenate([fc1, flat_volume])
  146. # Final 4D embedding
  147. fc2 = Dense(units=final_layer_size, activation='linear', kernel_regularizer=w_regularizer, name='features')(
  148. flat) # was linear activation
  149. '''