소스 검색

Model Trains!

Nicholas Schense 1 년 전
부모
커밋
83cc6dd1cf
4개의 변경된 파일21개의 추가작업 그리고 39개의 파일을 삭제
  1. 10 4
      main.py
  2. 0 1
      utils/layers.py
  3. 8 34
      utils/models.py
  4. 3 0
      utils/preprocess.py

+ 10 - 4
main.py

@@ -67,14 +67,18 @@ def evaluate_model(seed):
     print("Shape of MRI Data: ", training_data[0][0].shape)
     print("Shape of XLS Data: ", training_data[0][1].shape)
 
+    #Print Training Data Length
+    print("Length of Training Data: ", len(train_dataloader))
 
 
-    model_CNN = models.CNN_Net(1, 1, 0.5).double()
-    criterion = nn.CrossEntropyLoss()
+    print("Initializing Model...")
+    model_CNN = models.CNN_Net(1, 2, 0.5).cuda()
+    criterion = nn.BCELoss()
     optimizer = optim.Adam(model_CNN.parameters(), lr=0.001)
     print("Seed: ", seed)
     epoch_number = 0
 
+    print("Training Model...")
     for epoch in range(epochs):
         running_loss = 0.0
         for i, data in enumerate(train_dataloader, 0):
@@ -82,11 +86,13 @@ def evaluate_model(seed):
 
             optimizer.zero_grad()
 
-            mri = mri.double()
-            xls = xls.double()
+            mri = mri.cuda().float()
+            xls = xls.cuda().float()
+            label = label.cuda().float()
 
 
             outputs = model_CNN((mri, xls))
+
             loss = criterion(outputs, label)
             loss.backward()
             optimizer.step()

+ 0 - 1
utils/layers.py

@@ -53,7 +53,6 @@ class SplitConvBlock(nn.Module):
     def forward(self, x):
         (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
 
-        print(left.shape, right.shape)
 
         self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
         self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)

+ 8 - 34
utils/models.py

@@ -3,7 +3,7 @@ from torchvision.transforms import ToTensor
 import os
 import pandas as pd
 import numpy as np
-import layers as ly
+import utils.layers as ly
 
 import torch
 import torchvision
@@ -32,26 +32,8 @@ class CNN_Net(nn.Module):
     def __init__(self, image_channels, clin_data_channels, droprate):
         super().__init__()
 
-        # Initial Convolutional Blocks
-        self.conv1 = ly.ConvBlock(
-            image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=False
-        )
-        self.conv2 = ly.ConvBlock(
-            192, 384, (5, 6, 5), droprate=droprate, pool=False
-        )
-
-        # Midflow Block
-        self.midflow = ly.MidFlowBlock(384, droprate)
-
-        
-
-        # Split Convolutional Block
-        self.splitconv = ly.SplitConvBlock(384, 192, 96, 4, droprate)
-
-        #Fully Connected Block
-        self.fc_image = ly.FullConnBlock(96, 20, droprate=droprate)
-
-
+        #Image Section
+        self.image_section = CNN_Image_Section(image_channels, droprate)
 
         #Data Layers, fully connected
         self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
@@ -61,7 +43,7 @@ class CNN_Net(nn.Module):
         #Final Dense Layer
         self.dense1 = nn.Linear(40, 5)
         self.dense2 = nn.Linear(5, 2)
-        self.softmax = nn.Softmax()
+        self.softmax = nn.Softmax(dim = 1)
 
        
 
@@ -69,19 +51,14 @@ class CNN_Net(nn.Module):
 
         image, clin_data = x
 
-    
-        image = self.conv1(image)
-        image = self.conv2(image)
-        image = self.midflow(image)
-        image = self.splitconv(image)
-        image = torch.flatten(image, 1)
-        image = self.fc_image(image)
+        image = self.image_section(image)
+
+        
 
         clin_data = self.fc_clin1(clin_data)
         clin_data = self.fc_clin2(clin_data)
 
 
-
         x = torch.cat((image, clin_data), dim=1)
         x = self.dense1(x)
         x = self.dense2(x)
@@ -119,12 +96,9 @@ class CNN_Image_Section(nn.Module):
         x = self.conv2(x)
         x = self.midflow(x)
         x = self.splitconv(x)
-        print(x.shape)
         x = torch.flatten(x, 1)
         x = self.fc_image(x)
 
+        return x
 
 
-print(ts.summary(CNN_Image_Section(1, 0.5).cuda(), (1, 91, 109, 91)))
-    
-        

+ 3 - 0
utils/preprocess.py

@@ -125,4 +125,7 @@ class CustomDataset(Dataset):
         mri_tensor = torch.from_numpy(mri_data).unsqueeze(0)
         
         class_id = torch.tensor([class_id])
+        #Convert to one-hot and squeeze
+        class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
+
         return mri_tensor, xls_tensor, class_id