|
@@ -5,7 +5,7 @@ import pandas as pd
|
|
import random
|
|
import random
|
|
# import torch
|
|
# import torch
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data import Dataset
|
|
-import torchvision.transforms as transforms
|
|
|
|
|
|
+from torchvision.transforms import v2
|
|
import re
|
|
import re
|
|
|
|
|
|
|
|
|
|
@@ -63,20 +63,17 @@ def prepare_datasets(mri_dir, val_split=0.2, seed=50):
|
|
print(f"Validation size: {len(val_list)}")
|
|
print(f"Validation size: {len(val_list)}")
|
|
print(f"Test size: {len(test_list)}")
|
|
print(f"Test size: {len(test_list)}")
|
|
|
|
|
|
|
|
+ transformation = v2.Compose([
|
|
|
|
+ v2.Normalize([0.5],[0.5]), # TODO Get Vals from dataset
|
|
|
|
+ # TODO CHOOSE WHAT TRANSFORMATIONS TO DO
|
|
|
|
+ ])
|
|
|
|
|
|
- # # TRANSFORM
|
|
|
|
- # transform = transforms.Compose([
|
|
|
|
- # transforms.Grayscale(num_output_channels=1)
|
|
|
|
- # ])
|
|
|
|
-
|
|
|
|
- train_dataset = CustomDataset(train_list)
|
|
|
|
- val_dataset = CustomDataset(val_list)
|
|
|
|
- test_dataset = CustomDataset(test_list)
|
|
|
|
|
|
+ train_dataset = CustomDataset(train_list, transformation)
|
|
|
|
+ val_dataset = CustomDataset(val_list, transformation)
|
|
|
|
+ test_dataset = CustomDataset(test_list, transformation)
|
|
|
|
|
|
return train_dataset, val_dataset, test_dataset
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
|
|
- # TODO Normalize data? Later add
|
|
|
|
-
|
|
|
|
|
|
|
|
def prepare_predict(mri_dir, IDs):
|
|
def prepare_predict(mri_dir, IDs):
|
|
|
|
|
|
@@ -133,8 +130,9 @@ def get_train_val_test(AD_list, NL_list, val_split):
|
|
|
|
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
class CustomDataset(Dataset):
|
|
- def __init__(self, list):
|
|
|
|
|
|
+ def __init__(self, list, transform):
|
|
self.data = list # INPUT DATA: (image_dir, class_id, (clinical_data))
|
|
self.data = list # INPUT DATA: (image_dir, class_id, (clinical_data))
|
|
|
|
+ self.transform = transform
|
|
|
|
|
|
def __len__(self):
|
|
def __len__(self):
|
|
return len(self.data)
|
|
return len(self.data)
|
|
@@ -144,6 +142,7 @@ class CustomDataset(Dataset):
|
|
mri = nib.load(mri_path)
|
|
mri = nib.load(mri_path)
|
|
image = np.asarray(mri.dataobj)
|
|
image = np.asarray(mri.dataobj)
|
|
mri_data = np.asarray(np.expand_dims(image, axis=0))
|
|
mri_data = np.asarray(np.expand_dims(image, axis=0))
|
|
|
|
+ mri_data = self.transform(mri_data)
|
|
|
|
|
|
# mri_data = mri.get_fdata()
|
|
# mri_data = mri.get_fdata()
|
|
# mri_array = np.array(mri)
|
|
# mri_array = np.array(mri)
|