import torch import torchvision # FOR DATA from utils.preprocess import prepare_datasets from utils.show_image import show_image from torch.utils.data import DataLoader from torchvision import datasets from torch import nn import torch.nn.functional as F from torchvision.transforms import ToTensor # import nonechucks as nc # Used to load data in pytorch even when images are corrupted / unavailable (skips them) # FOR IMAGE VISUALIZATION import nibabel as nib # GENERAL PURPOSE import os import pandas as pd import numpy as np import matplotlib.pyplot as plt import glob from datetime import datetime # FOR TRAINING import torch.optim as optim import utils.models as models import utils.layers as ly #FOR TESTING import torchsummary print("--- RUNNING ---") print("Pytorch Version: " + torch. __version__) # data & training properties: val_split = 0.2 # % of val and test, rest will be train runs = 1 epochs = 100 time_stamp = timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') seeds = [np.random.randint(0, 1000) for _ in range(runs)] mri_datapath = './ADNI_volumes_customtemplate_float32/' xls_file = './Lp_ADNIMERGE.csv' # TODO: Datasets include multiple labels, such as medical info def evaluate_model(seed): training_data, val_data, test_data = prepare_datasets(mri_datapath, xls_file, val_split, seed) batch_size = 64 # Create data loaders train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True) val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True) model_CNN = models.CNN_Net(1, 1, 0.5).double() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model_CNN.parameters(), lr=0.001) print("Seed: ", seed) epoch_number = 0 for epoch in range(epochs): running_loss = 0.0 for i, data in enumerate(train_dataloader, 0): mri, xls, label = data optimizer.zero_grad() mri = mri.double() xls = xls.double() outputs = model_CNN((mri, xls)) loss = criterion(outputs, label) loss.backward() optimizer.step() running_loss += loss.item() if i % 1000 == 999: print("Epoch: ", epoch_number, "Batch: ", i+1, "Loss: ", running_loss / 1000, "Accuracy: ", ) print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 1000)) running_loss = 0.0 epoch_number += 1 #Test model correct = 0 total = 0 with torch.no_grad(): for data in test_dataloader: images, labels = data outputs = model_CNN(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print("Model Accuracy: ", 100 * correct / total) for seed in seeds: evaluate_model(seed) print("--- END ---")