Parcourir la source

Pre-rewrite commit

Nicholas Schense il y a 4 mois
Parent
commit
a26abdc067
7 fichiers modifiés avec 139 ajouts et 286 suppressions
  1. 4 0
      .vscode/settings.json
  2. 2 2
      config.toml
  3. 0 10
      coverage.csv
  4. BIN
      coverage.png
  5. 0 232
      predictions.csv
  6. 6 0
      ruff.toml
  7. 127 42
      threshold.py

+ 4 - 0
.vscode/settings.json

@@ -0,0 +1,4 @@
+{
+    "ruff.nativeServer": true,
+    "ruff.configuration": "${workspaceFolder}/ruff.toml"
+}

+ 2 - 2
config.toml

@@ -16,7 +16,7 @@ validation_split = 0.4 #Splits the dataset into the train and validation/test se
 #|splt*0.5  | split*0.5      | 1-split   |
 
 [model]
-name = 'cnn-100x30'
+name = 'cnn-10x10'
 image_channels = 1
 clin_data_channels = 2
 
@@ -29,5 +29,5 @@ droprate = 0.5
 silent = false
 
 [ensemble]
-name = 'cnn-100x30'
+name = 'cnn-10x10'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning

+ 0 - 10
coverage.csv

@@ -1,10 +0,0 @@
-,Threshold,Accuracy,Quantile,F1,AUC
-0,0.07056796699762344,1.0,0.1,1.0,1.0
-1,0.0776021957397461,0.9354838709677419,0.2,0.9666666666666667,0.9947712418300654
-2,0.0822245217859745,0.9139784946236559,0.30000000000000004,0.9550561797752809,0.98
-3,0.08846938908100128,0.9105691056910569,0.4,0.9531914893617022,0.9721840659340659
-4,0.09666759520769119,0.9285714285714286,0.5,0.9629629629629629,0.9819397993311036
-5,0.10368437767028807,0.9405405405405406,0.6,0.9693593314763231,0.9885496183206106
-6,0.11318791583180428,0.9441860465116279,0.7000000000000001,0.9712918660287081,0.9910283619525372
-7,0.12286047786474227,0.9512195121951219,0.8,0.975,0.9934722222222222
-8,0.13621583431959153,0.9530685920577617,0.9,0.9759704251386322,0.9944568809295384

BIN
coverage.png


+ 0 - 232
predictions.csv

