main.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import torch
  2. from torch import nn
  3. # GENERAL PURPOSE
  4. import numpy as np
  5. from datetime import datetime
  6. import pandas as pd
  7. import os
  8. # FOR TRAINING
  9. import torch.optim as optim
  10. import utils.models as models
  11. from utils.training import train_model, test_model, initalize_dataloaders, plot_results
  12. #Set Default GPU
  13. cuda_device = torch.device('cuda:1')
  14. torch.set_default_device(cuda_device)
  15. print("--- RUNNING ---")
  16. print("Pytorch Version: " + torch. __version__)
  17. # data & training properties:
  18. val_split = 0.2 # % of val and test, rest will be train
  19. runs = 1
  20. epochs = 30
  21. seeds = [np.random.randint(0, 1000) for _ in range(runs)]
  22. #Data Path
  23. mri_path = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/'
  24. #Local Path
  25. local_path = '/export/home/nschense/alzheimers/Pytorch_CNN-RNN'
  26. xls_path = local_path + '/LP_ADNIMERGE.csv'
  27. saved_model_path = local_path + '/saved_models/'
  28. plot_path = local_path + '/plots/'
  29. training_record_path = local_path + '/training_records/'
  30. DEBUG = False
  31. model_CNN = models.CNN_Net(1, 2, 0.5).to(cuda_device)
  32. criterion = nn.BCELoss()
  33. optimizer = optim.Adam(model_CNN.parameters(), lr=0.001)
  34. for seed in seeds:
  35. time_stamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  36. train_loader, val_loader, test_loader = initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=cuda_device)
  37. train_results = train_model(model_CNN, seed, time_stamp, epochs, train_loader, val_loader, saved_model_path, "CNN", optimizer, criterion, cuda_device=cuda_device)
  38. test_model(model_CNN, test_loader, cuda_device=cuda_device)
  39. #Plot results
  40. plot_results(train_results["train_acc"], train_results["train_loss"], train_results["val_acc"], train_results["val_loss"], "CNN", time_stamp, plot_path)
  41. #Save training results
  42. if not os.path.exists(training_record_path):
  43. os.makedirs(training_record_path)
  44. train_results.to_csv(training_record_path + "CNN_t-" + time_stamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".csv")
  45. print("--- END ---")