Переглянути джерело

Small update for Andrej use

Ruben Aguilo Schuurs 4 місяців тому
батько
коміт
ad21d569c0
2 змінених файлів з 48 додано та 26 видалено
  1. 20 0
      README
  2. 28 26
      main.py

+ 20 - 0
README

@@ -0,0 +1,20 @@
+Pytorch CNN model to diagnose Alzheimer's Disease on PET Scans
+PyTorch version Ruben Aguilo Schuurs (UW-Madison), reworked from Dr. Alison Deatsch's Tensorflow Version
+
+Files:
+- main.py: Controls everything. Calls all of the following files
+- utils/CNN.py & utils/CNN_Layers.py: CNN Model structure
+- preprocess.py: prepares CustomDatasets training, validation, and testing
+- train_methods.py: CNN training code
+- Graphs.py: Multiple graphing / results functions
+- K-fold.py: Specific graphs
+- dataset_sd_mean_finder.py: unused
+- show_image.py: show image
+- TALOS_main.py: unsuccessful attempt at implementing talos
+
+To run (in main.py):
+1. Define the CNN filepath and dataset filepath
+2. Change the properties
+3. Comment or uncomment the functions (train, load, evaluate, predict) depending on your interest.
+-- You don't need to load again after training.
+4. See it in action! Training will take a while with large dataset!!

+ 28 - 26
main.py

@@ -45,9 +45,9 @@ properties = {
     'momentum':[0.99, 0.97, 0.95, 0.9],  # Unused
     'weight_decay':[1e-3, 1e-4, 1e-5, 0]    # Unused
 }
+roc = True  # If True, will make ROC curve
 
-
-model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
+model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'     # TODO, MUST UPDATE BEFORE RUNNING!
 CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth'       # cnn_net.pth
 # small dataset
 # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/'   # Small Test
@@ -58,7 +58,7 @@ csv_datapath = 'LP_ADNIMERGE.csv'
 # annotations_file = pd.read_csv(annotations_datapath)    # DataFrame
 # show_image(17508)
 
-# TODO: Datasets include multiple labels, such as medical info
+# Datasets include multiple labels, such as medical info
 training_data, val_data, test_data = prepare_datasets(mri_datapath, csv_datapath, val_split, seed)
 
 # Create data loaders
@@ -69,7 +69,7 @@ test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shu
 
 # loads a few images to test
 x = 0
-while x < 0:
+while x < 1:
     train_features, train_labels = next(iter(train_dataloader))
     # print(f"Feature batch shape: {train_features.size()}")
     img = train_features[0].squeeze()
@@ -85,10 +85,11 @@ while x < 0:
     x = x+1
 
 
-roc = True
+
 CNN = CNN_Net(prps=properties, final_layer_size=2)
 CNN.cuda()
 
+# UNCOMMENT THE METHODS TO BE PERFORMED
 # train(CNN, train_dataloader, val_dataloader, CNN_filepath, properties, graphs=True)
 load(CNN, CNN_filepath)
 # evaluate(CNN, test_dataloader)
@@ -97,28 +98,8 @@ load(CNN, CNN_filepath)
 print(CNN)
 CNN.eval()
 
-guided_gc = GuidedGradCam(CNN, CNN.conv5_sepConv)   # Performed on LAST convolution layer
-# input = torch.randn(1, 1, 91, 109, 91, requires_grad=True).cuda()
-
-# TODO MAKE BATCH SIZE 1 FOR THIS TO WORK??
-train_features, train_labels = next(iter(train_dataloader))
-while(train_labels[0] == 0):
-    train_features, train_labels = next(iter(train_dataloader))
-
-attr = guided_gc.attribute(train_features.cuda(), 0) #, interpolate_mode="area")
-
-# draw the attributes
-attr = attr.unsqueeze(0)
-attr = attr.cpu().detach().numpy()
-attr = np.clip(attr, 0, 1)
-plt.imshow(attr)
-plt.show()
-
-print("Done w/ attributions")
-print(attr)
-
-# EXTRA
 
+# EXTRA TESTS AND IMPLEMENTATION ATTEMPTS, CAN BE IGNORED
 
 # # PREDICT MODE TO TEST INDIVIDUAL IMAGES
 # if(predict):
@@ -184,3 +165,24 @@ def evaluate_net (seed):
 
     print('Length Val Data[0]: ',len(val_data[0]))
 '''
+
+
+# Failed attempt, to be ignored
+# guided_gc = GuidedGradCam(CNN, CNN.conv5_sepConv)   # Performed on LAST convolution layer
+# input = torch.randn(1, 1, 91, 109, 91, requires_grad=True).cuda()
+
+# train_features, train_labels = next(iter(train_dataloader))
+# while(train_labels[0] == 0):
+#     train_features, train_labels = next(iter(train_dataloader))
+
+# attr = guided_gc.attribute(train_features.cuda(), 0) #, interpolate_mode="area")
+
+# # draw the attributes
+# attr = attr.unsqueeze(0)
+# attr = attr.cpu().detach().numpy()
+# attr = np.clip(attr, 0, 1)
+# plt.imshow(attr)
+# plt.show()
+
+# print("Done w/ attributions")
+# print(attr)