main.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import torch
  2. import torchvision
  3. # FOR DATA
  4. from utils.preprocess import prepare_datasets
  5. from utils.show_image import show_image
  6. from torch.utils.data import DataLoader
  7. from torchvision import datasets
  8. from torch import nn
  9. import torch.nn.functional as F
  10. from torchvision.transforms import ToTensor
  11. # import nonechucks as nc # Used to load data in pytorch even when images are corrupted / unavailable (skips them)
  12. # FOR IMAGE VISUALIZATION
  13. import nibabel as nib
  14. # GENERAL PURPOSE
  15. import os
  16. import pandas as pd
  17. import numpy as np
  18. import matplotlib.pyplot as plt
  19. import glob
  20. from datetime import datetime
  21. # FOR TRAINING
  22. import torch.optim as optim
  23. import utils.models as models
  24. import utils.layers as ly
  25. from tqdm import tqdm
  26. #Set Default GPU
  27. cuda_device = torch.device('cuda:1')
  28. torch.set_default_device(cuda_device)
  29. print("--- RUNNING ---")
  30. print("Pytorch Version: " + torch. __version__)
  31. # data & training properties:
  32. val_split = 0.2 # % of val and test, rest will be train
  33. runs = 1
  34. epochs = 5
  35. time_stamp = timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  36. seeds = [np.random.randint(0, 1000) for _ in range(runs)]
  37. #Data Path
  38. mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/'
  39. #Local Path
  40. local_path = '/export/home/nschense/alzheimers/Pytorch_CNN-RNN'
  41. xls_path = local_path + '/LP_ADNIMERGE.csv'
  42. saved_model_path = local_path + 'saved_models/'
  43. DEBUG = False
  44. # TODO: Datasets include multiple labels, such as medical info
  45. def evaluate_model(seed):
  46. training_data, val_data, test_data = prepare_datasets(mri_datapath, xls_path, val_split, seed)
  47. batch_size = 64
  48. # Create data loaders
  49. train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
  50. test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
  51. val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
  52. #Print Shape of Image Data
  53. print("Shape of MRI Data: ", training_data[0][0].shape)
  54. print("Shape of XLS Data: ", training_data[0][1].shape)
  55. #Print Training Data Length
  56. print("Length of Training Data: ", len(train_dataloader))
  57. print("--- INITIALIZING MODEL ---")
  58. model_CNN = models.CNN_Net(1, 2, 0.5).to(cuda_device)
  59. criterion = nn.BCELoss()
  60. optimizer = optim.Adam(model_CNN.parameters(), lr=0.001)
  61. print("Seed: ", seed)
  62. epoch_number = 0
  63. print("--- TRAINING MODEL ---")
  64. for epoch in range(epochs):
  65. running_loss = 0.0
  66. length = len(train_dataloader)
  67. for i, data in tqdm(enumerate(train_dataloader, 0), total=length, desc="Epoch " + str(epoch), unit="batch"):
  68. mri, xls, label = data
  69. optimizer.zero_grad()
  70. mri = mri.to(cuda_device).float()
  71. xls = xls.to(cuda_device).float()
  72. label = label.to(cuda_device).float()
  73. outputs = model_CNN((mri, xls))
  74. if DEBUG:
  75. print(outputs.shape, label.shape)
  76. loss = criterion(outputs, label)
  77. loss.backward()
  78. optimizer.step()
  79. running_loss += loss.item()
  80. if i % 1000 == 999:
  81. print("Epoch: ", epoch_number, "Batch: ", i+1, "Loss: ", running_loss / 1000, "Accuracy: ", )
  82. print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 1000))
  83. running_loss = 0.0
  84. epoch_number += 1
  85. print("--- TESTING MODEL ---")
  86. #Test model
  87. correct = 0
  88. total = 0
  89. with torch.no_grad():
  90. length = len(test_dataloader)
  91. for i, data in tqdm(enumerate(test_dataloader, 0), total=length, desc="Testing", unit="batch"):
  92. mri, xls, label = data
  93. mri = mri.to(cuda_device).float()
  94. xls = xls.to(cuda_device).float()
  95. label = label.to(cuda_device).float()
  96. outputs = model_CNN((mri, xls))
  97. if DEBUG:
  98. print(outputs.shape, label.shape)
  99. _, predicted = torch.max(outputs.data, 1)
  100. _, labels = torch.max(label.data, 1)
  101. if DEBUG:
  102. print("Predicted: ", predicted)
  103. print("Labels: ", labels)
  104. total += labels.size(0)
  105. correct += (predicted == labels).sum().item()
  106. print("Model Accuracy: ", 100 * correct / total)
  107. for seed in seeds:
  108. evaluate_model(seed)
  109. print("--- END ---")