|
@@ -23,18 +23,20 @@ else:
|
|
|
# This function returns a list of the accuracies given a threshold
|
|
|
def threshold(config):
|
|
|
# First, get the model data
|
|
|
- ts, vs, test_set = prepare_datasets(
|
|
|
- config['paths']['mri_data'],
|
|
|
- config['paths']['xls_data'],
|
|
|
- config['dataset']['validation_split'],
|
|
|
- 944,
|
|
|
- config['training']['device'],
|
|
|
+ test_set = torch.load(
|
|
|
+ config['paths']['model_output']
|
|
|
+ + config['ensemble']['name']
|
|
|
+ + '/test_dataset.pt'
|
|
|
+ )
|
|
|
+
|
|
|
+ vs = torch.load(
|
|
|
+ config['paths']['model_output'] + config['ensemble']['name'] + '/val_dataset.pt'
|
|
|
)
|
|
|
|
|
|
test_set = test_set + vs
|
|
|
|
|
|
models, _ = ens.load_models(
|
|
|
- config['paths']['model_output'] + config['ensemble']['name'] + '/',
|
|
|
+ config['paths']['model_output'] + config['ensemble']['name'] + '/models/',
|
|
|
config['training']['device'],
|
|
|
)
|
|
|
|