浏览代码

General work - forgot to commit!

Nicholas Schense 5 月之前
父节点
当前提交
3a42f81d92
共有 10 个文件被更改,包括 507 次插入26 次删除
  1. 1 0
      .gitignore
  2. 5 0
      README.md
  3. 63 0
      bayesian.py
  4. 6 4
      config.toml
  5. 10 0
      coverage.csv
  6. 二进制
      coverage.png
  7. 232 0
      predictions.csv
  8. 162 0
      threshold.py
  9. 3 3
      train_cnn.py
  10. 25 19
      utils/data/datasets.py

+ 1 - 0
.gitignore

@@ -1,5 +1,6 @@
 #Custom gitignore
 saved_models/
+nohup.out
 
 
 # Byte-compiled / optimized / DLL files

+ 5 - 0
README.md

@@ -1 +1,6 @@
 # Alzheimers Diagnosis Neural Net Project Rewrite
+
+## TODO
+- [ ] Implement config saving for ensembles
+- [ ] Run more models
+- [ ] 

+ 63 - 0
bayesian.py

@@ -0,0 +1,63 @@
+from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
+
+import torch
+import torch.nn as nn
+import os
+import tomli as toml
+from tqdm import tqdm
+
+from utils.models import cnn
+from utils.data.datasets import prepare_datasets, initalize_dataloaders
+
+# CONFIGURATION
+if os.getenv("ADL_CONFIG_PATH") is None:
+    with open("config.toml", "rb") as f:
+        config = toml.load(f)
+else:
+    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
+        config = toml.load(f)
+
+
+model = cnn.CNN()
+
+# Convert the model to a Bayesian model
+model = dnn_to_bnn(model, prior_mu=0, prior_sigma=0.1)
+
+criterion = nn.CrossEntropyLoss()
+optimizer = torch.optim.Adam(
+    model.parameters(), config["hyperparameters"]["learning_rate"]
+)
+
+
+train_set, val_set, test_set = prepare_datasets(
+    config["paths"]["mri_data"],
+    config["paths"]["xls_data"],
+    config["dataset"]["validation_split"],
+    config["training"]["device"],
+)
+train_loader, val_loader, test_loader = initalize_dataloaders(
+    train_set, val_set, test_set, config["training"]["batch_size"]
+)
+
+# Train the model
+
+for epoch in range(config["training"]["epochs"]):
+    print(f"Epoch {epoch + 1} / {config['training']['epochs']}")
+    model.train()
+    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
+        optimizer.zero_grad()
+        output = model(data)
+        loss = criterion(output, target)
+        loss += get_kl_loss(model, config["hyperparameters"]["kl_weight"])
+        loss = loss / len(data)
+        loss.backward()
+        optimizer.step()
+        
+
+#Test the model
+model.eval()
+with torch.no_grad():
+    output_li
+        
+    
+

+ 6 - 4
config.toml

@@ -11,10 +11,12 @@ runs = 10
 max_epochs = 10
 
 [dataset]
-validation_split = 0.3
+validation_split = 0.4 #Splits the dataset into the train and validation/test set, 50% each for validation and test
+#|---TEST---|---VALIDATION---|---TRAIN---|
+#|splt*0.5  | split*0.5      | 1-split   |
 
 [model]
-name = 'cnn-ensemble10x10'
+name = 'cnn-100x30'
 image_channels = 1
 clin_data_channels = 2
 
@@ -27,5 +29,5 @@ droprate = 0.5
 silent = false
 
 [ensemble]
-name = 'cnn-ensemble10x10'
-prune_threshold = 0.7      # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
+name = 'cnn-100x30'
+prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning

+ 10 - 0
coverage.csv

