123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- # NEEDS TO BE FINISHED
- # TODO CHECK ABOUT IMAGE DIMENSIONS
- # TODO ENSURE ITERATION WORKS
- import glob
- import nibabel as nib
- import numpy as np
- import random
- import torch
- from torch.utils.data import Dataset
- import pandas as pd
- from torch.utils.data import DataLoader
- '''
- Prepares CustomDatasets for training, validating, and testing CNN
- '''
- def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50):
- rndm = random.Random(seed)
- xls_data = pd.read_csv(xls_file).set_index('Image Data ID')
- raw_data = glob.glob(mri_dir + "*")
- AD_list = []
- NL_list = []
- # TODO Check that image is in CSV?
- for image in raw_data:
- if "NL" in image:
- NL_list.append(image)
- elif "AD" in image:
- AD_list.append(image)
- rndm.shuffle(AD_list)
- rndm.shuffle(NL_list)
- train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
- rndm.shuffle(train_list)
- rndm.shuffle(val_list)
- rndm.shuffle(test_list)
- train_dataset = ADNIDataset(train_list, xls_data)
- val_dataset = ADNIDataset(val_list, xls_data)
- test_dataset = ADNIDataset(test_list, xls_data)
- return train_dataset, val_dataset, test_dataset
- # TODO Normalize data? Later add / Exctract clinical data? Which data?
- '''
- Returns train_list, val_list and test_list in format [(image, id), ...] each
- '''
- def get_train_val_test(AD_list, NL_list, val_split):
- train_list, val_list, test_list = [], [], []
- num_test_ad = int(len(AD_list) * val_split)
- num_test_nl = int(len(NL_list) * val_split)
- num_val_ad = int((len(AD_list) - num_test_ad) * val_split)
- num_val_nl = int((len(NL_list) - num_test_nl) * val_split)
- # Sets up ADs
- for image in AD_list[0:num_val_ad]:
- val_list.append((image, 1))
- for image in AD_list[num_val_ad:num_test_ad]:
- test_list.append((image, 1))
- for image in AD_list[num_test_ad:]:
- train_list.append((image, 1))
- # Sets up NLs
- for image in NL_list[0:num_val_nl]:
- val_list.append((image, 0))
- for image in NL_list[num_val_nl:num_test_nl]:
- test_list.append((image, 0))
- for image in NL_list[num_test_nl:]:
- train_list.append((image, 0))
- return train_list, val_list, test_list
- class ADNIDataset(Dataset):
- def __init__(self, mri, xls: pd.DataFrame):
- self.mri_data = mri # DATA IS A LIST WITH TUPLES (image_dir, class_id)
- self.xls_data = xls
- def __len__(self):
- return len(self.mri_data)
-
- def _xls_to_tensor(self, xls_data: pd.Series):
- #Get used data
- #data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']]
- data = xls_data.loc[['Sex', 'Age (current)']]
-
- data.replace({'M': 0, 'F': 1}, inplace=True)
-
- #Convert to tensor
- xls_tensor = torch.tensor(data.values.astype(float))
-
- return xls_tensor
- def __getitem__(self, idx): # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
- mri_path, class_id = self.mri_data[idx]
- mri = nib.load(mri_path)
- mri_data = mri.get_fdata()
- xls = self.xls_data.iloc[idx]
- #Convert xls data to tensor
- xls_tensor = self._xls_to_tensor(xls)
- mri_tensor = torch.from_numpy(mri_data).unsqueeze(0)
-
- class_id = torch.tensor([class_id])
- #Convert to one-hot and squeeze
- class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
-
- #Convert to float
- mri_tensor = mri_tensor.float()
- xls_tensor = xls_tensor.float()
- class_id = class_id.float()
- return (mri_tensor, xls_tensor), class_id
-
-
- def initalize_dataloaders(training_data, val_data, test_data, cuda_device=torch.device('cuda:0'), batch_size=64):
- train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
- test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True)
- val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
- return train_dataloader, val_dataloader, test_dataloader
|