main.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import torch
  2. import torchvision
  3. # FOR DATA
  4. from utils.preprocess import prepare_datasets
  5. from utils.show_image import show_image
  6. from torch.utils.data import DataLoader
  7. from torchvision import datasets
  8. from torch import nn
  9. import torch.nn.functional as F
  10. from torchvision.transforms import ToTensor
  11. # import nonechucks as nc # Used to load data in pytorch even when images are corrupted / unavailable (skips them)
  12. # FOR IMAGE VISUALIZATION
  13. import nibabel as nib
  14. # GENERAL PURPOSE
  15. import os
  16. import pandas as pd
  17. import numpy as np
  18. import matplotlib.pyplot as plt
  19. import glob
  20. from datetime import datetime
  21. # FOR TRAINING
  22. import torch.optim as optim
  23. import utils.models as models
  24. import utils.layers as ly
  25. #FOR TESTING
  26. import torchsummary
  27. print("--- RUNNING ---")
  28. print("Pytorch Version: " + torch. __version__)
  29. # data & training properties:
  30. val_split = 0.2 # % of val and test, rest will be train
  31. runs = 1
  32. epochs = 100
  33. time_stamp = timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  34. seeds = [np.random.randint(0, 1000) for _ in range(runs)]
  35. mri_datapath = './ADNI_volumes_customtemplate_float32/'
  36. xls_file = './Lp_ADNIMERGE.csv'
  37. # TODO: Datasets include multiple labels, such as medical info
  38. def evaluate_model(seed):
  39. training_data, val_data, test_data = prepare_datasets(mri_datapath, xls_file, val_split, seed)
  40. batch_size = 64
  41. # Create data loaders
  42. train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
  43. test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
  44. val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
  45. model_CNN = models.CNN_Net(1, 1, 0.5)
  46. criterion = nn.CrossEntropyLoss()
  47. optimizer = optim.Adam(model_CNN.parameters(), lr=0.001)
  48. print("Seed: ", seed)
  49. epoch_number = 0
  50. for epoch in range(epochs):
  51. running_loss = 0.0
  52. for i, data in enumerate(train_dataloader, 0):
  53. mri, xls, label = data
  54. optimizer.zero_grad()
  55. outputs = model_CNN((mri, xls))
  56. loss = criterion(outputs, label)
  57. loss.backward()
  58. optimizer.step()
  59. running_loss += loss.item()
  60. if i % 1000 == 999:
  61. print("Epoch: ", epoch_number, "Batch: ", i+1, "Loss: ", running_loss / 1000, "Accuracy: ", )
  62. print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 1000))
  63. running_loss = 0.0
  64. epoch_number += 1
  65. #Test model
  66. correct = 0
  67. total = 0
  68. with torch.no_grad():
  69. for data in test_dataloader:
  70. images, labels = data
  71. outputs = model_CNN(images)
  72. _, predicted = torch.max(outputs.data, 1)
  73. total += labels.size(0)
  74. correct += (predicted == labels).sum().item()
  75. print("Model Accuracy: ", 100 * correct / total)
  76. for seed in seeds:
  77. evaluate_model(seed)
  78. print("--- END ---")