import glob import nibabel as nib import numpy as np import pandas as pd import random # import torch from torch.utils.data import Dataset from torchvision.transforms import v2 import re ''' Prepares CustomDatasets for training, validating, and testing CNN ''' def prepare_datasets(mri_dir, val_split=0.2, seed=50): rndm = random.Random(seed) csv = pd.read_csv("LP_ADNIMERGE.csv") raw_data = glob.glob(mri_dir + "*") AD_list = [] NL_list = [] print("--- DATA INFO ---") print("Amount of images: " + str(len(raw_data))) for image in raw_data: # FIND IMAGE IN CSV, GET IT'S CLINICAL DATA image_ID = re.search(r"_I(\d+)_", image).group(1) if(image_ID==None): raise RuntimeError("Image not found in CSV!") image_data = csv[csv["Image Data ID"] == int(image_ID)].loc[:,['Sex', 'Age (current)']] # M: 0, F: 1 image_data = image_data.iloc[0].tolist() if (image_data[0] == 'M'): image_data[0] = 0 elif(image_data[0] == 'F'): image_data[0] = 1 else: raise RuntimeError("Incorrect sex in an image") image_data = tuple(image_data) if "NL" in image: NL_list.append((image, 0, image_data)) elif "AD" in image: AD_list.append((image, 1, image_data)) print("Total AD: " + str(len(AD_list))) print("Total NL: " + str(len(NL_list))) rndm.shuffle(AD_list) rndm.shuffle(NL_list) train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split) rndm.shuffle(train_list) rndm.shuffle(val_list) rndm.shuffle(test_list) print(f"DATA INITIALIZATION") print(f"Training size: {len(train_list)}") print(f"Validation size: {len(val_list)}") print(f"Test size: {len(test_list)}") transformation = v2.Compose([ v2.Normalize([0.5],[0.5]), # TODO Get Vals from dataset # TODO CHOOSE WHAT TRANSFORMATIONS TO DO ]) train_dataset = CustomDataset(train_list, transformation) val_dataset = CustomDataset(val_list, transformation) test_dataset = CustomDataset(test_list, transformation) return train_dataset, val_dataset, test_dataset def prepare_predict(mri_dir, IDs): raw_data = glob.glob(mri_dir + "*") image_list = [] # Gets all images and prepares them for Dataset for ID in IDs: pattern = re.compile(ID) matches = [item for item in raw_data if pattern.search(item)] if (len(matches) != 1): print("No image found, or more than one") for match in matches: if "NL" in match: image_list.append((match, 0)) if "AD" in match: image_list.append((match, 1)) return CustomDataset(image_list) ''' Returns train_list, val_list and test_list in format [(image, id), ...] each ''' def get_train_val_test(AD_list, NL_list, val_split): train_list, val_list, test_list = [], [], [] num_val_ad = int(len(AD_list) * val_split/2) num_val_nl = int(len(NL_list) * val_split/2) num_test_ad = int((len(AD_list) - num_val_ad) * val_split) num_test_nl = int((len(NL_list) - num_val_nl) * val_split) # Sets up ADs for image in AD_list[0:num_test_ad]: test_list.append(image) for image in AD_list[num_test_ad:num_test_ad+num_val_ad]: val_list.append(image) for image in AD_list[num_test_ad+num_val_ad:]: train_list.append(image) # Sets up NLs for image in NL_list[0:num_test_nl]: test_list.append(image) for image in NL_list[num_test_nl:num_test_nl+num_val_nl]: val_list.append(image) for image in NL_list[num_test_nl+num_val_nl:]: train_list.append(image) return train_list, val_list, test_list class CustomDataset(Dataset): def __init__(self, list, transform): self.data = list # INPUT DATA: (image_dir, class_id, (clinical_data)) self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): # RETURNS TUPLE: ((mri_data, [clinical_data]), class_id) mri_path, class_id, clinical_data = self.data[idx] mri = nib.load(mri_path) image = np.asarray(mri.dataobj) mri_data = np.asarray(np.expand_dims(image, axis=0)) mri_data = self.transform(mri_data) # mri_data = mri.get_fdata() # mri_array = np.array(mri) # mri_tensor = torch.from_numpy(mri_array) # class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)?? return (mri_data, clinical_data), class_id