main.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import torch
  2. import torchvision
  3. # FOR DATA
  4. from utils.preprocess import prepare_datasets, prepare_predict
  5. from utils.show_image import show_image
  6. from utils.newCNN import CNN_Net
  7. from torch.utils.data import DataLoader
  8. from torchvision import datasets
  9. from torch import nn
  10. from torchvision.transforms import ToTensor
  11. # import nonechucks as nc # Used to load data in pytorch even when images are corrupted / unavailable (skips them)
  12. # FOR IMAGE VISUALIZATION
  13. import nibabel as nib
  14. # GENERAL PURPOSE
  15. import os
  16. import pandas as pd
  17. import numpy as np
  18. import matplotlib.pyplot as plt
  19. import glob
  20. print("--- RUNNING ---")
  21. print("Pytorch Version: " + torch. __version__)
  22. # MAYBE??
  23. '''
  24. import sys
  25. sys.path.append('//data/data_wnx3/data_wnx1/rschuurs/CNN+RNN-2class-1cnn-CLEAN/utils')
  26. import os
  27. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  28. os.environ["CUDA_VISIBLE_DEVICES"] = "0" # use id from $ nvidia-smi
  29. '''
  30. # LOADING DATA
  31. # data & training properties:
  32. val_split = 0.2 # % of val and test, rest will be train
  33. seed = 12 # TODO Randomize seed
  34. # params = {
  35. # "target_rows": 91,
  36. # "target_cols": 109,
  37. # "depth": 91,
  38. # "axis": 1,
  39. # "num_clinical": 2,
  40. # "CNN_drop_rate": 0.3,
  41. # "RNN_drop_rate": 0.1,
  42. # # "CNN_w_regularizer": regularizers.l2(2e-2),
  43. # # "RNN_w_regularizer": regularizers.l2(1e-6),
  44. # "CNN_batch_size": 10,
  45. # "RNN_batch_size": 5,
  46. # "val_split": 0.2,
  47. # "final_layer_size": 5
  48. # }
  49. properties = {
  50. "batch_size":4,
  51. "padding":0,
  52. "dilation":1,
  53. "groups":1,
  54. "bias":True,
  55. "padding_mode":"zeros",
  56. "drop_rate":0
  57. }
  58. # Might have to replace datapaths or separate between training and testing
  59. model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
  60. CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth'
  61. mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/MRI_volumes_customtemplate_float32/'
  62. annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
  63. # annotations_file = pd.read_csv(annotations_datapath) # DataFrame
  64. # show_image(17508)
  65. # TODO: Datasets include multiple labels, such as medical info
  66. training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split, seed)
  67. # Create data loaders
  68. train_dataloader = DataLoader(training_data, batch_size=properties['batch_size'], shuffle=True)
  69. test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shuffle=True)
  70. val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuffle=True)
  71. # for X, y in train_dataloader:
  72. # print(f"Shape of X [Channels (colors), Y, X, Z]: {X.shape}") # X & Y are from TOP LOOKING DOWN
  73. # print(f"Shape of Y (Dataset?): {y.shape} {y.dtype}")
  74. # break
  75. # Display 4 images and labels.
  76. # x = 1
  77. # while x < 1:
  78. # train_features, train_labels = next(iter(train_dataloader))
  79. # print(f"Feature batch shape: {train_features.size()}")
  80. # img = train_features[0].squeeze()
  81. # print(f"Feature batch shape: {img.size()}")
  82. # image = img[:, :, 40]
  83. # print(f"Feature batch shape: {image.size()}")
  84. # label = train_labels[0]
  85. # print(f"Label: {label}")
  86. # plt.imshow(image, cmap="gray")
  87. # plt.show()
  88. # x = x+1
  89. train = False
  90. predict = True
  91. CNN = CNN_Net(train_dataloader, prps=properties, final_layer_size=2)
  92. CNN.cuda()
  93. # RUN CNN
  94. if(train):
  95. CNN.train_model(train_dataloader, CNN_filepath, epochs=10)
  96. CNN.evaluate_model(val_dataloader)
  97. else:
  98. CNN.load_state_dict(torch.load(CNN_filepath))
  99. CNN.evaluate_model(val_dataloader)
  100. # PREDICT MODE TO TEST INDIVIDUAL IMAGES
  101. if(predict):
  102. on = True
  103. print("---- Predict mode ----")
  104. print("Integer for image")
  105. print("x or X for exit")
  106. while(on):
  107. inp = input("Next image: ")
  108. if(inp == None or inp.lower() == 'x' or not inp.isdigit()): on = False
  109. else:
  110. dataloader = DataLoader(prepare_predict(mri_datapath, [inp]), batch_size=properties['batch_size'], shuffle=True)
  111. prediction = CNN.predict(dataloader)
  112. features, labels = next(iter(dataloader), )
  113. img = features[0].squeeze()
  114. image = img[:, :, 40]
  115. print(f"Expected class: {labels}")
  116. print(f"Prediction: {prediction}")
  117. plt.imshow(image, cmap="gray")
  118. plt.show()
  119. print("--- END ---")
  120. # EXTRA
  121. # will I need these params?
  122. '''
  123. params_dict = { 'CNN_w_regularizer': CNN_w_regularizer, 'RNN_w_regularizer': RNN_w_regularizer,
  124. 'CNN_batch_size': CNN_batch_size, 'RNN_batch_size': RNN_batch_size,
  125. 'CNN_drop_rate': CNN_drop_rate, 'epochs': 30,
  126. 'gpu': "/gpu:0", 'model_filepath': model_filepath,
  127. 'image_shape': (target_rows, target_cols, depth, axis),
  128. 'num_clinical': num_clinical,
  129. 'final_layer_size': final_layer_size,
  130. 'optimizer': optimizer, 'RNN_drop_rate': RNN_drop_rate,}
  131. params = Parameters(params_dict)
  132. # WHAT WAS THIS AGAIN?
  133. seeds = [np.random.randint(1, 5000) for _ in range(1)]
  134. # READ THIS TO UNDERSTAND TRAIN VS VALIDATION DATA
  135. def evaluate_net (seed):
  136. n_classes = 2
  137. data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
  138. train_data, val_data, test_data,rnn_HdataT1,rnn_HdataT2,rnn_HdataT3,rnn_AdataT1,rnn_AdataT2,rnn_AdataT3, test_mri_nonorm = data_loader.get_train_val_test(val_split, mri_datapath)
  139. print('Length Val Data[0]: ',len(val_data[0]))
  140. '''