@@ -0,0 +1,10 @@
+,Threshold,Accuracy,Quantile,F1,AUC
+0,0.07056796699762344,1.0,0.1,1.0,1.0
+1,0.0776021957397461,0.9354838709677419,0.2,0.9666666666666667,0.9947712418300654
+2,0.0822245217859745,0.9139784946236559,0.30000000000000004,0.9550561797752809,0.98
+3,0.08846938908100128,0.9105691056910569,0.4,0.9531914893617022,0.9721840659340659
+4,0.09666759520769119,0.9285714285714286,0.5,0.9629629629629629,0.9819397993311036
+5,0.10368437767028807,0.9405405405405406,0.6,0.9693593314763231,0.9885496183206106
+6,0.11318791583180428,0.9441860465116279,0.7000000000000001,0.9712918660287081,0.9910283619525372
+7,0.12286047786474227,0.9512195121951219,0.8,0.975,0.9934722222222222
+8,0.13621583431959153,0.9530685920577617,0.9,0.9759704251386322,0.9944568809295384

二进制
coverage.png


+ 232 - 0
predictions.csv

@@ -0,0 +1,232 @@
+,Prediction,Actual,Stdev,Correct
+30,0.08064605295658112,0.0,0.05030234903097153,True
+75,0.08053117245435715,0.0,0.051269371062517166,True
+76,0.07893505692481995,0.0,0.05134830251336098,True
+189,0.08240099996328354,0.0,0.053257476538419724,True
+114,0.08727512508630753,0.0,0.0539369061589241,True
+215,0.0875755175948143,0.0,0.056368160992860794,True
+45,0.09081317484378815,0.0,0.0573900043964386,True
+22,0.10104575753211975,0.0,0.058140166103839874,True
+106,0.10128451883792877,0.0,0.0595875158905983,True
+120,0.10675643384456635,0.0,0.060996584594249725,True
+49,0.10084360092878342,0.0,0.061394911259412766,True
+95,0.10570310056209564,0.0,0.06182017922401428,True
+127,0.10709960013628006,0.0,0.06241276487708092,True
+176,0.11404267698526382,0.0,0.06323474645614624,True
+47,0.11116322875022888,0.0,0.0641709491610527,True
+162,0.1157153844833374,0.0,0.06449578702449799,True
+110,0.10752685368061066,0.0,0.06566949933767319,True
+222,0.9176755547523499,1.0,0.06679949164390564,True
+35,0.9077042937278748,1.0,0.06700856983661652,True
+84,0.12747898697853088,0.0,0.06788456439971924,True
+70,0.11196906864643097,0.0,0.06800825148820877,True
+51,0.9217737913131714,1.0,0.06806794553995132,True
+193,0.12001504749059677,0.0,0.06834182143211365,True
+164,0.11728300899267197,0.0,0.06834923475980759,True
+147,0.12182613462209702,0.0,0.06854480504989624,True
+187,0.12404181063175201,0.0,0.06857491284608841,True
+41,0.12975123524665833,0.0,0.06927058100700378,True
+200,0.12243807315826416,0.0,0.07033567875623703,True
+42,0.13068115711212158,0.0,0.07091105729341507,True
+135,0.12150847911834717,0.0,0.07092347741127014,True
+220,0.14413994550704956,0.0,0.07116097211837769,True
+98,0.12036634236574173,0.0,0.07116296142339706,True
+197,0.13115964829921722,0.0,0.07131355255842209,True
+4,0.13327959179878235,0.0,0.07139390707015991,True
+212,0.912270724773407,1.0,0.07254747301340103,True
+171,0.9018120169639587,1.0,0.07269398123025894,True
+186,0.1429961919784546,0.0,0.07408874481916428,True
+86,0.9191048741340637,1.0,0.07431873679161072,True
+138,0.12933339178562164,0.0,0.07432941347360611,True
+90,0.13462725281715393,0.0,0.07481259107589722,True
+151,0.1324373334646225,0.0,0.07491683959960938,True
+66,0.1392369568347931,0.0,0.07665596902370453,True
+152,0.9079922437667847,1.0,0.07676095515489578,True
+89,0.12983743846416473,0.0,0.076902374625206,True
+56,0.15645432472229004,0.0,0.07711644470691681,True
+16,0.13897286355495453,0.0,0.07711756974458694,True
+185,0.41366952657699585,0.0,0.07714840024709702,True
+54,0.415810763835907,0.0,0.07721669971942902,True
+224,0.4179881811141968,0.0,0.07733853161334991,True
+105,0.4179896414279938,1.0,0.07733859866857529,False
+145,0.4053601622581482,1.0,0.07737884670495987,False
+96,0.9032279849052429,1.0,0.07767074555158615,True
+206,0.42247098684310913,0.0,0.07768560200929642,True
+124,0.4271637201309204,1.0,0.07808028161525726,False
+140,0.14246824383735657,0.0,0.07816833257675171,True
+132,0.16837045550346375,0.0,0.07818591594696045,True
+39,0.42956098914146423,1.0,0.07843437045812607,False
+3,0.14686696231365204,0.0,0.07875282317399979,True
+180,0.14776481688022614,0.0,0.078982874751091,True
+11,0.1500343233346939,0.0,0.079056516289711,True
+60,0.17619368433952332,0.0,0.0796256810426712,True
+227,0.42440199851989746,1.0,0.07999620586633682,False
+226,0.8894540071487427,1.0,0.08009441941976547,True
+38,0.8856544494628906,1.0,0.08010639250278473,True
+194,0.16516436636447906,0.0,0.08040788024663925,True
+134,0.4363965392112732,0.0,0.08042025566101074,True
+205,0.8954517841339111,1.0,0.08050797134637833,True
+166,0.15232831239700317,0.0,0.08051760494709015,True
+63,0.43758124113082886,0.0,0.08097495883703232,True
+74,0.14938735961914062,0.0,0.08104454725980759,True
+37,0.16830533742904663,0.0,0.08133205771446228,True
+55,0.43862971663475037,0.0,0.08157221972942352,True
+196,0.43862971663475037,0.0,0.08157221972942352,True
+111,0.8900614976882935,1.0,0.08185040950775146,True
+117,0.1613903045654297,0.0,0.0823027715086937,True
+91,0.4371449947357178,0.0,0.08239345252513885,True
+129,0.4371453821659088,0.0,0.08239352703094482,True
+12,0.1437193602323532,0.0,0.08240661025047302,True
+50,0.8827714323997498,1.0,0.08248955756425858,True
+183,0.1775076985359192,0.0,0.08252869546413422,True
+179,0.44035619497299194,1.0,0.08286683261394501,False
+144,0.44063085317611694,1.0,0.08292903006076813,False
+175,0.16922545433044434,0.0,0.08365222811698914,True
+168,0.1605507880449295,0.0,0.08430176228284836,True
+24,0.1561395227909088,0.0,0.08435925841331482,True
+34,0.16665950417518616,0.0,0.08514362573623657,True
+157,0.8834100365638733,1.0,0.08561325073242188,True
+61,0.8795567154884338,1.0,0.08580030500888824,True
+13,0.8699986338615417,1.0,0.08597418665885925,True
+141,0.16704654693603516,0.0,0.08632723987102509,True
+181,0.16343875229358673,0.0,0.08685992658138275,True
+64,0.15710464119911194,0.0,0.08711321651935577,True
+19,0.17187774181365967,0.0,0.08759700506925583,True
+113,0.8657154440879822,1.0,0.08773587644100189,True
+213,0.17988014221191406,0.0,0.08796938508749008,True
+221,0.8197951316833496,1.0,0.08860242366790771,True
+53,0.44390758872032166,0.0,0.08881814777851105,True
+65,0.17992322146892548,0.0,0.0888797715306282,True
+210,0.18149317800998688,0.0,0.08959396928548813,True
+158,0.864574134349823,1.0,0.0906127393245697,True
+46,0.8631331920623779,1.0,0.09067254513502121,True
+195,0.1460513472557068,0.0,0.09078666567802429,True
+133,0.19544997811317444,0.0,0.09096378087997437,True
+125,0.17670367658138275,0.0,0.0911884531378746,True
+126,0.878413200378418,1.0,0.09187516570091248,True
+214,0.16936077177524567,0.0,0.0920468270778656,True
+68,0.18125255405902863,0.0,0.09219114482402802,True
+130,0.194206103682518,0.0,0.09236117452383041,True
+128,0.18882118165493011,0.0,0.0924619808793068,True
+43,0.2029598355293274,0.0,0.09254654496908188,True
+184,0.17665491998195648,0.0,0.09256598353385925,True
+149,0.19111061096191406,0.0,0.09315288811922073,True
+69,0.8676502704620361,1.0,0.09340927004814148,True
+103,0.17929165065288544,0.0,0.09348834306001663,True
+154,0.18596304953098297,0.0,0.09386551380157471,True
+230,0.16378547251224518,0.0,0.09413038194179535,True
+78,0.8370398283004761,1.0,0.09446965903043747,True
+118,0.8590189218521118,1.0,0.09481213986873627,True
+67,0.8633883595466614,1.0,0.0950653925538063,True
+123,0.1977778524160385,0.0,0.09530282765626907,True
+88,0.8346447348594666,1.0,0.0953729897737503,True
+170,0.18664516508579254,0.0,0.09701987355947495,True
+33,0.8454685807228088,1.0,0.09758497774600983,True
+211,0.1807575225830078,0.0,0.09802558273077011,True
+148,0.22706817090511322,0.0,0.09901139885187149,True
+7,0.22125060856342316,0.0,0.09902060776948929,True
+228,0.1942971795797348,0.0,0.09912717342376709,True
+52,0.19154150784015656,0.0,0.09937300533056259,True
+207,0.1952042281627655,0.0,0.09943529963493347,True
+20,0.8431762456893921,1.0,0.09954006224870682,True
+122,0.19631659984588623,0.0,0.09989665448665619,True
+23,0.21731746196746826,0.0,0.09994123131036758,True
+177,0.2133563905954361,0.0,0.10010655969381332,True
+143,0.8304190039634705,1.0,0.10014497488737106,True
+153,0.8434374928474426,1.0,0.1012844443321228,True
+216,0.24578765034675598,0.0,0.10183314979076385,True
+99,0.8257945775985718,1.0,0.10200571268796921,True
+201,0.8297196626663208,1.0,0.10201609134674072,True
+203,0.8440954685211182,1.0,0.1024559885263443,True
+59,0.24426139891147614,0.0,0.10260660946369171,True
+217,0.23986373841762543,0.0,0.10291758179664612,True
+146,0.22073577344417572,0.0,0.1031135618686676,True
+218,0.8017507195472717,1.0,0.10448488593101501,True
+17,0.19681228697299957,0.0,0.10469834506511688,True
+163,0.2411302775144577,0.0,0.10534659028053284,True
+159,0.2364622801542282,0.0,0.10561579465866089,True
+165,0.23870216310024261,0.0,0.10607533156871796,True
+79,0.22018252313137054,0.0,0.10616738349199295,True
+97,0.8139686584472656,1.0,0.10667204856872559,True
+31,0.8227530717849731,1.0,0.10737413913011551,True
+137,0.250531405210495,0.0,0.10740893334150314,True
+62,0.8103730082511902,1.0,0.1082184836268425,True
+116,0.7791247367858887,1.0,0.10977177321910858,True
+1,0.2479267120361328,0.0,0.10978859663009644,True
+174,0.28649798035621643,0.0,0.11131870746612549,True
+208,0.23735979199409485,0.0,0.11184616386890411,True
+198,0.8244062662124634,1.0,0.11186125129461288,True
+83,0.7983875870704651,1.0,0.11210570484399796,True
+40,0.8164024353027344,1.0,0.11215054243803024,True
+93,0.8240574598312378,1.0,0.11299832910299301,True
+82,0.2304438203573227,0.0,0.11315932124853134,True
+169,0.7665874361991882,1.0,0.11339464783668518,True
+80,0.7979971170425415,1.0,0.11344446241855621,True
+29,0.7858783006668091,1.0,0.11361948400735855,True
+28,0.8008655905723572,1.0,0.11364077031612396,True
+142,0.7987484335899353,1.0,0.11443256586790085,True
+167,0.8054426312446594,1.0,0.11540836095809937,True
+136,0.788724958896637,1.0,0.11613089591264725,True
+192,0.7759780883789062,1.0,0.1170915886759758,True
+104,0.7606977820396423,1.0,0.11716601997613907,True
+219,0.8869320750236511,1.0,0.11787694692611694,True
+72,0.7505139708518982,1.0,0.11789941042661667,True
+44,0.7945115566253662,1.0,0.11895372718572617,True
+119,0.766346275806427,1.0,0.11902948468923569,True
+190,0.7719056010246277,1.0,0.11941025406122208,True
+225,0.7407894730567932,1.0,0.12043941766023636,True
+209,0.7970132231712341,1.0,0.12057658284902573,True
+21,0.7409422397613525,1.0,0.12115594744682312,True
+15,0.7798006534576416,1.0,0.12195061147212982,True
+223,0.7598530650138855,1.0,0.12308055907487869,True
+48,0.2917308509349823,0.0,0.12323271483182907,True
+9,0.6980370879173279,1.0,0.12332192808389664,True
+173,0.7499210834503174,1.0,0.12339523434638977,True
+73,0.2615889012813568,0.0,0.12378858029842377,True
+204,0.7525201439857483,1.0,0.1241786926984787,True
+92,0.7257311940193176,1.0,0.12514808773994446,True
+2,0.6848272681236267,1.0,0.12516646087169647,True
+191,0.7565965056419373,1.0,0.1252429187297821,True
+32,0.7577760219573975,1.0,0.12556825578212738,True
+109,0.6913741827011108,1.0,0.12734569609165192,True
+188,0.6995864510536194,1.0,0.12773779034614563,True
+71,0.3063671290874481,0.0,0.12778043746948242,True
+139,0.747707188129425,1.0,0.12852038443088531,True
+18,0.722952663898468,1.0,0.1286218762397766,True
+112,0.288512259721756,1.0,0.1296192854642868,False
+161,0.724658191204071,1.0,0.13051988184452057,True
+160,0.6982055306434631,1.0,0.13111035525798798,True
+85,0.7322865724563599,1.0,0.13182096183300018,True
+0,0.6256694793701172,1.0,0.13254255056381226,True
+100,0.629258930683136,1.0,0.13331060111522675,True
+26,0.38021254539489746,0.0,0.13438573479652405,True
+14,0.6863818168640137,1.0,0.13650083541870117,True
+115,0.33127251267433167,0.0,0.136766254901886,True
+10,0.711577296257019,1.0,0.13678012788295746,True
+107,0.5779245495796204,1.0,0.13765454292297363,True
+182,0.6030951738357544,1.0,0.13834097981452942,True
+101,0.6345707774162292,1.0,0.13905014097690582,True
+150,0.5895432829856873,1.0,0.13923202455043793,True
+178,0.6916580200195312,1.0,0.13963013887405396,True
+5,0.6439388990402222,1.0,0.1400848776102066,True
+202,0.35062551498413086,0.0,0.14039146900177002,True
+172,0.7234364151954651,1.0,0.140615776181221,True
+6,0.5680787563323975,0.0,0.14148592948913574,False
+94,0.6395781636238098,1.0,0.14206628501415253,True
+27,0.5554025769233704,1.0,0.14241153001785278,True
+102,0.6225922107696533,1.0,0.14244593679904938,True
+77,0.5900291204452515,1.0,0.1430179625749588,True
+25,0.5724360346794128,1.0,0.1433231234550476,True
+199,0.6326014399528503,1.0,0.14344511926174164,True
+229,0.5214329361915588,1.0,0.144350066781044,True
+57,0.6745962500572205,1.0,0.14562147855758667,True
+121,0.4844725728034973,1.0,0.14744895696640015,False
+155,0.39916175603866577,0.0,0.14898072183132172,True
+8,0.6057963967323303,1.0,0.1497642546892166,True
+156,0.6141930818557739,1.0,0.15126197040081024,True
+81,0.56984543800354,1.0,0.1522326022386551,True
+58,0.44251352548599243,1.0,0.1530144065618515,False
+131,0.5784888863563538,1.0,0.15553662180900574,True
+36,0.4905702769756317,0.0,0.15639249980449677,True
+108,0.49245527386665344,1.0,0.15889272093772888,False
+87,0.45883727073669434,0.0,0.1617383509874344,True

