123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import torch
- from torch import nn
- # GENERAL PURPOSE
- import numpy as np
- from datetime import datetime
- import pandas as pd
- import os
- # FOR TRAINING
- import torch.optim as optim
- import utils.models as models
- from utils.training import train_model, test_model, initalize_dataloaders, plot_results
- #Set Default GPU
- cuda_device = torch.device('cuda:1')
- torch.set_default_device(cuda_device)
- print("--- RUNNING ---")
- print("Pytorch Version: " + torch. __version__)
- # data & training properties:
- val_split = 0.2 # % of val and test, rest will be train
- runs = 1
- epochs = 30
- seeds = [np.random.randint(0, 1000) for _ in range(runs)]
- #Data Path
- mri_path = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/'
- #Local Path
- local_path = '/export/home/nschense/alzheimers/Pytorch_CNN-RNN'
- xls_path = local_path + '/LP_ADNIMERGE.csv'
- saved_model_path = local_path + '/saved_models/'
- plot_path = local_path + '/plots/'
- training_record_path = local_path + '/training_records/'
- DEBUG = False
- model_CNN = models.CNN_Net(1, 2, 0.5).to(cuda_device)
- criterion = nn.BCELoss()
- optimizer = optim.Adam(model_CNN.parameters(), lr=0.001)
-
- for seed in seeds:
- time_stamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- train_loader, val_loader, test_loader = initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=cuda_device)
- train_results = train_model(model_CNN, seed, time_stamp, epochs, train_loader, val_loader, saved_model_path, "CNN", optimizer, criterion, cuda_device=cuda_device)
- test_model(model_CNN, test_loader, cuda_device=cuda_device)
-
- #Plot results
- plot_results(train_results["train_acc"], train_results["train_loss"], train_results["val_acc"], train_results["val_loss"], "CNN", time_stamp, plot_path)
-
- #Save training results
- if not os.path.exists(training_record_path):
- os.makedirs(training_record_path)
- train_results.to_csv(training_record_path + "CNN_t-" + time_stamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".csv")
-
-
-
- print("--- END ---")
|