@@ -1,232 +0,0 @@
-,Prediction,Actual,Stdev,Correct
-30,0.08064605295658112,0.0,0.05030234903097153,True
-75,0.08053117245435715,0.0,0.051269371062517166,True
-76,0.07893505692481995,0.0,0.05134830251336098,True
-189,0.08240099996328354,0.0,0.053257476538419724,True
-114,0.08727512508630753,0.0,0.0539369061589241,True
-215,0.0875755175948143,0.0,0.056368160992860794,True
-45,0.09081317484378815,0.0,0.0573900043964386,True
-22,0.10104575753211975,0.0,0.058140166103839874,True
-106,0.10128451883792877,0.0,0.0595875158905983,True
-120,0.10675643384456635,0.0,0.060996584594249725,True
-49,0.10084360092878342,0.0,0.061394911259412766,True
-95,0.10570310056209564,0.0,0.06182017922401428,True
-127,0.10709960013628006,0.0,0.06241276487708092,True
-176,0.11404267698526382,0.0,0.06323474645614624,True
-47,0.11116322875022888,0.0,0.0641709491610527,True
-162,0.1157153844833374,0.0,0.06449578702449799,True
-110,0.10752685368061066,0.0,0.06566949933767319,True
-222,0.9176755547523499,1.0,0.06679949164390564,True
-35,0.9077042937278748,1.0,0.06700856983661652,True
-84,0.12747898697853088,0.0,0.06788456439971924,True
-70,0.11196906864643097,0.0,0.06800825148820877,True
-51,0.9217737913131714,1.0,0.06806794553995132,True
-193,0.12001504749059677,0.0,0.06834182143211365,True
-164,0.11728300899267197,0.0,0.06834923475980759,True
-147,0.12182613462209702,0.0,0.06854480504989624,True
-187,0.12404181063175201,0.0,0.06857491284608841,True
-41,0.12975123524665833,0.0,0.06927058100700378,True
-200,0.12243807315826416,0.0,0.07033567875623703,True
-42,0.13068115711212158,0.0,0.07091105729341507,True
-135,0.12150847911834717,0.0,0.07092347741127014,True
-220,0.14413994550704956,0.0,0.07116097211837769,True
-98,0.12036634236574173,0.0,0.07116296142339706,True
-197,0.13115964829921722,0.0,0.07131355255842209,True
-4,0.13327959179878235,0.0,0.07139390707015991,True
-212,0.912270724773407,1.0,0.07254747301340103,True
-171,0.9018120169639587,1.0,0.07269398123025894,True
-186,0.1429961919784546,0.0,0.07408874481916428,True
-86,0.9191048741340637,1.0,0.07431873679161072,True
-138,0.12933339178562164,0.0,0.07432941347360611,True
-90,0.13462725281715393,0.0,0.07481259107589722,True
-151,0.1324373334646225,0.0,0.07491683959960938,True
-66,0.1392369568347931,0.0,0.07665596902370453,True
-152,0.9079922437667847,1.0,0.07676095515489578,True
-89,0.12983743846416473,0.0,0.076902374625206,True
-56,0.15645432472229004,0.0,0.07711644470691681,True
-16,0.13897286355495453,0.0,0.07711756974458694,True
-185,0.41366952657699585,0.0,0.07714840024709702,True
-54,0.415810763835907,0.0,0.07721669971942902,True
-224,0.4179881811141968,0.0,0.07733853161334991,True
-105,0.4179896414279938,1.0,0.07733859866857529,False
-145,0.4053601622581482,1.0,0.07737884670495987,False
-96,0.9032279849052429,1.0,0.07767074555158615,True
-206,0.42247098684310913,0.0,0.07768560200929642,True
-124,0.4271637201309204,1.0,0.07808028161525726,False
-140,0.14246824383735657,0.0,0.07816833257675171,True
-132,0.16837045550346375,0.0,0.07818591594696045,True
-39,0.42956098914146423,1.0,0.07843437045812607,False
-3,0.14686696231365204,0.0,0.07875282317399979,True
-180,0.14776481688022614,0.0,0.078982874751091,True
-11,0.1500343233346939,0.0,0.079056516289711,True
-60,0.17619368433952332,0.0,0.0796256810426712,True
-227,0.42440199851989746,1.0,0.07999620586633682,False
-226,0.8894540071487427,1.0,0.08009441941976547,True
-38,0.8856544494628906,1.0,0.08010639250278473,True
-194,0.16516436636447906,0.0,0.08040788024663925,True
-134,0.4363965392112732,0.0,0.08042025566101074,True
-205,0.8954517841339111,1.0,0.08050797134637833,True
-166,0.15232831239700317,0.0,0.08051760494709015,True
-63,0.43758124113082886,0.0,0.08097495883703232,True
-74,0.14938735961914062,0.0,0.08104454725980759,True
-37,0.16830533742904663,0.0,0.08133205771446228,True
-55,0.43862971663475037,0.0,0.08157221972942352,True
-196,0.43862971663475037,0.0,0.08157221972942352,True
-111,0.8900614976882935,1.0,0.08185040950775146,True
-117,0.1613903045654297,0.0,0.0823027715086937,True
-91,0.4371449947357178,0.0,0.08239345252513885,True
-129,0.4371453821659088,0.0,0.08239352703094482,True
-12,0.1437193602323532,0.0,0.08240661025047302,True
-50,0.8827714323997498,1.0,0.08248955756425858,True
-183,0.1775076985359192,0.0,0.08252869546413422,True
-179,0.44035619497299194,1.0,0.08286683261394501,False
-144,0.44063085317611694,1.0,0.08292903006076813,False
-175,0.16922545433044434,0.0,0.08365222811698914,True
-168,0.1605507880449295,0.0,0.08430176228284836,True
-24,0.1561395227909088,0.0,0.08435925841331482,True
-34,0.16665950417518616,0.0,0.08514362573623657,True
-157,0.8834100365638733,1.0,0.08561325073242188,True
-61,0.8795567154884338,1.0,0.08580030500888824,True
-13,0.8699986338615417,1.0,0.08597418665885925,True
-141,0.16704654693603516,0.0,0.08632723987102509,True
-181,0.16343875229358673,0.0,0.08685992658138275,True
-64,0.15710464119911194,0.0,0.08711321651935577,True
-19,0.17187774181365967,0.0,0.08759700506925583,True
-113,0.8657154440879822,1.0,0.08773587644100189,True
-213,0.17988014221191406,0.0,0.08796938508749008,True
-221,0.8197951316833496,1.0,0.08860242366790771,True
-53,0.44390758872032166,0.0,0.08881814777851105,True
-65,0.17992322146892548,0.0,0.0888797715306282,True
-210,0.18149317800998688,0.0,0.08959396928548813,True
-158,0.864574134349823,1.0,0.0906127393245697,True
-46,0.8631331920623779,1.0,0.09067254513502121,True
-195,0.1460513472557068,0.0,0.09078666567802429,True
-133,0.19544997811317444,0.0,0.09096378087997437,True
-125,0.17670367658138275,0.0,0.0911884531378746,True
-126,0.878413200378418,1.0,0.09187516570091248,True
-214,0.16936077177524567,0.0,0.0920468270778656,True
-68,0.18125255405902863,0.0,0.09219114482402802,True
-130,0.194206103682518,0.0,0.09236117452383041,True
-128,0.18882118165493011,0.0,0.0924619808793068,True
-43,0.2029598355293274,0.0,0.09254654496908188,True
-184,0.17665491998195648,0.0,0.09256598353385925,True
-149,0.19111061096191406,0.0,0.09315288811922073,True
-69,0.8676502704620361,1.0,0.09340927004814148,True
-103,0.17929165065288544,0.0,0.09348834306001663,True
-154,0.18596304953098297,0.0,0.09386551380157471,True
-230,0.16378547251224518,0.0,0.09413038194179535,True
-78,0.8370398283004761,1.0,0.09446965903043747,True
-118,0.8590189218521118,1.0,0.09481213986873627,True
-67,0.8633883595466614,1.0,0.0950653925538063,True
-123,0.1977778524160385,0.0,0.09530282765626907,True
-88,0.8346447348594666,1.0,0.0953729897737503,True
-170,0.18664516508579254,0.0,0.09701987355947495,True
-33,0.8454685807228088,1.0,0.09758497774600983,True
-211,0.1807575225830078,0.0,0.09802558273077011,True
-148,0.22706817090511322,0.0,0.09901139885187149,True
-7,0.22125060856342316,0.0,0.09902060776948929,True
-228,0.1942971795797348,0.0,0.09912717342376709,True
-52,0.19154150784015656,0.0,0.09937300533056259,True
-207,0.1952042281627655,0.0,0.09943529963493347,True
-20,0.8431762456893921,1.0,0.09954006224870682,True
-122,0.19631659984588623,0.0,0.09989665448665619,True
-23,0.21731746196746826,0.0,0.09994123131036758,True
-177,0.2133563905954361,0.0,0.10010655969381332,True
-143,0.8304190039634705,1.0,0.10014497488737106,True
-153,0.8434374928474426,1.0,0.1012844443321228,True
-216,0.24578765034675598,0.0,0.10183314979076385,True
-99,0.8257945775985718,1.0,0.10200571268796921,True
-201,0.8297196626663208,1.0,0.10201609134674072,True
-203,0.8440954685211182,1.0,0.1024559885263443,True
-59,0.24426139891147614,0.0,0.10260660946369171,True
-217,0.23986373841762543,0.0,0.10291758179664612,True
-146,0.22073577344417572,0.0,0.1031135618686676,True
-218,0.8017507195472717,1.0,0.10448488593101501,True
-17,0.19681228697299957,0.0,0.10469834506511688,True
-163,0.2411302775144577,0.0,0.10534659028053284,True
-159,0.2364622801542282,0.0,0.10561579465866089,True
-165,0.23870216310024261,0.0,0.10607533156871796,True
-79,0.22018252313137054,0.0,0.10616738349199295,True
-97,0.8139686584472656,1.0,0.10667204856872559,True
-31,0.8227530717849731,1.0,0.10737413913011551,True
-137,0.250531405210495,0.0,0.10740893334150314,True
-62,0.8103730082511902,1.0,0.1082184836268425,True
-116,0.7791247367858887,1.0,0.10977177321910858,True
-1,0.2479267120361328,0.0,0.10978859663009644,True
-174,0.28649798035621643,0.0,0.11131870746612549,True
-208,0.23735979199409485,0.0,0.11184616386890411,True
-198,0.8244062662124634,1.0,0.11186125129461288,True
-83,0.7983875870704651,1.0,0.11210570484399796,True
-40,0.8164024353027344,1.0,0.11215054243803024,True
-93,0.8240574598312378,1.0,0.11299832910299301,True
-82,0.2304438203573227,0.0,0.11315932124853134,True
-169,0.7665874361991882,1.0,0.11339464783668518,True
-80,0.7979971170425415,1.0,0.11344446241855621,True
-29,0.7858783006668091,1.0,0.11361948400735855,True
-28,0.8008655905723572,1.0,0.11364077031612396,True
-142,0.7987484335899353,1.0,0.11443256586790085,True
-167,0.8054426312446594,1.0,0.11540836095809937,True
-136,0.788724958896637,1.0,0.11613089591264725,True
-192,0.7759780883789062,1.0,0.1170915886759758,True
-104,0.7606977820396423,1.0,0.11716601997613907,True
-219,0.8869320750236511,1.0,0.11787694692611694,True
-72,0.7505139708518982,1.0,0.11789941042661667,True
-44,0.7945115566253662,1.0,0.11895372718572617,True
-119,0.766346275806427,1.0,0.11902948468923569,True
-190,0.7719056010246277,1.0,0.11941025406122208,True
-225,0.7407894730567932,1.0,0.12043941766023636,True
-209,0.7970132231712341,1.0,0.12057658284902573,True
-21,0.7409422397613525,1.0,0.12115594744682312,True
-15,0.7798006534576416,1.0,0.12195061147212982,True
-223,0.7598530650138855,1.0,0.12308055907487869,True
-48,0.2917308509349823,0.0,0.12323271483182907,True
-9,0.6980370879173279,1.0,0.12332192808389664,True
-173,0.7499210834503174,1.0,0.12339523434638977,True
-73,0.2615889012813568,0.0,0.12378858029842377,True
-204,0.7525201439857483,1.0,0.1241786926984787,True
-92,0.7257311940193176,1.0,0.12514808773994446,True
-2,0.6848272681236267,1.0,0.12516646087169647,True
-191,0.7565965056419373,1.0,0.1252429187297821,True
-32,0.7577760219573975,1.0,0.12556825578212738,True
-109,0.6913741827011108,1.0,0.12734569609165192,True
-188,0.6995864510536194,1.0,0.12773779034614563,True
-71,0.3063671290874481,0.0,0.12778043746948242,True
-139,0.747707188129425,1.0,0.12852038443088531,True
-18,0.722952663898468,1.0,0.1286218762397766,True
-112,0.288512259721756,1.0,0.1296192854642868,False
-161,0.724658191204071,1.0,0.13051988184452057,True
-160,0.6982055306434631,1.0,0.13111035525798798,True
-85,0.7322865724563599,1.0,0.13182096183300018,True
-0,0.6256694793701172,1.0,0.13254255056381226,True
-100,0.629258930683136,1.0,0.13331060111522675,True
-26,0.38021254539489746,0.0,0.13438573479652405,True
-14,0.6863818168640137,1.0,0.13650083541870117,True
-115,0.33127251267433167,0.0,0.136766254901886,True
-10,0.711577296257019,1.0,0.13678012788295746,True
-107,0.5779245495796204,1.0,0.13765454292297363,True
-182,0.6030951738357544,1.0,0.13834097981452942,True
-101,0.6345707774162292,1.0,0.13905014097690582,True
-150,0.5895432829856873,1.0,0.13923202455043793,True
-178,0.6916580200195312,1.0,0.13963013887405396,True
-5,0.6439388990402222,1.0,0.1400848776102066,True
-202,0.35062551498413086,0.0,0.14039146900177002,True
-172,0.7234364151954651,1.0,0.140615776181221,True
-6,0.5680787563323975,0.0,0.14148592948913574,False
-94,0.6395781636238098,1.0,0.14206628501415253,True
-27,0.5554025769233704,1.0,0.14241153001785278,True
-102,0.6225922107696533,1.0,0.14244593679904938,True
-77,0.5900291204452515,1.0,0.1430179625749588,True
-25,0.5724360346794128,1.0,0.1433231234550476,True
-199,0.6326014399528503,1.0,0.14344511926174164,True
-229,0.5214329361915588,1.0,0.144350066781044,True
-57,0.6745962500572205,1.0,0.14562147855758667,True
-121,0.4844725728034973,1.0,0.14744895696640015,False
-155,0.39916175603866577,0.0,0.14898072183132172,True
-8,0.6057963967323303,1.0,0.1497642546892166,True
-156,0.6141930818557739,1.0,0.15126197040081024,True
-81,0.56984543800354,1.0,0.1522326022386551,True
-58,0.44251352548599243,1.0,0.1530144065618515,False
-131,0.5784888863563538,1.0,0.15553662180900574,True
-36,0.4905702769756317,0.0,0.15639249980449677,True
-108,0.49245527386665344,1.0,0.15889272093772888,False
-87,0.45883727073669434,0.0,0.1617383509874344,True

