123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- from torch import device, cuda
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- import utils.CNN_methods as CNN
- # METHODS: CONV3D, CONV2D, MAXPOOL, LINEAR, ...
- class CNN_Net(nn.Module):
- # Defines all properties / layers that can be used
- def __init__(self, mri_volume, params):
- super().__init__()
- # self.parameters = nn.ParameterList(params)
- self.model = xalex3D(mri_volume)
- self.device = device('cuda:0' if cuda.is_available() else 'cpu')
- print("CNN Initialized. Using: " + str(self.device))
- # Implements layers with x data, "running an epoch on x"
- def forward(self, x):
- x = F.relu(self.model.f(x, [])) # TODO Add Clinical
- return x
- # Training data
- def train(self, trainloader, PATH):
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(self.parameters(), lr=1e-5)
- for epoch in range(2): # loop over the dataset multiple times
- running_loss = 0.0
- for i, data in enumerate(trainloader, 0):
- # get the inputs; data is a list of [inputs, labels]
- inputs, labels = data[0].to(self.device), data[1].to(self.device)
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward + backward + optimize
- outputs = self.forward(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- # print statistics
- running_loss += loss.item()
- if i % 2000 == 1999: # print every 2000 mini-batches
- print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
- running_loss = 0.0
- print('Finished Training')
- torch.save(self.state_dict(), PATH)
- def test(self, testloader):
- correct = 0
- total = 0
- # since we're not training, we don't need to calculate the gradients for our outputs
- with torch.no_grad():
- for data in testloader:
- images, labels = data[0].to(self.device), data[1].to(self.devie)
- # calculate outputs by running images through the network
- outputs = self.forward(images)
- # the class with the highest energy is what we choose as prediction
- _, predicted = torch.max(outputs.data, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- print(f'Accuracy of the network: {100 * correct // total} %')
- '''
- XAlex3D model.
- Functions used:
- - conv_elu_maxpool_drop(in_channel, filters, kernel_size, stride=(1,1,1), padding=0, dilation=1,
- groups=1, bias=True, padding_mode='zeros', pool=False, drop_rate=0, sep_conv = False)
- '''
- # TODO, figure out IN_CHANNEL
- # TODO, in_channel
- class xalex3D(nn.Module):
- def __init__(self, mri_volume, drop_rate=0, final_layer_size=50):
- self.drop_rate = drop_rate
- self.final_layer_size = final_layer_size
- # self.conv1 = CNN.conv_elu_maxpool_drop(len(next(iter(mri_volume))), 192, (11, 13, 11), stride=(4, 4, 4), drop_rate=self.drop_rate, pool=True)(next(iter(mri_volume)))
- # self.conv2 = CNN.conv_elu_maxpool_drop(self.conv1.shape(), 384, (5, 6, 5), stride=(1, 1, 1), drop_rate=self.drop_rate, pool=True)(self.conv1)
- # self.conv_mid_3 = CNN.mid_flow(self.conv2.shape(), self.drop_rate, filters=384)
- # self.groupConv4 = CNN.conv_elu_maxpool_drop(self.conv_mid_3.shape(), 96, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
- # pool=True, groups=2)(self.conv_mid_3)
- # self.groupConv5 = CNN.conv_elu_maxpool_drop(self.groupConv4.shape(), 48, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
- # pool=True, groups=2)(self.groupConv4)
- #
- # self.fc1 = CNN.fc_elu_drop(self.groupConv5.shape(), 20, drop_rate=self.drop_rate)(self.groupConv5)
- #
- # self.fc2 = CNN.fc_elu_drop(self.fc1.shape(), 50, drop_rate=self.drop_rate)(self.fc1)
- def f(self, mri_volume, clinical_inputs):
- conv1 = CNN.conv_elu_maxpool_drop(mri_volume.size(), 192, (11, 13, 11), stride=(4, 4, 4), drop_rate=self.drop_rate, pool=True)(mri_volume)
- conv2 = CNN.conv_elu_maxpool_drop(conv1.size(), 384, (5, 6, 5), stride=(1, 1, 1), drop_rate=self.drop_rate, pool=True)(conv1)
- # MIDDLE FLOW, 3 times sepConv & ELU()
- print(f"Residual: {conv2.shape}")
- conv_mid_3 = CNN.mid_flow(conv2, self.drop_rate, filters=384)
- # CONV in 2 groups (left & right)
- groupConv4 = CNN.conv_elu_maxpool_drop(conv_mid_3.size(), 96, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
- pool=True, groups=2)(conv_mid_3)
- groupConv5 = CNN.conv_elu_maxpool_drop(groupConv4.size(), 48, (3, 4, 3), stride=(1, 1, 1), drop_rate=self.drop_rate,
- pool=True, groups=2)(groupConv4)
- # FCs
- fc1 = CNN.fc_elu_drop(groupConv5.size(), 20, drop_rate=self.drop_rate)(groupConv5)
- fc2 = CNN.fc_elu_drop(fc1.size(), 50, drop_rate=self.drop_rate)(fc1)
- return fc2
- """ LAST PART:
-
- # Flatten 3D conv network representations
- flat_conv_6 = Reshape((np.prod(K.int_shape(conv6_concat)[1:]),))(conv6_concat)
- # 2-layer Dense network for clinical features
- vol_fc1 = _fc_bn_relu_drop(64, w_regularizer=w_regularizer,
- drop_rate=drop_rate)(clinical_inputs)
- flat_volume = _fc_bn_relu_drop(20, w_regularizer=w_regularizer,
- drop_rate=drop_rate)(vol_fc1)
- # Combine image and clinical features embeddings
- fc1 = _fc_bn_relu_drop(20, w_regularizer, drop_rate=drop_rate, name='final_conv')(flat_conv_6)
- flat = concatenate([fc1, flat_volume])
- # Final 4D embedding
- fc2 = Dense(units=final_layer_size, activation='linear', kernel_regularizer=w_regularizer, name='features')(
- flat) # was linear activation"""
- ''' FULL CODE:
- # First layer
- conv1_left = _conv_bn_relu_pool_drop(192, 11, 13, 11, strides=(4, 4, 4), w_regularizer=w_regularizer,
- drop_rate=drop_rate, pool=True)(mri_volume)
-
- # Second layer
- conv2_left = _conv_bn_relu_pool_drop(384, 5, 6, 5, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
- conv1_left)
- # Introduce Middle Flow (separable convolutions with a residual connection)
- print('residual shape ' + str(conv2_left.shape))
- conv_mid_1 = mid_flow(conv2_left, drop_rate, w_regularizer,
- filters=384) # changed input to conv2_left from conv2_concat
-
- # Split channels for grouped-style convolution
- conv_mid_1_1 = Lambda(lambda x: x[:, :, :, :, :192])(conv_mid_1)
- conv_mid_1_2 = Lambda(lambda x: x[:, :, :, :, 192:])(conv_mid_1)
- conv5_left = _conv_bn_relu_pool_drop(96, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
- conv_mid_1_1)
- conv5_right = _conv_bn_relu_pool_drop(96, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
- conv_mid_1_2)
- conv6_left = _conv_bn_relu_pool_drop(48, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
- conv5_left)
- conv6_right = _conv_bn_relu_pool_drop(48, 3, 4, 3, w_regularizer=w_regularizer, drop_rate=drop_rate, pool=True)(
- conv5_right)
- conv6_concat = concatenate([conv6_left, conv6_right], axis=-1)
- # convExtra = Conv3D(48, (20,30,20),
- # strides = (1,1,1), kernel_initializer="he_normal",
- # padding="same", kernel_regularizer = w_regularizer)(conv6_concat)
- # Flatten 3D conv network representations
- flat_conv_6 = Reshape((np.prod(K.int_shape(conv6_concat)[1:]),))(conv6_concat)
- # 2-layer Dense network for clinical features
- vol_fc1 = _fc_bn_relu_drop(64, w_regularizer=w_regularizer,
- drop_rate=drop_rate)(clinical_inputs)
- flat_volume = _fc_bn_relu_drop(20, w_regularizer=w_regularizer,
- drop_rate=drop_rate)(vol_fc1)
- # Combine image and clinical features embeddings
- fc1 = _fc_bn_relu_drop(20, w_regularizer, drop_rate=drop_rate, name='final_conv')(flat_conv_6)
- # fc2 = _fc_bn_relu_drop (40, w_regularizer, drop_rate = drop_rate) (fc1)
- flat = concatenate([fc1, flat_volume])
- # Final 4D embedding
- fc2 = Dense(units=final_layer_size, activation='linear', kernel_regularizer=w_regularizer, name='features')(
- flat) # was linear activation
- '''
|