main.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import torch
  2. import torchvision
  3. # FOR DATA
  4. from utils.preprocess import prepare_datasets
  5. from utils.show_image import show_image
  6. from torch.utils.data import DataLoader
  7. from torchvision import datasets
  8. from torch import nn
  9. from torchvision.transforms import ToTensor
  10. # import nonechucks as nc # Used to load data in pytorch even when images are corrupted / unavailable (skips them)
  11. # FOR IMAGE VISUALIZATION
  12. import nibabel as nib
  13. # GENERAL PURPOSE
  14. import os
  15. import pandas as pd
  16. import numpy as np
  17. import matplotlib.pyplot as plt
  18. import glob
  19. print("--- RUNNING ---")
  20. print("Pytorch Version: " + torch. __version__)
  21. # MAYBE??
  22. '''
  23. import sys
  24. sys.path.append('//data/data_wnx3/data_wnx1/rschuurs/CNN+RNN-2class-1cnn-CLEAN/utils')
  25. import os
  26. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  27. os.environ["CUDA_VISIBLE_DEVICES"] = "0" # use id from $ nvidia-smi
  28. '''
  29. # LOADING DATA
  30. # data & training properties:
  31. val_split = 0.2 # % of val and test, rest will be train
  32. seed = 12 # TODO Randomize seed
  33. '''
  34. target_rows = 91
  35. target_cols = 109
  36. depth = 91
  37. axis = 1
  38. num_clinical = 2
  39. CNN_drop_rate = 0.3
  40. RNN_drop_rate = 0.1
  41. CNN_w_regularizer = regularizers.l2(2e-2)
  42. RNN_w_regularizer = regularizers.l2(1e-6)
  43. CNN_batch_size = 10
  44. RNN_batch_size = 5
  45. val_split = 0.2
  46. optimizer = Adam(lr=1e-5)
  47. final_layer_size = 5
  48. '''
  49. # Might have to replace datapaths or separate between training and testing
  50. model_filepath = '//data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
  51. mri_datapath = './ADNI_volumes_customtemplate_float32/'
  52. annotations_datapath = './LP_ADNIMERGE.csv'
  53. # annotations_file = pd.read_csv(annotations_datapath) # DataFrame
  54. # show_image(17508)
  55. # TODO: Datasets include multiple labels, such as medical info
  56. training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split, seed)
  57. batch_size = 64
  58. # Create data loaders
  59. train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
  60. test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
  61. val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
  62. for X, y in train_dataloader:
  63. print(f"Shape of X [N, C, H, W]: {X.shape}")
  64. print(f"Shape of y: {y.shape} {y.dtype}")
  65. break
  66. # Display 10 images and labels.
  67. x = 0
  68. while x < 10:
  69. train_features, train_labels = next(iter(train_dataloader))
  70. print(f"Feature batch shape: {train_features.size()}")
  71. img = train_features[0].squeeze()
  72. image = img[:, :, 40]
  73. label = train_labels[0]
  74. plt.imshow(image, cmap="gray")
  75. plt.show()
  76. print(f"Label: {label}")
  77. x = x+1
  78. print("--- END ---")
  79. # EXTRA
  80. # will I need these params?
  81. '''
  82. params_dict = { 'CNN_w_regularizer': CNN_w_regularizer, 'RNN_w_regularizer': RNN_w_regularizer,
  83. 'CNN_batch_size': CNN_batch_size, 'RNN_batch_size': RNN_batch_size,
  84. 'CNN_drop_rate': CNN_drop_rate, 'epochs': 30,
  85. 'gpu': "/gpu:0", 'model_filepath': model_filepath,
  86. 'image_shape': (target_rows, target_cols, depth, axis),
  87. 'num_clinical': num_clinical,
  88. 'final_layer_size': final_layer_size,
  89. 'optimizer': optimizer, 'RNN_drop_rate': RNN_drop_rate,}
  90. params = Parameters(params_dict)
  91. # WHAT WAS THIS AGAIN?
  92. seeds = [np.random.randint(1, 5000) for _ in range(1)]
  93. # READ THIS TO UNDERSTAND TRAIN VS VALIDATION DATA
  94. def evaluate_net (seed):
  95. n_classes = 2
  96. data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
  97. train_data, val_data, test_data,rnn_HdataT1,rnn_HdataT2,rnn_HdataT3,rnn_AdataT1,rnn_AdataT2,rnn_AdataT3, test_mri_nonorm = data_loader.get_train_val_test(val_split, mri_datapath)
  98. print('Length Val Data[0]: ',len(val_data[0]))
  99. '''