123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- from symbol import parameters
- import torch
- # FOR DATA
- from utils.preprocess import prepare_datasets
- from utils.train_methods import train, load, evaluate, predict
- from utils.CNN import CNN_Net
- from torch.utils.data import DataLoader
- from torchvision import datasets
- from sklearn.model_selection import KFold
- # GENERAL PURPOSE
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import platform
- import time
- current_time = time.localtime()
- # INTERPRETABILITY
- from captum.attr import GuidedGradCam
- print(time.strftime("%Y-%m-%d_%H:%M", current_time))
- print("--- RUNNING ---")
- print("Pytorch Version: " + torch. __version__)
- print("Python Version: " + platform.python_version())
- # LOADING DATA
- val_split = 0.2 # % of val and test, rest will be train
- seed = 12 # TODO Randomize seed
- properties = {
- "batch_size":32,
- "padding":0,
- "dilation":1,
- "groups":1,
- "bias":True,
- "padding_mode":"zeros",
- "drop_rate":0,
- "epochs": 20,
- "lr": [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], # Unused
- 'momentum':[0.99, 0.97, 0.95, 0.9], # Unused
- 'weight_decay':[1e-3, 1e-4, 1e-5, 0] # Unused
- }
- model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
- CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
- # small dataset
- # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
- # big dataset
- mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
- csv_datapath = 'LP_ADNIMERGE.csv'
- # annotations_file = pd.read_csv(annotations_datapath) # DataFrame
- # show_image(17508)
- # TODO: Datasets include multiple labels, such as medical info
- training_data, val_data, test_data = prepare_datasets(mri_datapath, csv_datapath, val_split, seed)
- # Create data loaders
- train_dataloader = DataLoader(training_data, batch_size=properties['batch_size'], shuffle=True, drop_last=True)
- val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuffle=True) # Used during training
- test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shuffle=True) # Used at end for graphs
- # loads a few images to test
- x = 0
- while x < 0:
- train_features, train_labels = next(iter(train_dataloader))
- # print(f"Feature batch shape: {train_features.size()}")
- img = train_features[0].squeeze()
- print(f"Feature batch shape: {img.size()}")
- image = img[:, :, 40]
- print(f"Feature batch shape: {image.size()}")
- label = train_labels[0]
- print(f"Label: {label}")
- plt.imshow(image, cmap="gray")
- # plt.savefig(f"./Image{x}_IS:{label}.png")
- plt.show()
- x = x+1
- roc = True
- CNN = CNN_Net(prps=properties, final_layer_size=2)
- CNN.cuda()
- # train(CNN, train_dataloader, val_dataloader, CNN_filepath, properties, graphs=True)
- load(CNN, CNN_filepath)
- # evaluate(CNN, test_dataloader)
- # predict(CNN, test_dataloader)
- print(CNN)
- CNN.eval()
- guided_gc = GuidedGradCam(CNN, CNN.conv5_sepConv) # Performed on LAST convolution layer
- # input = torch.randn(1, 1, 91, 109, 91, requires_grad=True).cuda()
- # TODO MAKE BATCH SIZE 1 FOR THIS TO WORK??
- train_features, train_labels = next(iter(train_dataloader))
- while(train_labels[0] == 0):
- train_features, train_labels = next(iter(train_dataloader))
- attr = guided_gc.attribute(train_features.cuda(), 0) #, interpolate_mode="area")
- # draw the attributes
- attr = attr.unsqueeze(0)
- attr = attr.cpu().detach().numpy()
- attr = np.clip(attr, 0, 1)
- plt.imshow(attr)
- plt.show()
- print("Done w/ attributions")
- print(attr)
- # EXTRA
- # # PREDICT MODE TO TEST INDIVIDUAL IMAGES
- # if(predict):
- # on = True
- # print("---- Predict mode ----")
- # print("Integer for image")
- # print("x or X for exit")
- #
- # while(on):
- # inp = input("Next image: ")
- # if(inp == None or inp.lower() == 'x' or not inp.isdigit()): on = False
- # else:
- # dataloader = DataLoader(prepare_predict(mri_datapath, [inp]), batch_size=params['batch_size'], shuffle=True)
- # prediction = CNN.predict(dataloader)
- #
- # features, labels = next(iter(dataloader), )
- # img = features[0].squeeze()
- # image = img[:, :, 40]
- # print(f"Expected class: {labels}")
- # print(f"Prediction: {prediction}")
- # plt.imshow(image, cmap="gray")
- # plt.show()
- #
- # print("--- END ---")
- # params = {
- # "target_rows": 91,
- # "target_cols": 109,
- # "depth": 91,
- # "axis": 1,
- # "num_clinical": 2,
- # "CNN_drop_rate": 0.3,
- # "RNN_drop_rate": 0.1,
- # # "CNN_w_regularizer": regularizers.l2(2e-2),
- # # "RNN_w_regularizer": regularizers.l2(1e-6),
- # "CNN_batch_size": 10,
- # "RNN_batch_size": 5,
- # "val_split": 0.2,
- # "final_layer_size": 5
- # }
- '''
- params_dict = { 'CNN_w_regularizer': CNN_w_regularizer, 'RNN_w_regularizer': RNN_w_regularizer,
- 'CNN_batch_size': CNN_batch_size, 'RNN_batch_size': RNN_batch_size,
- 'CNN_drop_rate': CNN_drop_rate, 'epochs': 30,
- 'gpu': "/gpu:0", 'model_filepath': model_filepath,
- 'image_shape': (target_rows, target_cols, depth, axis),
- 'num_clinical': num_clinical,
- 'final_layer_size': final_layer_size,
- 'optimizer': optimizer, 'RNN_drop_rate': RNN_drop_rate,}
- params = Parameters(params_dict)
- # WHAT WAS THIS AGAIN?
- seeds = [np.random.randint(1, 5000) for _ in range(1)]
- # READ THIS TO UNDERSTAND TRAIN VS VALIDATION DATA
- def evaluate_net (seed):
- n_classes = 2
- data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
- train_data, val_data, test_data,rnn_HdataT1,rnn_HdataT2,rnn_HdataT3,rnn_AdataT1,rnn_AdataT2,rnn_AdataT3, test_mri_nonorm = data_loader.get_train_val_test(val_split, mri_datapath)
- print('Length Val Data[0]: ',len(val_data[0]))
- '''
|