123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- import glob
- import nibabel as nib
- import numpy as np
- import pandas as pd
- import random
- # import torch
- from torch.utils.data import Dataset
- from torchvision.transforms import v2
- import re
- '''
- Prepares CustomDatasets for training, validating, and testing CNN
- '''
- def prepare_datasets(mri_dir, val_split=0.2, seed=50):
- rndm = random.Random(seed)
- csv = pd.read_csv("LP_ADNIMERGE.csv")
- raw_data = glob.glob(mri_dir + "*")
- AD_list = []
- NL_list = []
- print("--- DATA INFO ---")
- print("Amount of images: " + str(len(raw_data)))
- for image in raw_data:
- # FIND IMAGE IN CSV, GET IT'S CLINICAL DATA
- image_ID = re.search(r"_I(\d+)_", image).group(1)
- if(image_ID==None): raise RuntimeError("Image not found in CSV!")
- image_data = csv[csv["Image Data ID"] == int(image_ID)].loc[:,['Sex', 'Age (current)']] # M: 0, F: 1
- image_data = image_data.iloc[0].tolist()
- if (image_data[0] == 'M'):
- image_data[0] = 0
- elif(image_data[0] == 'F'):
- image_data[0] = 1
- else:
- raise RuntimeError("Incorrect sex in an image")
- image_data = tuple(image_data)
- if "NL" in image:
- NL_list.append((image, 0, image_data))
- elif "AD" in image:
- AD_list.append((image, 1, image_data))
- print("Total AD: " + str(len(AD_list)))
- print("Total NL: " + str(len(NL_list)))
- rndm.shuffle(AD_list)
- rndm.shuffle(NL_list)
- train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
- rndm.shuffle(train_list)
- rndm.shuffle(val_list)
- rndm.shuffle(test_list)
- print(f"DATA INITIALIZATION")
- print(f"Training size: {len(train_list)}")
- print(f"Validation size: {len(val_list)}")
- print(f"Test size: {len(test_list)}")
- transformation = v2.Compose([
- v2.Normalize([0.5],[0.5]), # TODO Get Vals from dataset
- # TODO CHOOSE WHAT TRANSFORMATIONS TO DO
- ])
- train_dataset = CustomDataset(train_list, transformation)
- val_dataset = CustomDataset(val_list, transformation)
- test_dataset = CustomDataset(test_list, transformation)
- return train_dataset, val_dataset, test_dataset
- def prepare_predict(mri_dir, IDs):
- raw_data = glob.glob(mri_dir + "*")
- image_list = []
- # Gets all images and prepares them for Dataset
- for ID in IDs:
- pattern = re.compile(ID)
- matches = [item for item in raw_data if pattern.search(item)]
- if (len(matches) != 1): print("No image found, or more than one")
- for match in matches:
- if "NL" in match: image_list.append((match, 0))
- if "AD" in match: image_list.append((match, 1))
- return CustomDataset(image_list)
- '''
- Returns train_list, val_list and test_list in format [(image, id), ...] each
- '''
- def get_train_val_test(AD_list, NL_list, val_split):
- train_list, val_list, test_list = [], [], []
- num_val_ad = int(len(AD_list) * val_split/2)
- num_val_nl = int(len(NL_list) * val_split/2)
- num_test_ad = int((len(AD_list) - num_val_ad) * val_split)
- num_test_nl = int((len(NL_list) - num_val_nl) * val_split)
- # Sets up ADs
- for image in AD_list[0:num_test_ad]:
- test_list.append(image)
- for image in AD_list[num_test_ad:num_test_ad+num_val_ad]:
- val_list.append(image)
- for image in AD_list[num_test_ad+num_val_ad:]:
- train_list.append(image)
- # Sets up NLs
- for image in NL_list[0:num_test_nl]:
- test_list.append(image)
- for image in NL_list[num_test_nl:num_test_nl+num_val_nl]:
- val_list.append(image)
- for image in NL_list[num_test_nl+num_val_nl:]:
- train_list.append(image)
- return train_list, val_list, test_list
- class CustomDataset(Dataset):
- def __init__(self, list, transform):
- self.data = list # INPUT DATA: (image_dir, class_id, (clinical_data))
- self.transform = transform
- def __len__(self):
- return len(self.data)
- def __getitem__(self, idx): # RETURNS TUPLE: ((mri_data, [clinical_data]), class_id)
- mri_path, class_id, clinical_data = self.data[idx]
- mri = nib.load(mri_path)
- image = np.asarray(mri.dataobj)
- mri_data = np.asarray(np.expand_dims(image, axis=0))
- mri_data = self.transform(mri_data)
- # mri_data = mri.get_fdata()
- # mri_array = np.array(mri)
- # mri_tensor = torch.from_numpy(mri_array)
- # class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)??
- return (mri_data, clinical_data), class_id
|