newCNN.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from torch import device, cuda
  2. import torch
  3. from torch import add
  4. import torch.nn as nn
  5. import utils.newCNN_Layers as CustomLayers
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. import utils.CNN_methods as CNN
  9. import copy
  10. class CNN_Net(nn.Module):
  11. def __init__(self, input, prps, final_layer_size=5):
  12. super(CNN_Net, self).__init__()
  13. self.final_layer_size = final_layer_size
  14. self.device = device('cuda:0' if cuda.is_available() else 'cpu')
  15. print("CNN Initialized. Using: " + str(self.device))
  16. # GETS FIRST IMAGE FOR SIZE
  17. data_iter = iter(input)
  18. first_batch = next(data_iter)
  19. first_features = first_batch[0]
  20. image = first_features[0]
  21. # LAYERS
  22. print(f"CNN Model Initialization. Input size: {image.size()}")
  23. self.conv1 = CustomLayers.Conv_elu_maxpool_drop(1, 192, (11, 13, 11), stride=(4,4,4), pool=True, prps=prps)
  24. self.conv2 = CustomLayers.Conv_elu_maxpool_drop(192, 384, (5, 6, 5), stride=(1,1,1), pool=True, prps=prps)
  25. self.conv3_mid_flow = CustomLayers.Mid_flow(384, 384, prps=prps)
  26. self.conv4_sepConv = CustomLayers.Conv_elu_maxpool_drop(384, 96,(3, 4, 3), stride=(1,1,1), pool=True, prps=prps,
  27. sep_conv=True)
  28. self.conv5_sepConv = CustomLayers.Conv_elu_maxpool_drop(96, 48, (3, 4, 3), stride=(1, 1, 1), pool=True,
  29. prps=prps, sep_conv=True)
  30. self.fc1 = CustomLayers.Fc_elu_drop(113568, 20, prps=prps) # TODO, concatenate clinical data after this
  31. self.fc2 = CustomLayers.Fc_elu_drop(20, final_layer_size, prps=prps)
  32. # FORWARDS
  33. def forward(self, x):
  34. x = self.conv1(x)
  35. x = self.conv2(x)
  36. x = self.conv3_mid_flow(x)
  37. x = self.conv4_sepConv(x)
  38. x = self.conv5_sepConv(x)
  39. # FLATTEN x
  40. flatten_size = x.size(1) * x.size(2) * x.size(3) * x.size(4)
  41. x = x.view(-1, flatten_size)
  42. x = self.fc1(x)
  43. x = self.fc2(x)
  44. return x
  45. # TRAIN
  46. def train_model(self, trainloader, PATH, epochs):
  47. self.train()
  48. criterion = nn.CrossEntropyLoss()
  49. optimizer = optim.Adam(self.parameters(), lr=1e-5)
  50. for epoch in epochs: # loop over the dataset multiple times
  51. print(f"Training... {epoch}/{epochs}")
  52. running_loss = 0.0
  53. for i, data in enumerate(trainloader, 0):
  54. # get the inputs; data is a list of [inputs, labels]
  55. inputs, labels = data[0].to(self.device), data[1].to(self.device)
  56. # zero the parameter gradients
  57. optimizer.zero_grad()
  58. # forward + backward + optimize
  59. outputs = self.forward(inputs)
  60. loss = criterion(outputs, labels)
  61. loss.backward()
  62. optimizer.step()
  63. # print statistics
  64. running_loss += loss.item()
  65. if i % 2000 == 1999: # print every 2000 mini-batches
  66. print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
  67. running_loss = 0.0
  68. print('Finished Training')
  69. torch.save(self.state_dict(), PATH)
  70. # TEST
  71. def evaluate_model(self, testloader):
  72. correct = 0
  73. total = 0
  74. self.eval()
  75. # since we're not training, we don't need to calculate the gradients for our outputs
  76. with torch.no_grad():
  77. for data in testloader:
  78. images, labels = data[0].to(self.device), data[1].to(self.device)
  79. # calculate outputs by running images through the network
  80. outputs = self.forward(images)
  81. # the class with the highest energy is what we choose as prediction
  82. _, predicted = torch.max(outputs.data, 1)
  83. total += labels.size(0)
  84. print(f"Predicted class vals: {predicted}")
  85. correct += (predicted == labels).sum().item()
  86. print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
  87. self.train()
  88. # PREDICT
  89. def predict(self, loader):
  90. self.eval()
  91. with torch.no_grad():
  92. for data in loader:
  93. images, labels = data[0].to(self.device), data[1].to(self.device)
  94. outputs = self.forward(images)
  95. # the class with the highest energy is what we choose as prediction
  96. _, predicted = torch.max(outputs.data, 1)
  97. self.train()
  98. return predicted