|
@@ -0,0 +1,97 @@
|
|
|
+from preprocess import prepare_datasets
|
|
|
+from train_methods import train, load, evaluate, predict
|
|
|
+from CNN import CNN_Net
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+
|
|
|
+model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
|
|
|
+CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
|
|
|
+# small dataset
|
|
|
+# mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
|
|
|
+# big dataset
|
|
|
+mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
|
|
|
+annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
|
|
|
+
|
|
|
+
|
|
|
+# LOADING DATA
|
|
|
+val_split = 0.2 # % of val and test, rest will be train
|
|
|
+seed = 12 # TODO Randomize seed
|
|
|
+
|
|
|
+properties = {
|
|
|
+ "batch_size":32,
|
|
|
+ "padding":0,
|
|
|
+ "dilation":1,
|
|
|
+ "groups":1,
|
|
|
+ "bias":True,
|
|
|
+ "padding_mode":"zeros",
|
|
|
+ "drop_rate":0,
|
|
|
+ "epochs": 20,
|
|
|
+ "lr": [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], # Unused
|
|
|
+ 'momentum':[0.99, 0.97, 0.95, 0.9], # Unused
|
|
|
+ 'weight_decay':[1e-3, 1e-4, 1e-5, 0] # Unused
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# TODO: Datasets include multiple labels, such as medical info
|
|
|
+training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split, seed)
|
|
|
+
|
|
|
+# Create data loaders
|
|
|
+train_dataloader = DataLoader(training_data, batch_size=properties['batch_size'], shuffle=True, drop_last=True)
|
|
|
+val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuffle=True) # Used during training
|
|
|
+test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shuffle=True) # Used at end for graphs
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# HERE'S ACTUAL CODE
|
|
|
+mean = 0.
|
|
|
+std = 0.
|
|
|
+nb_samples = 0.
|
|
|
+for data in train_dataloader:
|
|
|
+ batch_samples = data.size(0)
|
|
|
+ data = data.view(batch_samples, data.size(1), -1)
|
|
|
+ mean += data.mean(2).sum(0)
|
|
|
+ std += data.std(2).sum(0)
|
|
|
+ nb_samples += batch_samples
|
|
|
+
|
|
|
+mean /= nb_samples
|
|
|
+std /= nb_samples
|
|
|
+
|
|
|
+print(mean)
|
|
|
+print(std)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+mean = 0.
|
|
|
+std = 0.
|
|
|
+nb_samples = 0.
|
|
|
+for data in val_dataloader:
|
|
|
+ batch_samples = data.size(0)
|
|
|
+ data = data.view(batch_samples, data.size(1), -1)
|
|
|
+ mean += data.mean(2).sum(0)
|
|
|
+ std += data.std(2).sum(0)
|
|
|
+ nb_samples += batch_samples
|
|
|
+
|
|
|
+mean /= nb_samples
|
|
|
+std /= nb_samples
|
|
|
+
|
|
|
+print(mean)
|
|
|
+print(std)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+mean = 0.
|
|
|
+std = 0.
|
|
|
+nb_samples = 0.
|
|
|
+for data in test_dataloader:
|
|
|
+ batch_samples = data.size(0)
|
|
|
+ data = data.view(batch_samples, data.size(1), -1)
|
|
|
+ mean += data.mean(2).sum(0)
|
|
|
+ std += data.std(2).sum(0)
|
|
|
+ nb_samples += batch_samples
|
|
|
+
|
|
|
+mean /= nb_samples
|
|
|
+std /= nb_samples
|
|
|
+
|
|
|
+print(mean)
|
|
|
+print(std)
|