浏览代码

Implemented dataset saving and loading

Nicholas Schense 3 月之前
父节点
当前提交
75b4ff2484
共有 3 个文件被更改,包括 80 次插入81 次删除
  1. 21 23
      ensemble_predict.py
  2. 48 44
      train_cnn.py
  3. 11 14
      utils/data/datasets.py

+ 21 - 23
ensemble_predict.py

@@ -7,31 +7,29 @@ import math
 import torch
 
 # CONFIGURATION
-if os.getenv("ADL_CONFIG_PATH") is None:
-    with open("config.toml", "rb") as f:
+if os.getenv('ADL_CONFIG_PATH') is None:
+    with open('config.toml', 'rb') as f:
         config = toml.load(f)
 else:
-    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
+    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
         config = toml.load(f)
 
 # Force cuDNN initialization
-force_init_cudnn(config["training"]["device"])
+force_init_cudnn(config['training']['device'])
 
 
-ensemble_folder = config["paths"]["model_output"] + config["ensemble"]["name"] + "/"
-models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
+ensemble_folder = (
+    config['paths']['model_output'] + config['ensemble']['name'] + '/models/'
+)
+models, model_descs = ens.load_models(ensemble_folder, config['training']['device'])
 models, model_descs = ens.prune_models(
-    models, model_descs, ensemble_folder, config["ensemble"]["prune_threshold"]
+    models, model_descs, ensemble_folder, config['ensemble']['prune_threshold']
 )
 
 # Load test data
-test_dataset = prepare_datasets(
-    config["paths"]["mri_data"],
-    config["paths"]["xls_data"],
-    config["dataset"]["validation_split"],
-    0,
-    config["training"]["device"],
-)[2]
+test_dataset = torch.load(
+    config['paths']['model_output'] + config['ensemble']['name'] + '/test_dataset.pt'
+)
 
 # Evaluate ensemble and uncertainty test set
 correct = 0
@@ -69,22 +67,22 @@ accuracy = correct / total
 with open(
     ensemble_folder
     + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
-    "w",
+    'w',
 ) as f:
-    f.write("Accuracy: " + str(accuracy) + "\n")
-    f.write("Correct: " + str(correct) + "\n")
-    f.write("Total: " + str(total) + "\n")
+    f.write('Accuracy: ' + str(accuracy) + '\n')
+    f.write('Correct: ' + str(correct) + '\n')
+    f.write('Total: ' + str(total) + '\n')
 
     for exp, pred, stdev in zip(actual, predictions, stdevs):
         f.write(
             str(exp)
-            + ", "
+            + ', '
             + str(pred)
-            + ", "
+            + ', '
             + str(stdev)
-            + ", "
+            + ', '
             + str(yes_votes)
-            + ", "
+            + ', '
             + str(no_votes)
-            + "\n"
+            + '\n'
         )

+ 48 - 44
train_cnn.py

@@ -23,57 +23,69 @@ from utils.system import force_init_cudnn
 
 
 # CONFIGURATION
-if os.getenv("ADL_CONFIG_PATH") is None:
-    with open("config.toml", "rb") as f:
+if os.getenv('ADL_CONFIG_PATH') is None:
+    with open('config.toml', 'rb') as f:
         config = toml.load(f)
 else:
-    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
+    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
         config = toml.load(f)
 
 # Force cuDNN initialization
-force_init_cudnn(config["training"]["device"])
+force_init_cudnn(config['training']['device'])
 # Generate seed for each set of runs
 seed = rand.randint(0, 1000)
 
-for i in range(config["training"]["runs"]):
+# Prepare data
+train_dataset, val_dataset, test_dataset = prepare_datasets(
+    config['paths']['mri_data'],
+    config['paths']['xls_data'],
+    config['dataset']['validation_split'],
+    seed,
+    config['training']['device'],
+)
+train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
+    train_dataset,
+    val_dataset,
+    test_dataset,
+    config['hyperparameters']['batch_size'],
+)
+
+# Save datasets
+model_folder_path = (
+    config['paths']['model_output'] + '/' + str(config['model']['name']) + '/'
+)
+
+if not os.path.exists(model_folder_path):
+    os.makedirs(model_folder_path)
+
+torch.save(train_dataset, model_folder_path + 'train_dataset.pt')
+torch.save(val_dataset, model_folder_path + 'val_dataset.pt')
+torch.save(test_dataset, model_folder_path + 'test_dataset.pt')
+
+for i in range(config['training']['runs']):
     # Set up the model
     model = (
         cnn.CNN(
-            config["model"]["image_channels"],
-            config["model"]["clin_data_channels"],
-            config["hyperparameters"]["droprate"],
+            config['model']['image_channels'],
+            config['model']['clin_data_channels'],
+            config['hyperparameters']['droprate'],
         )
         .float()
-        .to(config["training"]["device"])
+        .to(config['training']['device'])
     )
     criterion = nn.BCELoss()
     optimizer = optim.Adam(
-        model.parameters(), lr=config["hyperparameters"]["learning_rate"]
+        model.parameters(), lr=config['hyperparameters']['learning_rate']
     )
 