+ 162 - 0
threshold.py

@@ -0,0 +1,162 @@
+import pandas as pd
+import numpy as np
+import os
+import tomli as toml
+from utils.data.datasets import prepare_datasets
+import utils.ensemble as ens
+import torch
+import matplotlib.pyplot as plt
+import sklearn.metrics as metrics
+from tqdm import tqdm
+
+# CONFIGURATION
+if os.getenv("ADL_CONFIG_PATH") is None:
+    with open("config.toml", "rb") as f:
+        config = toml.load(f)
+else:
+    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
+        config = toml.load(f)
+
+
+# This function returns a list of the accuracies given a threshold
+def threshold(config):
+    # First, get the model data
+    ts, vs, test_set = prepare_datasets(
+        config["paths"]["mri_data"],
+        config["paths"]["xls_data"],
+        config["dataset"]["validation_split"],
+        944,
+        config["training"]["device"],
+    )
+
+    test_set = test_set + vs
+
+    models, _ = ens.load_models(
+        config["paths"]["model_output"] + config["ensemble"]["name"] + "/",
+        config["training"]["device"],
+    )
+
+    predictions = []
+
+    # Evaluate ensemble and uncertainty test set
+    for mdata, target in tqdm(test_set, total=len(test_set)):
+        mri, xls = mdata
+        mri = mri.unsqueeze(0)
+        xls = xls.unsqueeze(0)
+        mdata = (mri, xls)
+        mean, variance = ens.ensemble_predict(models, mdata)
+        stdev = torch.sqrt(variance)
+        prediction = mean.item()
+
+        target = target[1]
+
+        # Check if the prediction is correct
+        correct = (prediction < 0.5 and int(target.item()) == 0) or (
+            prediction >= 0.5 and int(target.item()) == 1
+        )
+
+        predictions.append(
+            {
+                "Prediction": prediction,
+                "Actual": target.item(),
+                "Stdev": stdev.item(),
+                "Correct": correct,
+            }
+        )
+
+    # Sort the predictions by the uncertainty
+    predictions = pd.DataFrame(predictions).sort_values(by="Stdev")
+
+    thresholds = []
+    quantiles = np.arange(0.1, 1, 0.1)
+    # get uncertainty quantiles
+    for quantile in quantiles:
+        thresholds.append(predictions["Stdev"].quantile(quantile))
+
+    # Calculate the accuracy of the model for each threshold
+    accuracies = []
+    # Calculate the accuracy of the model for each threshold
+    for threshold, quantile in zip(thresholds, quantiles):
+        filtered = predictions[predictions["Stdev"] <= threshold]
+        correct = filtered["Correct"].sum()
+        total = len(filtered)
+        accuracy = correct / total
+
+        false_positives = len(
+            filtered[(filtered["Prediction"] >= 0.5) & (filtered["Actual"] == 0)]
+        )
+
+        false_negatives = len(
+            filtered[(filtered["Prediction"] < 0.5) & (filtered["Actual"] == 1)]
+        )
+
+        f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
+
+        auc = metrics.roc_auc_score(filtered["Actual"], filtered["Prediction"])
+
+        accuracies.append(
+            {
+                "Threshold": threshold,
+                "Accuracy": accuracy,
+                "Quantile": quantile,
+                "F1": f1,
+                "AUC": auc,
+            }
+        )
+
+    predictions.to_csv(
+        f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
+    )
+
+    return pd.DataFrame(accuracies)
+
+
+result = threshold(config)
+result.to_csv("coverage.csv")
+
+result = pd.read_csv("coverage.csv")
+predictions = pd.read_csv(
+    f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
+)
+print(result)
+
+
+plt.figure()
+
+plt.plot(result["Quantile"], result["Accuracy"])
+plt.xlabel("Coverage")
+plt.ylabel("Accuracy")
+plt.gca().invert_xaxis()
+
+plt.savefig(
+    f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
+)
+
+plt.figure()
+plt.plot(result["Quantile"], result["F1"])
+plt.xlabel("Coverage")
+plt.ylabel("F1")
+plt.gca().invert_xaxis()
+
+plt.savefig(
+    f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
+)
+
+plt.figure()
+plt.plot(result["Quantile"], result["AUC"])
+plt.xlabel("Coverage")
+plt.ylabel("AUC")
+plt.gca().invert_xaxis()
+
+plt.savefig(
+    f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
+)
+
+# create histogram of the incorrect predictions vs the uncertainty
+plt.figure()
+plt.hist(predictions[~predictions["Correct"]]["Stdev"], bins=10)
+plt.xlabel("Uncertainty")
+plt.ylabel("Number of incorrect predictions")
+plt.savefig(
+    f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
+)

