123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from preprocess import prepare_datasets
- from train_methods import train, load, evaluate, predict
- from CNN import CNN_Net
- from torch.utils.data import DataLoader
- model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
- CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
- # small dataset
- # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
- # big dataset
- mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
- csv_datapath = '../LP_ADNIMERGE.csv'
- # '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv')
- # LOADING DATA
- val_split = 0.2 # % of val and test, rest will be train
- seed = 12 # TODO Randomize seed
- properties = {
- "batch_size":32,
- "padding":0,
- "dilation":1,
- "groups":1,
- "bias":True,
- "padding_mode":"zeros",
- "drop_rate":0,
- "epochs": 20,
- "lr": [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], # Unused
- 'momentum':[0.99, 0.97, 0.95, 0.9], # Unused
- 'weight_decay':[1e-3, 1e-4, 1e-5, 0] # Unused
- }
- # TODO: Datasets include multiple labels, such as medical info
- training_data, val_data, test_data = prepare_datasets(mri_datapath, csv_datapath, val_split, seed)
- # Create data loaders
- train_dataloader = DataLoader(training_data, batch_size=properties['batch_size'], shuffle=True, drop_last=True)
- val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuffle=True) # Used during training
- test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shuffle=True) # Used at end for graphs
- print("STARTING")
- # HERE'S ACTUAL CODE
- mean = 0.
- std = 0.
- nb_samples = 0.
- for data in train_dataloader:
- print(data)
- batch_samples = data.size(0)
- data = data.view(batch_samples, data.size(1), -1)
- mean += data.mean(2).sum(0)
- std += data.std(2).sum(0)
- nb_samples += batch_samples
- mean /= nb_samples
- std /= nb_samples
- print(mean)
- print(std)
- mean = 0.
- std = 0.
- nb_samples = 0.
- for data in val_dataloader:
- batch_samples = data.size(0)
- data = data.view(batch_samples, data.size(1), -1)
- mean += data.mean(2).sum(0)
- std += data.std(2).sum(0)
- nb_samples += batch_samples
- mean /= nb_samples
- std /= nb_samples
- print(mean)
- print(std)
- mean = 0.
- std = 0.
- nb_samples = 0.
- for data in test_dataloader:
- batch_samples = data.size(0)
- data = data.view(batch_samples, data.size(1), -1)
- mean += data.mean(2).sum(0)
- std += data.std(2).sum(0)
- nb_samples += batch_samples
- mean /= nb_samples
- std /= nb_samples
- print(mean)
- print(std)
|