+ 6 - 0
ruff.toml

@@ -0,0 +1,6 @@
+[lint]
+select = ["E4", "E7", "E9", "F", "B"]
+
+
+[format]
+quote-style = "single"

+ 127 - 42
threshold.py

@@ -9,12 +9,14 @@ import matplotlib.pyplot as plt
 import sklearn.metrics as metrics
 from tqdm import tqdm
 
+RUN = True
+
 # CONFIGURATION
-if os.getenv("ADL_CONFIG_PATH") is None:
-    with open("config.toml", "rb") as f:
+if os.getenv('ADL_CONFIG_PATH') is None:
+    with open('config.toml', 'rb') as f:
         config = toml.load(f)
 else:
-    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
+    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
         config = toml.load(f)
 
 
@@ -22,21 +24,24 @@ else:
 def threshold(config):
     # First, get the model data
     ts, vs, test_set = prepare_datasets(
-        config["paths"]["mri_data"],
-        config["paths"]["xls_data"],
-        config["dataset"]["validation_split"],
+        config['paths']['mri_data'],
+        config['paths']['xls_data'],
+        config['dataset']['validation_split'],
         944,
-        config["training"]["device"],
+        config['training']['device'],
     )
 
     test_set = test_set + vs
 
     models, _ = ens.load_models(
-        config["paths"]["model_output"] + config["ensemble"]["name"] + "/",
-        config["training"]["device"],
+        config['paths']['model_output'] + config['ensemble']['name'] + '/',
+        config['training']['device'],
     )
 