+ 3 - 3
train_cnn.py

@@ -2,6 +2,7 @@
 import torch
 import torch.nn as nn
 import torch.optim as optim
+import shutil
 
 # GENERAL USE
 import random as rand
@@ -31,6 +32,8 @@ else:
 
 # Force cuDNN initialization
 force_init_cudnn(config["training"]["device"])
+# Generate seed for each set of runs
+seed = rand.randint(0, 1000)
 
 for i in range(config["training"]["runs"]):
     # Set up the model
@@ -48,9 +51,6 @@ for i in range(config["training"]["runs"]):
         model.parameters(), lr=config["hyperparameters"]["learning_rate"]
     )
 
-    # Generate seed for each run
-    seed = rand.randint(0, 1000)
-
     # Prepare data
     train_dataset, val_dataset, test_dataset = prepare_datasets(
         config["paths"]["mri_data"],

+ 25 - 19
utils/data/datasets.py

@@ -8,6 +8,7 @@ import torch
 from torch.utils.data import Dataset
 import pandas as pd
 from torch.utils.data import DataLoader
+import math
 
 
 """
@@ -56,32 +57,37 @@ Returns train_list, val_list and test_list in format [(image, id), ...] each
 
 def get_train_val_test(AD_list, NL_list, val_split):
     train_list, val_list, test_list = [], [], []
+    # For the purposes of this split, the val_split constitutes the validation and testing split, as they are divided evenly
 
-    num_test_ad = int(len(AD_list) * val_split)
-    num_test_nl = int(len(NL_list) * val_split)
+    # get the overall length of the data
+    AD_len = len(AD_list)
+    NL_len = len(NL_list)
 
-    num_val_ad = int((len(AD_list) - num_test_ad) * val_split)
-    num_val_nl = int((len(NL_list) - num_test_nl) * val_split)
+    # First, determine the length of each of the sets
+    AD_val_len = int(math.ceil(AD_len * val_split * 0.5))
+    NL_val_len = int(math.ceil(NL_len * val_split * 0.5))
 
-    # Sets up ADs
-    for image in AD_list[0:num_val_ad]:
-        val_list.append((image, 1))
+    AD_test_len = int(math.floor(AD_len * val_split * 0.5))
+    NL_test_len = int(math.floor(NL_len * val_split * 0.5))
 
-    for image in AD_list[num_val_ad:num_test_ad]:
-        test_list.append((image, 1))
+    AD_train_len = AD_len - AD_val_len - AD_test_len
+    NL_train_len = NL_len - NL_val_len - NL_test_len
 
-    for image in AD_list[num_test_ad:]:
-        train_list.append((image, 1))
+    # Add the data to the sets
+    for i in range(AD_train_len):
+        train_list.append((AD_list[i], 1))
+    for i in range(NL_train_len):
+        train_list.append((NL_list[i], 0))
 
-    # Sets up NLs
-    for image in NL_list[0:num_val_nl]:
-        val_list.append((image, 0))
+    for i in range(AD_train_len, AD_train_len + AD_val_len):
+        val_list.append((AD_list[i], 1))
+    for i in range(NL_train_len, NL_train_len + NL_val_len):
+        val_list.append((NL_list[i], 0))
 
-    for image in NL_list[num_val_nl:num_test_nl]:
-        test_list.append((image, 0))
-
-    for image in NL_list[num_test_nl:]:
-        train_list.append((image, 0))
+    for i in range(AD_train_len + AD_val_len, AD_len):
+        test_list.append((AD_list[i], 1))
+    for i in range(NL_train_len + NL_val_len, NL_len):
+        test_list.append((NL_list[i], 0))
 
     return train_list, val_list, test_list