TALOS_main.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import torch
  2. import talos
  3. # FOR DATA
  4. from utils.preprocess import prepare_datasets
  5. from utils.train_methods import train, load, evaluate, predict
  6. from utils.CNN import CNN_Net
  7. from torch.utils.data import DataLoader
  8. from torchvision import datasets
  9. from sklearn.model_selection import KFold
  10. # GENERAL PURPOSE
  11. import pandas as pd
  12. import numpy as np
  13. import matplotlib.pyplot as plt
  14. import platform
  15. import time
  16. current_time = time.localtime()
  17. print(time.strftime("%Y-%m-%d_%H:%M", current_time))
  18. print("--- RUNNING ---")
  19. print("Pytorch Version: " + torch. __version__)
  20. print("Python Version: " + platform.python_version())
  21. # LOADING DATA
  22. val_split = 0.2 # % of val and test, rest will be train
  23. seed = 12 # TODO Randomize seed
  24. params = {
  25. "batch_size": (15, 40, 5),
  26. "padding":0,
  27. "dilation":1,
  28. "groups":1,
  29. "bias":True,
  30. "padding_mode":"zeros",
  31. "drop_rate":[0, 0.1, 0.2],
  32. "epochs": (10, 30, 5),
  33. "lr": [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6],
  34. 'momentum':[0.99, 0.97, 0.95, 0.9],
  35. 'weight_decay':[1e-3, 1e-4, 1e-5, 0]
  36. }
  37. # "optimizer":'adam',
  38. model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
  39. CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
  40. # small dataset
  41. # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
  42. # big dataset
  43. mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
  44. annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
  45. # annotations_file = pd.read_csv(annotations_datapath) # DataFrame
  46. # show_image(17508)
  47. # TODO: Datasets include multiple labels, such as medical info
  48. training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split, seed)
  49. training_data_list = list(training_data)
  50. val_data_list = list(val_data)
  51. test_data_list = list(test_data)
  52. # Create data loaders
  53. train_dataloader = DataLoader(training_data, batch_size=params['batch_size'], shuffle=True, drop_last=True)
  54. val_dataloader = DataLoader(val_data, batch_size=params['batch_size'], shuffle=True) # Used during training
  55. test_dataloader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True) # Used at end for graphs
  56. # loads a few images to test
  57. x = 0
  58. while x < 0:
  59. train_features, train_labels = next(iter(train_dataloader))[0]
  60. # print(f"Feature batch shape: {train_features.size()}")
  61. img = train_features[0].squeeze()
  62. print(f"Feature batch shape: {img.size()}")
  63. image = img[:, :, 40]
  64. print(f"Feature batch shape: {image.size()}")
  65. label = train_labels[0]
  66. print(f"Label: {label}")
  67. plt.imshow(image, cmap="gray")
  68. plt.savefig(f"./Image{x}_IS:{label}.png")
  69. plt.show()
  70. x = x+1
  71. # epochs = 20
  72. roc = True
  73. CNN = CNN_Net(prps=params, final_layer_size=2)
  74. CNN.cuda()
  75. # scan_object = talos.Scan(
  76. train(CNN, train_dataloader, val_dataloader, CNN_filepath, params=params, graphs=True)
  77. # load(CNN, CNN_filepath)
  78. evaluate(CNN, test_dataloader)
  79. predict(CNN, test_dataloader)
  80. # EXTRA
  81. # # PREDICT MODE TO TEST INDIVIDUAL IMAGES
  82. # if(predict):
  83. # on = True
  84. # print("---- Predict mode ----")
  85. # print("Integer for image")
  86. # print("x or X for exit")
  87. #
  88. # while(on):
  89. # inp = input("Next image: ")
  90. # if(inp == None or inp.lower() == 'x' or not inp.isdigit()): on = False
  91. # else:
  92. # dataloader = DataLoader(prepare_predict(mri_datapath, [inp]), batch_size=params['batch_size'], shuffle=True)
  93. # prediction = CNN.predict(dataloader)
  94. #
  95. # features, labels = next(iter(dataloader), )
  96. # img = features[0].squeeze()
  97. # image = img[:, :, 40]
  98. # print(f"Expected class: {labels}")
  99. # print(f"Prediction: {prediction}")
  100. # plt.imshow(image, cmap="gray")
  101. # plt.show()
  102. #
  103. # print("--- END ---")
  104. # params = {
  105. # "target_rows": 91,
  106. # "target_cols": 109,
  107. # "depth": 91,
  108. # "axis": 1,
  109. # "num_clinical": 2,
  110. # "CNN_drop_rate": 0.3,
  111. # "RNN_drop_rate": 0.1,
  112. # # "CNN_w_regularizer": regularizers.l2(2e-2),
  113. # # "RNN_w_regularizer": regularizers.l2(1e-6),
  114. # "CNN_batch_size": 10,
  115. # "RNN_batch_size": 5,
  116. # "val_split": 0.2,
  117. # "final_layer_size": 5
  118. # }
  119. '''
  120. params_dict = { 'CNN_w_regularizer': CNN_w_regularizer, 'RNN_w_regularizer': RNN_w_regularizer,
  121. 'CNN_batch_size': CNN_batch_size, 'RNN_batch_size': RNN_batch_size,
  122. 'CNN_drop_rate': CNN_drop_rate, 'epochs': 30,
  123. 'gpu': "/gpu:0", 'model_filepath': model_filepath,
  124. 'image_shape': (target_rows, target_cols, depth, axis),
  125. 'num_clinical': num_clinical,
  126. 'final_layer_size': final_layer_size,
  127. 'optimizer': optimizer, 'RNN_drop_rate': RNN_drop_rate,}
  128. params = Parameters(params_dict)
  129. # WHAT WAS THIS AGAIN?
  130. seeds = [np.random.randint(1, 5000) for _ in range(1)]
  131. # READ THIS TO UNDERSTAND TRAIN VS VALIDATION DATA
  132. def evaluate_net (seed):
  133. n_classes = 2
  134. data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
  135. train_data, val_data, test_data,rnn_HdataT1,rnn_HdataT2,rnn_HdataT3,rnn_AdataT1,rnn_AdataT2,rnn_AdataT3, test_mri_nonorm = data_loader.get_train_val_test(val_split, mri_datapath)
  136. print('Length Val Data[0]: ',len(val_data[0]))
  137. '''