+    indv_model = models[0]
+
     predictions = []
+    indv_predictions = []
 
     # Evaluate ensemble and uncertainty test set
     for mdata, target in tqdm(test_set, total=len(test_set)):
@@ -57,50 +62,85 @@ def threshold(config):
 
         predictions.append(
             {
-                "Prediction": prediction,
-                "Actual": target.item(),
-                "Stdev": stdev.item(),
-                "Correct": correct,
+                'Prediction': prediction,
+                'Actual': target.item(),
+                'Stdev': stdev.item(),
+                'Correct': correct,
+            }
+        )
+
+        i_mean = indv_model(mdata)[:, 1].item()
+        i_correct = (i_mean < 0.5 and int(target.item()) == 0) or (
+            i_mean >= 0.5 and int(target.item()) == 1
+        )
+
+        indv_predictions.append(
+            {
+                'Prediction': i_mean,
+                'Actual': target.item(),
+                'Stdev': 0,
+                'Correct': i_correct,
             }
         )
 
     # Sort the predictions by the uncertainty
-    predictions = pd.DataFrame(predictions).sort_values(by="Stdev")
+    predictions = pd.DataFrame(predictions).sort_values(by='Stdev')
+
+    # Calculate the metrics for the individual model
+    indv_predictions = pd.DataFrame(indv_predictions)
+    indv_correct = indv_predictions['Correct'].sum()
+    indv_accuracy = indv_correct / len(indv_predictions)
+    indv_false_pos = len(
+        indv_predictions[
+            (indv_predictions['Prediction'] >= 0.5) & (indv_predictions['Actual'] == 0)
+        ]
+    )
+    indv_false_neg = len(
+        indv_predictions[
+            (indv_predictions['Prediction'] < 0.5) & (indv_predictions['Actual'] == 1)
+        ]
+    )
+    indv_f1 = 2 * indv_correct / (2 * indv_correct + indv_false_pos + indv_false_neg)
+    indv_auc = metrics.roc_auc_score(
+        indv_predictions['Actual'], indv_predictions['Prediction']
+    )
+
+    indv_metrics = {'Accuracy': indv_accuracy, 'F1': indv_f1, 'AUC': indv_auc}
 
     thresholds = []
     quantiles = np.arange(0.1, 1, 0.1)
     # get uncertainty quantiles
     for quantile in quantiles:
