dataset_sd_mean_finder.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from preprocess import prepare_datasets
  2. from train_methods import train, load, evaluate, predict
  3. from CNN import CNN_Net
  4. from torch.utils.data import DataLoader
  5. model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
  6. CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
  7. # small dataset
  8. # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
  9. # big dataset
  10. mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
  11. csv_datapath = '../LP_ADNIMERGE.csv'
  12. # '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv')
  13. # LOADING DATA
  14. val_split = 0.2 # % of val and test, rest will be train
  15. seed = 12 # TODO Randomize seed
  16. properties = {
  17. "batch_size":32,
  18. "padding":0,
  19. "dilation":1,
  20. "groups":1,
  21. "bias":True,
  22. "padding_mode":"zeros",
  23. "drop_rate":0,
  24. "epochs": 20,
  25. "lr": [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], # Unused
  26. 'momentum':[0.99, 0.97, 0.95, 0.9], # Unused
  27. 'weight_decay':[1e-3, 1e-4, 1e-5, 0] # Unused
  28. }
  29. # TODO: Datasets include multiple labels, such as medical info
  30. training_data, val_data, test_data = prepare_datasets(mri_datapath, csv_datapath, val_split, seed)
  31. # Create data loaders
  32. train_dataloader = DataLoader(training_data, batch_size=properties['batch_size'], shuffle=True, drop_last=True)
  33. val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuffle=True) # Used during training
  34. test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shuffle=True) # Used at end for graphs
  35. print("STARTING")
  36. # HERE'S ACTUAL CODE
  37. mean = 0.
  38. std = 0.
  39. nb_samples = 0.
  40. for data in train_dataloader:
  41. print(data)
  42. batch_samples = data.size(0)
  43. data = data.view(batch_samples, data.size(1), -1)
  44. mean += data.mean(2).sum(0)
  45. std += data.std(2).sum(0)
  46. nb_samples += batch_samples
  47. mean /= nb_samples
  48. std /= nb_samples
  49. print(mean)
  50. print(std)
  51. mean = 0.
  52. std = 0.
  53. nb_samples = 0.
  54. for data in val_dataloader:
  55. batch_samples = data.size(0)
  56. data = data.view(batch_samples, data.size(1), -1)
  57. mean += data.mean(2).sum(0)
  58. std += data.std(2).sum(0)
  59. nb_samples += batch_samples
  60. mean /= nb_samples
  61. std /= nb_samples
  62. print(mean)
  63. print(std)
  64. mean = 0.
  65. std = 0.
  66. nb_samples = 0.
  67. for data in test_dataloader:
  68. batch_samples = data.size(0)
  69. data = data.view(batch_samples, data.size(1), -1)
  70. mean += data.mean(2).sum(0)
  71. std += data.std(2).sum(0)
  72. nb_samples += batch_samples
  73. mean /= nb_samples
  74. std /= nb_samples
  75. print(mean)
  76. print(std)