-    # Prepare data
-    train_dataset, val_dataset, test_dataset = prepare_datasets(
-        config["paths"]["mri_data"],
-        config["paths"]["xls_data"],
-        config["dataset"]["validation_split"],
-        seed,
-        config["training"]["device"],
-    )
-    train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
-        train_dataset,
-        val_dataset,
-        test_dataset,
-        config["hyperparameters"]["batch_size"],
-    )
+    runs_num = config['training']['runs']
 
-    runs_num = config["training"]["runs"]
-
-    if not config["operation"]["silent"]:
-        print(f"Training model {i + 1} / {runs_num} with seed {seed}...")
+    if not config['operation']['silent']:
+        print(f'Training model {i + 1} / {runs_num} with seed {seed}...')
 
     # Train the model
     with warnings.catch_warnings():
-        warnings.simplefilter("ignore")
+        warnings.simplefilter('ignore')
 
         history = train.train_model(
             model, train_dataloader, val_dataloader, criterion, optimizer, config
@@ -84,31 +96,23 @@ for i in range(config["training"]["runs"]):
 
     # Save model
     if not os.path.exists(
-        config["paths"]["model_output"] + "/" + str(config["model"]["name"])
+        config['paths']['model_output'] + '/' + str(config['model']['name'])
     ):
         os.makedirs(
-            config["paths"]["model_output"] + "/" + str(config["model"]["name"])
+            config['paths']['model_output'] + '/' + str(config['model']['name'])
         )
 
-    model_save_path = (
-        config["paths"]["model_output"]
-        + "/"
-        + str(config["model"]["name"])
-        + "/"
-        + str(i + 1)
-        + "_s-"
-        + str(seed)
-    )
+    model_save_path = model_folder_path + 'models/' + str(i + 1) + '_s-' + str(seed)
 
     torch.save(
         model,
-        model_save_path + ".pt",
+        model_save_path + '.pt',
     )
 
     history.to_csv(
-        model_save_path + "_history.csv",
+        model_save_path + '_history.csv',
         index=True,
     )
 
-    with open(model_save_path + "_test_acc.txt", "w") as f:
-        f.write(str(tes_acc))
+    with open(model_save_path + 'summary.txt', 'a') as f:
+        f.write(f'{i + 1}: Test Accuracy: {tes_acc}\n')

+ 11 - 14
utils/data/datasets.py

@@ -16,20 +16,21 @@ Prepares CustomDatasets for training, validating, and testing CNN
 """
 
 
-def prepare_datasets(
-    mri_dir, xls_file, val_split=0.2, seed=50, device=torch.device("cpu")
-):
+def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50, device=None):
+    if device is None:
+        device = torch.device('cpu')
+
     rndm = random.Random(seed)
-    xls_data = pd.read_csv(xls_file).set_index("Image Data ID")
-    raw_data = glob.glob(mri_dir + "*")
+    xls_data = pd.read_csv(xls_file).set_index('Image Data ID')
+    raw_data = glob.glob(mri_dir + '*')
     AD_list = []
     NL_list = []
 
     # TODO Check that image is in CSV?
     for image in raw_data:
-        if "NL" in image:
+        if 'NL' in image:
             NL_list.append(image)
-        elif "AD" in image:
+        elif 'AD' in image:
             AD_list.append(image)
 
     rndm.shuffle(AD_list)
@@ -37,10 +38,6 @@ def prepare_datasets(
 
     train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
 
-    rndm.shuffle(train_list)
-    rndm.shuffle(val_list)
-    rndm.shuffle(test_list)
-
     train_dataset = ADNIDataset(train_list, xls_data, device=device)
     val_dataset = ADNIDataset(val_list, xls_data, device=device)
     test_dataset = ADNIDataset(test_list, xls_data, device=device)
@@ -93,7 +90,7 @@ def get_train_val_test(AD_list, NL_list, val_split):
 
 
 class ADNIDataset(Dataset):
-    def __init__(self, mri, xls: pd.DataFrame, device=torch.device("cpu")):
+    def __init__(self, mri, xls: pd.DataFrame, device=torch.device('cpu')):
         self.mri_data = mri  # DATA IS A LIST WITH TUPLES (image_dir, class_id)
         self.xls_data = xls
         self.device = device
@@ -105,9 +102,9 @@ class ADNIDataset(Dataset):
         # Get used data
 
         # data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']]
-        data = xls_data.loc[["Sex", "Age (current)"]]
+        data = xls_data.loc[['Sex', 'Age (current)']]
 
-        data.replace({"M": 0, "F": 1}, inplace=True)
+        data.replace({'M': 0, 'F': 1}, inplace=True)
 
         # Convert to tensor
         xls_tensor = torch.tensor(data.values.astype(float))