-        thresholds.append(predictions["Stdev"].quantile(quantile))
+        thresholds.append(predictions['Stdev'].quantile(quantile))
 
     # Calculate the accuracy of the model for each threshold
     accuracies = []
     # Calculate the accuracy of the model for each threshold
     for threshold, quantile in zip(thresholds, quantiles):
-        filtered = predictions[predictions["Stdev"] <= threshold]
-        correct = filtered["Correct"].sum()
+        filtered = predictions[predictions['Stdev'] <= threshold]
+        correct = filtered['Correct'].sum()
         total = len(filtered)
         accuracy = correct / total
 
         false_positives = len(
-            filtered[(filtered["Prediction"] >= 0.5) & (filtered["Actual"] == 0)]
+            filtered[(filtered['Prediction'] >= 0.5) & (filtered['Actual'] == 0)]
         )
 
         false_negatives = len(
-            filtered[(filtered["Prediction"] < 0.5) & (filtered["Actual"] == 1)]
+            filtered[(filtered['Prediction'] < 0.5) & (filtered['Actual'] == 1)]
         )
 
         f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
 
-        auc = metrics.roc_auc_score(filtered["Actual"], filtered["Prediction"])
+        auc = metrics.roc_auc_score(filtered['Actual'], filtered['Prediction'])
 
         accuracies.append(
             {
-                "Threshold": threshold,
-                "Accuracy": accuracy,
-                "Quantile": quantile,
-                "F1": f1,
-                "AUC": auc,
+                'Threshold': threshold,
+                'Accuracy': accuracy,
+                'Quantile': quantile,
+                'F1': f1,
+                'AUC': auc,
             }
         )
 
@@ -108,24 +148,52 @@ def threshold(config):
         f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
     )
 
-    return pd.DataFrame(accuracies)
+    indv_predictions.to_csv(
+        f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_predictions.csv"
+    )
+
+    return pd.DataFrame(accuracies), indv_metrics
 
 
-result = threshold(config)
-result.to_csv("coverage.csv")
+if RUN:
+    result, indv = threshold(config)
+    result.to_csv(
+        f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
+    )
+    indv = pd.DataFrame([indv])
+    indv.to_csv(
+        f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
+    )
 
-result = pd.read_csv("coverage.csv")
+result = pd.read_csv(
+    f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
+)
 predictions = pd.read_csv(
     f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
 )
-print(result)
+indv = pd.read_csv(
+    f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
+)
+
+print(indv)
 
 
 plt.figure()
 
-plt.plot(result["Quantile"], result["Accuracy"])
-plt.xlabel("Coverage")
-plt.ylabel("Accuracy")
+plt.plot(result['Quantile'], result['Accuracy'], label='Ensemble Accuracy')
+
+plt.plot(
+    result['Quantile'],
+    [indv['Accuracy']] * len(result['Quantile']),
+    label='Individual Accuracy',
+    linestyle='--',
+)
+plt.legend()
+
+plt.title('Accuracy vs Coverage')
+
+plt.xlabel('Coverage')
+plt.ylabel('Accuracy')
 plt.gca().invert_xaxis()
 
 plt.savefig(
@@ -133,9 +201,18 @@ plt.savefig(
 )
 
 plt.figure()
-plt.plot(result["Quantile"], result["F1"])
-plt.xlabel("Coverage")
-plt.ylabel("F1")
+plt.plot(result['Quantile'], result['F1'], label='Ensemble F1')
+plt.plot(
+    result['Quantile'],
+    [indv['F1']] * len(result['Quantile']),
+    label='Individual F1',
+    linestyle='--',
+)
+plt.legend()
+plt.title('F1 vs Coverage')
+
+plt.xlabel('Coverage')
+plt.ylabel('F1')
 plt.gca().invert_xaxis()
 
 plt.savefig(
@@ -143,9 +220,17 @@ plt.savefig(
 )
 
 plt.figure()
-plt.plot(result["Quantile"], result["AUC"])
-plt.xlabel("Coverage")
-plt.ylabel("AUC")
+plt.plot(result['Quantile'], result['AUC'], label='Ensemble AUC')
+plt.plot(
+    result['Quantile'],
+    [indv['AUC']] * len(result['Quantile']),
+    label='Individual AUC',
+    linestyle='--',
+)
+plt.legend()
+plt.title('AUC vs Coverage')
+plt.xlabel('Coverage')
+plt.ylabel('AUC')
 plt.gca().invert_xaxis()
 
 plt.savefig(
@@ -154,9 +239,9 @@ plt.savefig(
 
 # create histogram of the incorrect predictions vs the uncertainty
 plt.figure()
-plt.hist(predictions[~predictions["Correct"]]["Stdev"], bins=10)
-plt.xlabel("Uncertainty")
-plt.ylabel("Number of incorrect predictions")
+plt.hist(predictions[~predictions['Correct']]['Stdev'], bins=10)
+plt.xlabel('Uncertainty')
+plt.ylabel('Number of incorrect predictions')
 plt.savefig(
     f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
 )