12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007 |
- import numpy as np
- from sklearn.metrics import roc_curve, auc
- from sklearn.metrics import confusion_matrix
- from sklearn.preprocessing import label_binarize
- from keras import regularizers
- import pickle as pickle
- from utils.preprocess import DataLoader
- from utils.models import Parameters, CNN_Net, RNN_Net
- from utils.heatmapPlotting import heatmapPlotter
- from matplotlib import pyplot as plt
- import pandas as pd
- from scipy import interp
- from keras.models import Model, load_model#, load_weights
- from keras.layers import Input
- from keras.optimizers import Adam
- import tensorflow as tf
- from IPython.display import Image
- import matplotlib.cm as cm
- import SimpleITK as sitk
- import csv
- from copy import deepcopy
- import matplotlib.colors as mcolors
- import nibabel as nib
- import math
- import sys
- sys.path.append('//data/data_wnx3/data_wnx1/rschuurs/CNN+RNN-2class-1cnn-CLEAN/utils')
- from sepconv3D import SeparableConv3D
- ##for 2 class CNN + RNN ##
- #Dummy feature vectors are added to feature vectors from CNN (which are fed only the images)
- import os
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = "0" # use id from $ nvidia-smi
- target_rows = 91
- target_cols = 109
- depth = 91
- axis = 1
- num_clinical = 2
- CNN_drop_rate = 0.3
- RNN_drop_rate = 0.1
- CNN_w_regularizer = regularizers.l2(2e-2)
- RNN_w_regularizer = regularizers.l2(1e-6)
- CNN_batch_size = 10
- RNN_batch_size = 5
- val_split = 0.2
- optimizer = Adam(lr=1e-5)
- final_layer_size = 5
- model_filepath = '//data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data'
- mri_datapath = '//data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/ADNI_volumes_customtemplate_float32'
- params_dict = { 'CNN_w_regularizer': CNN_w_regularizer, 'RNN_w_regularizer': RNN_w_regularizer,
- 'CNN_batch_size': CNN_batch_size, 'RNN_batch_size': RNN_batch_size,
- 'CNN_drop_rate': CNN_drop_rate, 'epochs': 30,
- 'gpu': "/gpu:0", 'model_filepath': model_filepath,
- 'image_shape': (target_rows, target_cols, depth, axis),
- 'num_clinical': num_clinical,
- 'final_layer_size': final_layer_size,
- 'optimizer': optimizer, 'RNN_drop_rate': RNN_drop_rate,}
- params = Parameters(params_dict)
- seeds = [np.random.randint(1, 5000) for _ in range(1)]
- def evaluate_net (seed):
- n_classes = 2
- data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
- train_data, val_data, test_data,rnn_HdataT1,rnn_HdataT2,rnn_HdataT3,rnn_AdataT1,rnn_AdataT2,rnn_AdataT3, test_mri_nonorm = data_loader.get_train_val_test(val_split, mri_datapath)
- print('Length Val Data[0]: ',len(val_data[0]))
- #RUN THE CNN:
- netCNN = CNN_Net(params)
- historyCNN, featuresModel_CNN = netCNN.train((train_data, val_data))
- test_lossCNN, test_accCNN = netCNN.evaluate(test_data)
- test_predsCNN = netCNN.predict(test_data)
- """
- #TO LOAD A PREVIOUS MODEL FOR HEATMAPS: (uncomment this chunk and comment above chunk) #Then add the pickle file and savedWeights to the modelfilepath folder
- #can't seem to figure out how to load the whole model (but am saving it anyway). I'm only able to save and load the weights, so note that the model needs to be recompiled, so it has to be the correct architecture
- picklename = '1820'
- netCNN = CNN_Net(params)
- netCNN.load_the_weights("SavedCNNWeights")
- pickle_in = open(model_filepath+'/'+picklename+'.pickle', 'rb')
- pickle0=pickle.load(pickle_in)
- pickle_in.close()
- test_data = pickle0[5][0]
- pickle0 = 0 #to save memory
- test_lossCNN, test_accCNN = netCNN.evaluate(test_data)
- test_predsCNN = netCNN.predict(test_data)
- print('check_lossCNN, check_accCNN: '+ str(test_lossCNN)+', '+ str(test_accCNN))
- """
-
- ##PREP DATA FOR THE RNN
- #Get the feature vectors from the final layer for each training image at each timepoint:
- rnn_HpredsT1 = featuresModel_CNN.predict([rnn_HdataT1[0],rnn_HdataT1[1],rnn_HdataT1[2]])
- rnn_HpredsT2 = featuresModel_CNN.predict([rnn_HdataT2[0],rnn_HdataT2[1],rnn_HdataT2[2]])
- rnn_HpredsT3 = featuresModel_CNN.predict([rnn_HdataT3[0],rnn_HdataT3[1],rnn_HdataT3[2]])
- rnn_ApredsT1 = featuresModel_CNN.predict([rnn_AdataT1[0],rnn_AdataT1[1],rnn_AdataT1[2]])
- rnn_ApredsT2 = featuresModel_CNN.predict([rnn_AdataT2[0],rnn_AdataT2[1],rnn_AdataT2[2]])
- rnn_ApredsT3 = featuresModel_CNN.predict([rnn_AdataT3[0],rnn_AdataT3[1],rnn_AdataT3[2]])
- #grab the PTIDs for each dataset
- rnn_HptidT1 = rnn_HdataT1[4]
- rnn_HptidT2 = rnn_HdataT2[4]
- rnn_HptidT3 = rnn_HdataT3[4]
- rnn_AptidT1 = rnn_AdataT1[4]
- rnn_AptidT2 = rnn_AdataT2[4]
- rnn_AptidT3 = rnn_AdataT3[4]
- #grab the imageIDs for each dataset
- rnn_HimageIDT1 = rnn_HdataT1[5]
- rnn_HimageIDT2 = rnn_HdataT2[5]
- rnn_HimageIDT3 = rnn_HdataT3[5]
- rnn_AimageIDT1 = rnn_AdataT1[5]
- rnn_AimageIDT2 = rnn_AdataT2[5]
- rnn_AimageIDT3 = rnn_AdataT3[5]
- #add dummy feature vectors to all missing timepoints
- dummyVector = np.full((final_layer_size),-1)
- #Healthy patients
- rnn_HpredsT1_padded = []
- rnn_HpredsT2_padded = []
- rnn_HpredsT3_padded = []
- rnn_HptidT1_padded = []
- rnn_HptidT2_padded = []
- rnn_HptidT3_padded = []
- rnn_HimageIDT1_padded = []
- rnn_HimageIDT2_padded = []
- rnn_HimageIDT3_padded = []
- j=0
- HrnnT1T2T3 = 0
- HrnnT1T2 = 0
- HrnnT1T3 = 0
- HrnnT1 = 0
- HrnnT2 = 0
- HrnnT2T3 = 0
- HrnnT3 = 0
- HrnnT1Removed = 0
- for ptidT1 in rnn_HptidT1:
- rnn_HpredsT1_padded.append(rnn_HpredsT1[j])
- rnn_HptidT1_padded.append(ptidT1)
- rnn_HimageIDT1_padded.append(rnn_HimageIDT1[j])
- j+=1
- c=0
- k=0
- t2 = False
- t3 = False
- for ptidT2 in rnn_HptidT2:
- c+=1
- if ptidT1 == ptidT2:
- rnn_HpredsT2_padded.append(rnn_HpredsT2[c-1])
- rnn_HptidT2_padded.append(ptidT2)
- rnn_HimageIDT2_padded.append(rnn_HimageIDT2[c-1])
- t2 = True
- for ptidT3 in rnn_HptidT3:
- k+=1
- if ptidT1 == ptidT3:
- rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
- rnn_HptidT3_padded.append(ptidT3)
- rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
- HrnnT1T2T3+=1
- t3 = True
- break
- if t3 == False:
- rnn_HpredsT3_padded.append(dummyVector)
- rnn_HptidT3_padded.append(ptidT1)
- rnn_HimageIDT3_padded.append('dummy')
- HrnnT1T2+=1
- break
- if t2 == False:
- rnn_HpredsT2_padded.append(dummyVector)
- rnn_HptidT2_padded.append(ptidT1)
- rnn_HimageIDT2_padded.append('dummy')
- for ptidT3 in rnn_HptidT3:
- k+=1
- if ptidT1 == ptidT3:
- rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
- rnn_HptidT3_padded.append(ptidT3)
- rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
- HrnnT1T3+=1
- t3 = True
- break
- if t3 == False:
- #rnn_HpredsT3_padded.append(dummyVector)
- HrnnT1+=1
- rnn_HpredsT1_padded.pop(-1) #remove any scans that have only T1
- rnn_HpredsT2_padded.pop(-1)
- rnn_HptidT1_padded.pop(-1)
- rnn_HptidT2_padded.pop(-1)
- rnn_HimageIDT1_padded.pop(-1)
- rnn_HimageIDT2_padded.pop(-1)
- HrnnT1Removed+=1
- c=0
- for ptidT2 in rnn_HptidT2:
- c+=1
- j=0
- k=0
- match = False
- t3=False
- for ptidT1 in rnn_HptidT1:
- j+=1
- if ptidT2 == ptidT1:
- match = True
- if match == False:
- rnn_HpredsT2_padded.append(rnn_HpredsT2[c-1])
- rnn_HpredsT1_padded.append(dummyVector)
- rnn_HptidT2_padded.append(ptidT2)
- rnn_HimageIDT2_padded.append(rnn_HimageIDT2[c-1])
- rnn_HptidT1_padded.append(ptidT1)
- rnn_HimageIDT1_padded.append('dummy')
- for ptidT3 in rnn_HptidT3:
- k+=1
- if ptidT2 == ptidT3:
- rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
- rnn_HptidT3_padded.append(ptidT2)
- rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
- t3 = True
- HrnnT2T3+=1
- break
- if t3 == False:
- rnn_HpredsT3_padded.append(dummyVector)
- rnn_HptidT3_padded.append(ptidT1)
- rnn_HimageIDT3_padded.append('dummy')
- HrnnT2+=1
- k=0
- for ptidT3 in rnn_HptidT3:
- k+=1
- j=0
- c=0
- match1 = False
- for ptidT1 in rnn_HptidT1:
- j+=1
- if ptidT3 == ptidT1:
- match1 = True
- # if match1 == True:
- # break
- if match1 == False:
- match2 = False
- for ptidT2 in rnn_HptidT2:
- c+=1
- if ptidT3 == ptidT2:
- match2 = True
- # if match2 == True:
- # break
- if match2 == False:
- rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
- rnn_HptidT3_padded.append(ptidT3)
- rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
- rnn_HpredsT1_padded.append(dummyVector)
- rnn_HptidT1_padded.append(ptidT1)
- rnn_HimageIDT1_padded.append('dummy')
- rnn_HpredsT2_padded.append(dummyVector)
- rnn_HptidT2_padded.append(ptidT1)
- rnn_HimageIDT2_padded.append('dummy')
- HrnnT3+=1
-
- #move the data from a list to an array
- j=0
- c=0
- k=0
- LenPadded = len(rnn_HpredsT1_padded)
- rnn_HpredsT1_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
- rnn_HpredsT2_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
- rnn_HpredsT3_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
- for vector in rnn_HpredsT1_padded:
- rnn_HpredsT1_padArray[j] = vector
- j+=1
- for vector in rnn_HpredsT2_padded:
- rnn_HpredsT2_padArray[c] = vector
- c+=1
- for vector in rnn_HpredsT3_padded:
- rnn_HpredsT3_padArray[k] = vector
- k+=1
-
- with open(model_filepath+'/figures/paddedPreds.txt','w') as paddedPreds:
- paddedPreds.write('Train Preds Sizes: '+'\n')
- paddedPreds.write('Type of rnn_HpredsT1: '+str(type(rnn_HpredsT1))+'\n')
- paddedPreds.write('Type of rnn_HpredsT1_padded: '+str(type(rnn_HpredsT1_padded))+'\n')
- paddedPreds.write('Type of rnn_HpredsT1_padArray: '+str(type(rnn_HpredsT1_padArray))+'\n')
- paddedPreds.write('Type of rnn_HpredsT1 elements: '+str(type(rnn_HpredsT1[0]))+'\n')
- paddedPreds.write('Type of rnn_HpredsT1_padded elements: '+str(type(rnn_HpredsT1_padded[0]))+'\n')
- paddedPreds.write('Type of rnn_HpredsT1_padArray elements: '+str(type(rnn_HpredsT1_padArray[0]))+'\n')
- paddedPreds.write('Length of rnn_HpredsT1: '+str(len(rnn_HpredsT1))+'\n')
- paddedPreds.write('Length of rnn_HpredsT1_padded: '+str(len(rnn_HpredsT1_padded))+'\n')
- paddedPreds.write('Length of rnn_HpredsT2_padded: '+str(len(rnn_HpredsT2_padded))+'\n')
- paddedPreds.write('Length of rnn_HpredsT3_padded: '+str(len(rnn_HpredsT3_padded))+'\n')
- paddedPreds.write('Length of rnn_HpredsT1_padArray: '+str(len(rnn_HpredsT1_padArray))+'\n')
- paddedPreds.write('Length of rnn_HpredsT2_padArray: '+str(len(rnn_HpredsT2_padArray))+'\n')
- paddedPreds.write('Length of rnn_HpredsT3_padArray: '+str(len(rnn_HpredsT3_padArray))+'\n')
- paddedPreds.write('Length of rnn_HptidT1_padded: '+str(len(rnn_HptidT1_padded))+'\n')
- paddedPreds.write('Length of rnn_HptidT2_padded: '+str(len(rnn_HptidT2_padded))+'\n')
- paddedPreds.write('Length of rnn_HptidT3_padded: '+str(len(rnn_HptidT3_padded))+'\n')
- paddedPreds.write('Length of rnn_HimageIDT1_padded: '+str(len(rnn_HimageIDT1_padded))+'\n')
- paddedPreds.write('Length of rnn_HimageIDT2_padded: '+str(len(rnn_HimageIDT2_padded))+'\n')
- paddedPreds.write('Length of rnn_HimageIDT3_padded: '+str(len(rnn_HimageIDT3_padded))+'\n')
- paddedPreds.write('RNN_HpredsT1_padded: '+str(rnn_HpredsT1_padded)+'\n')
- paddedPreds.write('Compare to RNN_HpredsT1: '+str(rnn_HpredsT1)+'\n')
- paddedPreds.write('RNN_HpredsT1_padArray: '+str(rnn_HpredsT1_padArray)+'\n')
- paddedPreds.write('RNN_HpredsT2_padArray: '+str(rnn_HpredsT2_padArray)+'\n')
- paddedPreds.write('RNN_HpredsT3_padArray: '+str(rnn_HpredsT3_padArray)+'\n')
- paddedPreds.write('Shape of RNN_HpredsT1_padArray: '+str(rnn_HpredsT1_padArray.shape)+'\n')
- paddedPreds.write('Shape of RNN_HpredsT1: '+str(rnn_HpredsT1.shape)+'\n')
- paddedPreds.write('RNN_HpredsT1[0]: '+str(rnn_HpredsT1[0])+'\n')
- paddedPreds.write('rnn_HpredsT1[0][0]: '+str(rnn_HpredsT1[0][0])+'\n')
- paddedPreds.write('rnn_HpredsT1_padArray[0]: '+str(rnn_HpredsT1_padArray[0])+'\n')
- paddedPreds.write('rnn_HpredsT1_padArray[0][0]: '+str(rnn_HpredsT1_padArray[0][0])+'\n')
- paddedPreds.write('# of Hrnn T1 only: '+str(HrnnT1)+'\n')
- paddedPreds.write('# of Hrnn T1 only Removed: '+str(HrnnT1Removed)+'\n')
- paddedPreds.write('# of Hrnn T1+T2: '+str(HrnnT1T2)+'\n')
- paddedPreds.write('# of Hrnn T1+T2+T3: '+str(HrnnT1T2T3)+'\n')
- paddedPreds.write('# of Hrnn T1+T3: '+str(HrnnT1T3)+'\n')
- paddedPreds.write('# of Hrnn T2 only: '+str(HrnnT2)+'\n')
- paddedPreds.write('# of Hrnn T2+T3: '+str(HrnnT2T3)+'\n')
- paddedPreds.write('# of Hrnn T3 only: '+str(HrnnT3)+'\n')
- #AD patients
- rnn_ApredsT1_padded = []
- rnn_ApredsT2_padded = []
- rnn_ApredsT3_padded = []
- rnn_AptidT1_padded = []
- rnn_AptidT2_padded = []
- rnn_AptidT3_padded = []
- rnn_AimageIDT1_padded = []
- rnn_AimageIDT2_padded = []
- rnn_AimageIDT3_padded = []
- j=0
- ArnnT1T2T3 = 0
- ArnnT1T2 = 0
- ArnnT1T3 = 0
- ArnnT1 = 0
- ArnnT2 = 0
- ArnnT2T3 = 0
- ArnnT3 = 0
- ArnnT1Removed = 0
- for ptidT1 in rnn_AptidT1:
- rnn_ApredsT1_padded.append(rnn_ApredsT1[j])
- rnn_AptidT1_padded.append(ptidT1)
- rnn_AimageIDT1_padded.append(rnn_AimageIDT1[j])
- j+=1
- c=0
- k=0
- t2 = False
- t3 = False
- for ptidT2 in rnn_AptidT2:
- c+=1
- if ptidT1 == ptidT2:
- rnn_ApredsT2_padded.append(rnn_ApredsT2[c-1])
- rnn_AptidT2_padded.append(ptidT2)
- rnn_AimageIDT2_padded.append(rnn_AimageIDT2[c-1])
- t2 = True
- for ptidT3 in rnn_AptidT3:
- k+=1
- if ptidT1 == ptidT3:
- rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
- rnn_AptidT3_padded.append(ptidT3)
- rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
- ArnnT1T2T3+=1
- t3 = True
- break
- if t3 == False:
- rnn_ApredsT3_padded.append(dummyVector)
- rnn_AptidT3_padded.append(ptidT1)
- rnn_AimageIDT3_padded.append('dummy')
- ArnnT1T2+=1
- break
- if t2 == False:
- rnn_ApredsT2_padded.append(dummyVector)
- rnn_AptidT2_padded.append(ptidT1)
- rnn_AimageIDT2_padded.append('dummy')
- for ptidT3 in rnn_AptidT3:
- k+=1
- if ptidT1 == ptidT3:
- rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
- rnn_AptidT3_padded.append(ptidT3)
- rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
- ArnnT1T3+=1
- t3 = True
- break
- if t3 == False:
- #rnn_ApredsT3_padded.append(dummyVector)
- ArnnT1+=1
- rnn_ApredsT1_padded.pop(-1) #remove any scans that have only T1
- rnn_ApredsT2_padded.pop(-1)
- rnn_AptidT1_padded.pop(-1)
- rnn_AimageIDT1_padded.pop(-1)
- rnn_AptidT2_padded.pop(-1)
- rnn_AimageIDT2_padded.pop(-1)
- ArnnT1Removed+=1
- c=0
- for ptidT2 in rnn_AptidT2:
- c+=1
- j=0
- k=0
- match = False
- t3=False
- for ptidT1 in rnn_AptidT1:
- j+=1
- if ptidT2 == ptidT1:
- match = True
- if match == False:
- rnn_ApredsT2_padded.append(rnn_ApredsT2[c-1])
- rnn_AptidT2_padded.append(ptidT2)
- rnn_AimageIDT2_padded.append(rnn_AimageIDT2[c-1])
- rnn_ApredsT1_padded.append(dummyVector)
- rnn_AptidT1_padded.append(ptidT1)
- rnn_AimageIDT1_padded.append('dummy')
- for ptidT3 in rnn_AptidT3:
- k+=1
- if ptidT2 == ptidT3:
- rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
- rnn_AptidT3_padded.append(ptidT3)
- rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
- t3 = True
- ArnnT2T3+=1
- break
- if t3 == False:
- rnn_ApredsT3_padded.append(dummyVector)
- rnn_AptidT3_padded.append(ptidT1)
- rnn_AimageIDT3_padded.append('dummy')
- ArnnT2+=1
- k=0
- for ptidT3 in rnn_AptidT3:
- k+=1
- j=0
- c=0
- match1 = False
- for ptidT1 in rnn_AptidT1:
- j+=1
- if ptidT3 == ptidT1:
- match1 = True
- # if match1 == True:
- # break
- if match1 == False:
- match2 = False
- for ptidT2 in rnn_AptidT2:
- c+=1
- if ptidT3 == ptidT2:
- match2 = True
- # if match2 == True:
- # break
- if match2 == False:
- rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
- rnn_AptidT3_padded.append(ptidT3)
- rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
- rnn_ApredsT1_padded.append(dummyVector)
- rnn_AptidT1_padded.append(ptidT1)
- rnn_AimageIDT1_padded.append('dummy')
- rnn_ApredsT2_padded.append(dummyVector)
- rnn_AptidT2_padded.append(ptidT1)
- rnn_AimageIDT2_padded.append('dummy')
- ArnnT3+=1
-
- #move the data from a list to an array
- j=0
- c=0
- k=0
- LenPadded = len(rnn_ApredsT1_padded)
- rnn_ApredsT1_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
- rnn_ApredsT2_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
- rnn_ApredsT3_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
- for vector in rnn_ApredsT1_padded:
- rnn_ApredsT1_padArray[j] = vector
- j+=1
- for vector in rnn_ApredsT2_padded:
- rnn_ApredsT2_padArray[c] = vector
- c+=1
- for vector in rnn_ApredsT3_padded:
- rnn_ApredsT3_padArray[k] = vector
- k+=1
-
- with open(model_filepath+'/figures/paddedPreds.txt','a') as paddedPreds:
- paddedPreds.write('Length of rnn_ApredsT1_padArray: '+str(len(rnn_ApredsT1_padArray))+'\n')
- paddedPreds.write('Length of rnn_ApredsT2_padArray: '+str(len(rnn_ApredsT2_padArray))+'\n')
- paddedPreds.write('Length of rnn_ApredsT3_padArray: '+str(len(rnn_ApredsT3_padArray))+'\n')
- paddedPreds.write('Length of rnn_AptidT1_padded: '+str(len(rnn_AptidT1_padded))+'\n')
- paddedPreds.write('Length of rnn_AptidT2_padded: '+str(len(rnn_AptidT2_padded))+'\n')
- paddedPreds.write('Length of rnn_AptidT3_padded: '+str(len(rnn_AptidT3_padded))+'\n')
- paddedPreds.write('Length of rnn_AimageIDT1_padded: '+str(len(rnn_AimageIDT1_padded))+'\n')
- paddedPreds.write('Length of rnn_AimageIDT2_padded: '+str(len(rnn_AimageIDT2_padded))+'\n')
- paddedPreds.write('Length of rnn_AimageIDT3_padded: '+str(len(rnn_AimageIDT3_padded))+'\n')
- paddedPreds.write('# of Arnn T1 only: '+str(ArnnT1)+'\n')
- paddedPreds.write('# of Arnn T1 only Removed: '+str(ArnnT1Removed)+'\n')
- paddedPreds.write('# of Arnn T1+T2: '+str(ArnnT1T2)+'\n')
- paddedPreds.write('# of Arnn T1+T2+T3: '+str(ArnnT1T2T3)+'\n')
- paddedPreds.write('# of Arnn T1+T3: '+str(ArnnT1T3)+'\n')
- paddedPreds.write('# of Arnn T2 only: '+str(ArnnT2)+'\n')
- paddedPreds.write('# of Arnn T2+T3: '+str(ArnnT2T3)+'\n')
- paddedPreds.write('# of Arnn T3 only: '+str(ArnnT3)+'\n')
- #Balance the datasets: (drop the last scans from the H datasets to make the A and H datasets equal. Should be different patients each time because I shuffled in get_filenames
- diff = len(rnn_HpredsT1_padArray)-len(rnn_ApredsT1_padArray)
- for i in range(diff):
- rnn_HpredsT1_padArray = np.delete(rnn_HpredsT1_padArray,-1,0)
- rnn_HpredsT2_padArray = np.delete(rnn_HpredsT2_padArray,-1,0)
- rnn_HpredsT3_padArray = np.delete(rnn_HpredsT3_padArray,-1,0)
- dummyCountHT1 = 0
- dummyCountHT2 = 0
- dummyCountHT3 = 0
- dummyCountAT1 = 0
- dummyCountAT2 = 0
- dummyCountAT3 = 0
- for i in range(len(rnn_HpredsT1_padArray)):
- if rnn_HpredsT1_padArray[i][0] == -1:
- dummyCountHT1 += 1
- if rnn_HpredsT2_padArray[i][0] == -1:
- dummyCountHT2 += 1
- if rnn_HpredsT3_padArray[i][0] == -1:
- dummyCountHT3 += 1
- for i in range(len(rnn_ApredsT1_padArray)):
- if rnn_ApredsT1_padArray[i][0] == -1:
- dummyCountAT1 += 1
- if rnn_ApredsT2_padArray[i][0] == -1:
- dummyCountAT2 += 1
- if rnn_ApredsT3_padArray[i][0] == -1:
- dummyCountAT3 += 1
- with open(model_filepath+'/figures/paddedPreds.txt','a') as paddedPreds:
- paddedPreds.write('Length of rnn_HpredsT1_padArray popped: '+str(len(rnn_HpredsT1_padArray))+'\n')
- paddedPreds.write('Length of rnn_HpredsT2_padArray popped: '+str(len(rnn_HpredsT2_padArray))+'\n')
- paddedPreds.write('Length of rnn_HpredsT3_padArray popped: '+str(len(rnn_HpredsT3_padArray))+'\n')
- with open(model_filepath+'/figures/DataList.txt','a') as datalist:
- datalist.write('Number of scans in HT1 (excluding dummies): '+str(len(rnn_HpredsT1_padArray)-dummyCountHT1)+'\n')
- datalist.write('Number of scans in HT2 (excluding dummies): '+str(len(rnn_HpredsT2_padArray)-dummyCountHT2)+'\n')
- datalist.write('Number of scans in HT3 (excluding dummies): '+str(len(rnn_HpredsT3_padArray)-dummyCountHT3)+'\n')
- datalist.write('Number of scans in AT1 (excluding dummies): '+str(len(rnn_ApredsT1_padArray)-dummyCountAT1)+'\n')
- datalist.write('Number of scans in AT2 (excluding dummies): '+str(len(rnn_ApredsT2_padArray)-dummyCountAT2)+'\n')
- datalist.write('Number of scans in AT3 (excluding dummies): '+str(len(rnn_ApredsT3_padArray)-dummyCountAT3)+'\n')
-
- #Split RNN data into train/val/test
- train_predsT1_padArray,train_predsT2_padArray,train_predsT3_padArray,val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray,test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray, train_labels_padArray,val_labels_padArray,test_labels_padArray, test_ptidT1,test_ptidT2,test_ptidT3,test_imageIDT1,test_imageIDT2,test_imageIDT3 = data_loader.split_data_RNN(rnn_HpredsT1_padArray,rnn_HpredsT2_padArray,rnn_HpredsT3_padArray,rnn_ApredsT1_padArray,rnn_ApredsT2_padArray,rnn_ApredsT3_padArray,rnn_HptidT1_padded,rnn_HptidT2_padded,rnn_HptidT3_padded,rnn_HimageIDT1_padded,rnn_HimageIDT2_padded,rnn_HimageIDT3_padded,rnn_AptidT1_padded,rnn_AptidT2_padded,rnn_AptidT3_padded,rnn_AimageIDT1_padded,rnn_AimageIDT2_padded,rnn_AimageIDT3_padded,val_split)
- #RUN THE RNN:
- netRNN = RNN_Net(params)
- historyRNN = netRNN.train(([train_predsT1_padArray,train_predsT2_padArray,train_predsT3_padArray],train_labels_padArray,[val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray],val_labels_padArray))
- #EVALUATE RNN:
- test_lossRNN, test_accRNN = netRNN.evaluate (([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
- test_predsRNN = netRNN.predict(([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
- """
- #TO LOAD A PREVIOUS MODEL INSTEAD: (uncomment this chunk and comment above chunk - all the way up to through the data prep for RNN)
- #note: this is not needed for CNN heatmaps
- #can't seem to figure out how to load the whole model (but am saving it anyway). I'm only able to save and load the weights, so note that the model needs to be recompiled, so it has to be the correct architecture
- #Also, I should check that it works by running the same test set first and making sure I get the same results
- netRNN = RNN_Net(params)
- netRNN.load_the_weights("SavedRNNWeights")
- pickle_in = open(model_filepath+'/'+picklename+'.pickle', 'rb') #change this to be the pickle filename
- pickle0=pickle.load(pickle_in)
- pickle_in.close()
- test_predsT1_padArray = pickle0[5][1]
- test_predsT2_padArray = pickle0[5][2]
- test_predsT3_padArray = pickle0[5][3]
- test_labels_padArray = pickle0[5][4]
- test_labels_padArray = np.delete(test_labels_padArray,0)
- pickle0 = 0
- print('test_labels_padArray: ',test_labels_padArray)
- test_lossRNN, test_accRNN = netRNN.evaluate(([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
- test_predsRNN = netRNN.predict(([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
- print('check_lossRNN, check_accRNN: '+ str(test_lossRNN)+', '+ str(test_accRNN))
- """
- #PLOTS FOR THE CNN ALONE
- #plot accuracy learning curves
- plt.figure()
- plt.plot(historyCNN['acc'],color='red')
- plt.plot(historyCNN['val_acc'],color='blue')
- plt.title('CNN model accuracy learning curve')
- plt.ylabel('accuracy')
- plt.xlabel('epoch')
- plt.xlabel('1 - Specificity',fontsize=20)
- plt.ylabel('Sensitivity',fontsize=20)
- plt.legend(['training', 'validation'], loc='upper left')
- plt.savefig(model_filepath+'/figures/CNN_LCacc'+str(seed)+'.png', bbox_inches='tight')
- #plot loss learning curves
- plt.figure()
- plt.plot(historyCNN['loss'],color='orange')
- plt.plot(historyCNN['val_loss'],color='purple')
- plt.title('CNN model loss learning curve')
- plt.ylabel('loss')
- plt.xlabel('epoch')
- plt.legend(['training', 'validation'], loc='upper right')
- plt.savefig(model_filepath+'/figures/CNN_LCloss'+str(seed)+'.png', bbox_inches='tight')
-
- #plot test ROC curve
- fpr_testCNN = dict()
- tpr_testCNN = dict()
- thresholds_testCNN = dict()
- areaundercurveCNN = dict()
- plt.figure()
- test_predsCNN_class = np.argmax(test_predsCNN,axis=-1)
- test_predsCNN_count = np.bincount(test_predsCNN_class, minlength=n_classes)
- print('test_labelsCNN: ', test_data[3])
- print('test_predCNNclass: ', test_predsCNN_class)
- print('test_predCNNcount: ', test_predsCNN_count)
-
- tROC = True
- for i in range(n_classes):
- if test_predsCNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
- print('Class ', i, 'is predicted 0 times in CNN testing.')
- print('Cannot plot Test ROC curve for CNN.')
- tROC = False
- break
- if tROC == True:
- if n_classes ==2:
- #fpr_testCNN, tpr_testCNN, thresholds_testCNN = roc_curve(np.array(pd.get_dummies(test_data[3]))[:,1], np.array(test_predsCNN)[:,1])
- fpr_testCNN, tpr_testCNN, thresholds_testCNN = roc_curve(test_data[3], test_predsCNN[:,1])
- areaundercurveCNN = auc(fpr_testCNN,tpr_testCNN)
- lw = 3
- class_name = ['AD','Healthy']
- plt.plot(fpr_testCNN, tpr_testCNN, lw=lw)
- plt.title('CNN ROC')
- plt.xlabel('1 - Specificity',fontsize=13)
- plt.ylabel('Sensitivity',fontsize=13)
- else:
- for i in range(n_classes):
- fpr_testCNN[i], tpr_testCNN[i], thresholds_testCNN[i] = roc_curve(np.array(pd.get_dummies(test_data[3]))[:,i], np.array(test_predsCNN)[:,i])
- areaundercurveCNN[i] = auc(fpr_testCNN[i],tpr_testCNN[i])
- lw = 3
- class_name = ['AD','Healthy']
- plt.plot(fpr_testCNN[i], tpr_testCNN[i],
- lw=lw, label=str(class_name[i]))
- plt.title('CNN ROC')
- plt.xlabel('1 - Specificity',fontsize=13)
- plt.ylabel('Sensitivity',fontsize=13)
- if tROC==True: #skips ROC curve and TPRs for situation where one class is never predicted
- #plot testROC
- plt.legend(loc="lower right")
- plt.savefig(model_filepath+'/figures/CNN_ROC'+str(seed)+'.png', bbox_inches='tight')
- #print TPRs for each class
- #print('TPR_AD_CNN = '+str(tpr_testCNN[0]))
- #print('TPR_Healthy_CNN = '+str(tpr_testCNN[1]))
-
- #Confusion matrix
- mci_conf_matrix_testCNN = confusion_matrix(y_true = test_data[3], y_pred = np.round(test_predsCNN_class))
- plt.figure()
- ax = plt.subplot()
- cax = ax.matshow(mci_conf_matrix_testCNN)
- plt.title('Full CNN T1 Confusion Matrix')
- plt.colorbar(cax)
- ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
- ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
- plt.xlabel('Predicted',fontsize=13)
- plt.ylabel('True',fontsize=13)
- plt.savefig(model_filepath+'/figures/CNN_ConfMatrix'+str(seed)+'.png', bbox_inches='tight')
- #Normalized confusion matrix
- mci_conf_matrix_test_normedCNN = mci_conf_matrix_testCNN/(mci_conf_matrix_testCNN.sum(axis=1)[:,np.newaxis])
- plt.figure()
- ax = plt.subplot()
- cax = ax.matshow(mci_conf_matrix_test_normedCNN)
- plt.title('Full CNN T1 Normalized Confusion Matrix')
- plt.colorbar(cax)
- ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
- ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
- plt.xlabel('Predicted',fontsize=13)
- plt.ylabel('True',fontsize=13)
- plt.savefig(model_filepath+'/figures/CNN_ConfMatrixNormed'+str(seed)+'.png', bbox_inches='tight')
- #validation ROC
- val_lossCNN, val_accCNN = netCNN.evaluate ((val_data))
- val_predsCNN = netCNN.predict((val_data))
- val_predsCNN_class = np.argmax(val_predsCNN,axis=-1)
- fpr_valCNN = dict()
- tpr_valCNN = dict()
- thresholds_valCNN = dict()
- val_predsCNN_count = np.bincount(val_predsCNN_class, minlength=n_classes)
- print('val_predsCNN_count: ', val_predsCNN_count)
- vROC = True
- for i in range(n_classes):
- if val_predsCNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
- print('Class ', i, 'is predicted 0 times in CNN validation.')
- print('Cannot plot vROC curve for CNN.')
- vROC = False
- break
- if vROC == True:
- if n_classes ==2:
- fpr_valCNN, tpr_valCNN, thresholds_valCNN = roc_curve(np.array(pd.get_dummies(val_data[3]))[:,1], np.array(val_predsCNN)[:,1])
- else:
- fpr_valCNN[i], tpr_valCNN[i], thresholds_valCNN[i] = roc_curve(np.array(pd.get_dummies(val_data[3]))[:,i], np.array(val_predsCNN)[:,i])
- mci_conf_matrix_valCNN = confusion_matrix(y_true = val_data[3], y_pred = np.round(val_predsCNN_class))
- mci_conf_matrix_val_normedCNN = mci_conf_matrix_valCNN/(mci_conf_matrix_valCNN.sum(axis=1)[:,np.newaxis])
- print("Test CNN accuracy: "+str(test_accCNN))
- print("CNN AUC: " +str(areaundercurveCNN))
-
- #PLOTS FOR THE RNN
- #plot accuracy learning curves
- plt.figure()
- plt.plot(historyRNN['acc'],color='red')
- plt.plot(historyRNN['val_acc'],color='blue')
- plt.title('RNN model accuracy learning curve')
- plt.ylabel('accuracy')
- plt.xlabel('epoch')
- plt.legend(['training', 'validation'], loc='upper left')
- plt.savefig(model_filepath+'/figures/RNN_LCacc'+str(seed)+'.png', bbox_inches='tight')
- #plot loss learning curves
- plt.figure()
- plt.plot(historyRNN['loss'],color='orange')
- plt.plot(historyRNN['val_loss'],color='purple')
- plt.title('RNN model loss learning curve')
- plt.ylabel('loss')
- plt.xlabel('epoch')
- plt.legend(['training', 'validation'], loc='upper right')
- plt.savefig(model_filepath+'/figures/RNN_LCloss'+str(seed)+'.png', bbox_inches='tight')
-
- #plot 2-class test ROC curve
- fpr_testRNN = dict()
- tpr_testRNN = dict()
- thresholds_testRNN = dict()
- areaundercurveRNN = dict()
- plt.figure()
- test_predsRNN_class = np.argmax(test_predsRNN,axis=-1)
- test_predsRNN_count = np.bincount(test_predsRNN_class, minlength=n_classes)
- print('test_labelsRNN: ', test_labels_padArray)
- print('test_predsRNN_class: ', test_predsRNN_class)
- print('test_predsRNN_count: ', test_predsRNN_count)
-
- tROC = True
- for i in range(n_classes):
- if test_predsRNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
- print('Class ', i, 'is predicted 0 times in RNN testing.')
- print('Cannot plot Test ROC curve for RNN.')
- tROC = False
- break
- if tROC == True:
- if n_classes ==2:
- fpr_testRNN, tpr_testRNN, thresholds_testRNN = roc_curve(np.array(pd.get_dummies(test_labels_padArray))[:,i], np.array(test_predsRNN)[:,1]) #changed first 1 from i
- areaundercurveRNN = auc(fpr_testRNN,tpr_testRNN)
- lw = 3
- class_name = ['AD','Healthy']
- plt.plot(fpr_testRNN, tpr_testRNN, lw=lw)
- plt.title('RNN ROC')
- plt.xlabel('1 - Specificity',fontsize=13)
- plt.ylabel('Sensitivity',fontsize=13)
- else:
- for i in range(n_classes):
- fpr_testRNN[i], tpr_testRNN[i], thresholds_testRNN[i] = roc_curve(np.array(pd.get_dummies(test_labels_padArray))[:,i], np.array(test_predsRNN)[:,i])
- areaundercurveRNN[i] = auc(fpr_testRNN[i],tpr_testRNN[i])
- lw = 3
- class_name = ['AD','Healthy']
- plt.plot(fpr_testRNN[i], tpr_testRNN[i],
- lw=lw, label=str(class_name[i]))
- plt.title('RNN ROC')
- plt.xlabel('1 - Specificity',fontsize=13)
- plt.ylabel('Sensitivity',fontsize=13)
- if tROC==True: #skips ROC curve and TPRs for situation where one class is never predicted
- #plot testROC
- plt.legend(loc="lower right")
- plt.savefig(model_filepath+'/figures/RNN_ROC'+str(seed)+'.png', bbox_inches='tight')
- #print TPRs for each class
- #print('TPR_AD_RNN = '+str(tpr_testRNN[0]))
- #print('TPR_Healthy_RNN = '+str(tpr_testRNN[1]))
-
- #Confusion matrix
- mci_conf_matrix_testRNN = confusion_matrix(y_true = test_labels_padArray, y_pred = np.round(test_predsRNN_class))
- plt.figure()
- ax = plt.subplot()
- cax = ax.matshow(mci_conf_matrix_testRNN)
- plt.title('RNN Confusion Matrix')
- plt.colorbar(cax)
- ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
- ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
- plt.xlabel('Predicted',fontsize=13)
- plt.ylabel('True',fontsize=13)
- plt.savefig(model_filepath+'/figures/RNN_ConfMatrix'+str(seed)+'.png', bbox_inches='tight')
- #Normalized confusion matrix
- mci_conf_matrix_test_normedRNN = mci_conf_matrix_testRNN/(mci_conf_matrix_testRNN.sum(axis=1)[:,np.newaxis])
- plt.figure()
- ax = plt.subplot()
- cax = ax.matshow(mci_conf_matrix_test_normedRNN)
- plt.title('RNN Normalized Confusion Matrix')
- plt.colorbar(cax)
- ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
- ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
- plt.xlabel('Predicted',fontsize=13)
- plt.ylabel('True',fontsize=13)
- plt.savefig(model_filepath+'/figures/RNN_ConfMatrixNormed'+str(seed)+'.png', bbox_inches='tight')
-
- #validation ROC
- val_lossRNN, val_accRNN = netRNN.evaluate (([val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray],val_labels_padArray))
- val_predsRNN = netRNN.predict(([val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray],val_labels_padArray))
- val_predsRNN_class = np.argmax(val_predsRNN,axis=-1)
- fpr_valRNN = dict()
- tpr_valRNN = dict()
- thresholds_valRNN = dict()
- val_predsRNN_count = np.bincount(val_predsRNN_class, minlength=n_classes)
- print('val_predsRNN_count: ', val_predsRNN_count)
- vROC = True
- for i in range(n_classes):
- if val_predsRNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
- print('Class ', i, 'is predicted 0 times in RNN validation.')
- print('Cannot plot vROC curve for RNN.')
- vROC = False
- break
- if vROC==True:
- if n_classes == 2:
- fpr_valRNN, tpr_valRNN, thresholds_valRNN = roc_curve(np.array(pd.get_dummies(val_labels_padArray))[:,1], np.array(val_predsRNN)[:,1])
- else:
- for i in range(n_classes):
- fpr_valRNN[i], tpr_valRNN[i], thresholds_valRNN[i] = roc_curve(np.array(pd.get_dummies(val_labels_padArray))[:,i], np.array(val_predsRNN)[:,i])
- mci_conf_matrix_valRNN = confusion_matrix(y_true = val_labels_padArray, y_pred = np.round(val_predsRNN_class))
- mci_conf_matrix_val_normedRNN = mci_conf_matrix_valRNN/(mci_conf_matrix_valRNN.sum(axis=1)[:,np.newaxis])
- print("Test RNN accuracy: "+str(test_accRNN))
- print("RNN AUC: " +str(areaundercurveRNN))
-
- #TEST SET TABLES
- test_table_CNN = (test_data[4],test_data[5],test_data[3],test_data[6],test_data[7],test_predsCNN_class,test_predsCNN[0],test_predsCNN[1])
- test_table_RNN = (test_ptidT1,test_imageIDT1,test_imageIDT2,test_imageIDT3,test_labels_padArray,test_predsRNN_class,test_predsRNN[0],test_predsRNN[1])
-
- #WRITE THE OUTPUT FILE
- with open(model_filepath+'/figures/Outputs'+str(seed)+'.txt','w') as outputs:
- #RNN
- outputs.write('RNN Confusion Matrix Values:'+'\n')
- outputs.write(str(mci_conf_matrix_testRNN)+'\n')
- outputs.write('RNN Normalized Confusion Matrix Values:'+'\n')
- outputs.write(str(mci_conf_matrix_test_normedRNN)+'\n')
- outputs.write('RNN Test accuracy:'+'\n')
- outputs.write(str(test_accRNN)+'\n')
- outputs.write('RNN AUC:'+'\n')
- outputs.write(str(areaundercurveRNN) +'\n')
- outputs.write('RNN Test Predictions Probabilities'+'\n')
- outputs.write(str(test_predsRNN) +'\n')
- outputs.write('RNN Test Predictions MaxProb Class'+'\n')
- outputs.write(str(test_predsRNN_class) +'\n')
- #CNN
- outputs.write('Full CNN Confusion Matrix Values:'+'\n')
- outputs.write(str(mci_conf_matrix_testCNN)+'\n')
- outputs.write('Full CNN Normalized Confusion Matrix Values:'+'\n')
- outputs.write(str(mci_conf_matrix_test_normedCNN)+'\n')
- outputs.write('Full CNN Test accuracy:'+'\n')
- outputs.write(str(test_accCNN)+'\n')
- outputs.write('Full CNN AUC:'+'\n')
- outputs.write(str(areaundercurveCNN) +'\n')
- outputs.write('Full CNN Test Predictions Probabilities'+'\n')
- outputs.write(str(test_predsCNN) +'\n')
- outputs.write('Full CNN Test Predictions MaxProb Class'+'\n')
- outputs.write(str(test_predsCNN_class) +'\n')
- #outputs.write('Index of best CNN Gmean'+'\n')
- #outputs.write(str(ixC) +'\n')
- #outputs.write('Optimal CNN Threshold'+'\n')
- #outputs.write(str(bestThreshCNN) +'\n')
- #outputs.write('Value of highest Gmean'+'\n')
- #outputs.write(str(highGmeanCNN) +'\n')
- #outputs.write('CNN Accuracy at Optimized Threshold'+'\n')
- #outputs.write(str(OptAccCNN) +'\n'+'\n')
- #Testset output tables
- outputs.write('test_table_CNN'+'\n')
- outputs.write(str(test_table_CNN)+'\n'+'\n')
- outputs.write('test_table_RNN'+'\n')
- outputs.write(str(test_table_RNN)+'\n'+'\n')
- #TEST SET TABLES
- Cptid = test_data[4]
- Cptid = np.insert(Cptid,0,'PTID')
- CimageID = test_data[5]
- CimageID = np.insert(CimageID,0,'imgID')
- Cconfid = test_data[6]
- Cconfid = np.insert(Cconfid,0,'DxConfidence')
- Ccsf = test_data[7]
- Ccsf = np.insert(Ccsf,0,'CSF_Path')
- Clabels = test_data[3]
- Clabels = np.insert(Clabels.astype(str),0,'label')
- test_predsCNN_class = np.insert(test_predsCNN_class.astype(str),0,'prediction')
- probsCAD = [item[0] for item in test_predsCNN]
- probsCNC = [item[1] for item in test_predsCNN]
- probsCAD.insert(0,'prediction probabilities AD')
- probsCNC.insert(0,'prediction probabilities NC')
- test_ptidT1 = np.insert(test_ptidT1,0,'PTID')
- test_imageIDT1 = np.insert(test_imageIDT1,0,'imIDT1')
- test_imageIDT2 = np.insert(test_imageIDT2,0,'imIDT2')
- test_imageIDT3 = np.insert(test_imageIDT3,0,'imIDT3')
- test_labels_padArray = np.insert(test_labels_padArray.astype(str),0,'label')
- test_predsRNN_class = np.insert(test_predsRNN_class.astype(str),0,'prediction')
- probsRAD = [item[0] for item in test_predsRNN]
- probsRNC = [item[1] for item in test_predsRNN]
- probsRAD.insert(0,'prediction probabilities AD')
- probsRNC.insert(0,'prediction probabilities NC')
- with open(model_filepath+'/figures/test_table_'+str(seed)+'.csv','w') as Testcsv:
- Testcsv_writer = csv.writer(Testcsv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
- Testcsv_writer.writerow(['CNN'])
- Testcsv_writer.writerow(Cptid)
- Testcsv_writer.writerow(CimageID)
- Testcsv_writer.writerow(Clabels)
- Testcsv_writer.writerow(test_predsCNN_class)
- Testcsv_writer.writerow(Cconfid)
- Testcsv_writer.writerow(Ccsf)
- Testcsv_writer.writerow(probsCAD)
- Testcsv_writer.writerow(probsCNC)
- Testcsv_writer.writerow(' ')
- Testcsv_writer.writerow(' ')
- Testcsv_writer.writerow(['RNN'])
- Testcsv_writer.writerow(test_ptidT1)
- Testcsv_writer.writerow(test_imageIDT1)
- Testcsv_writer.writerow(test_imageIDT2)
- Testcsv_writer.writerow(test_imageIDT3)
- Testcsv_writer.writerow(test_labels_padArray)
- Testcsv_writer.writerow(test_predsRNN_class)
- Testcsv_writer.writerow(probsRAD)
- Testcsv_writer.writerow(probsRNC)
- """
- #HEATMAPS:
- HeatmapPlotter = heatmapPlotter(seed)
-
- #if plotting NC data (need to load AD avg):
- picklename = 'loadedGC4912' #CHANGE to LRP or GGC and change seed
- pickle_in = open(model_filepath+'/'+picklename+'.pickle', 'rb')
- pickle0=pickle.load(pickle_in)
- pickle_in.close()
- mean_map_AD = pickle0[0]["AD"]
- pickle0=0
-
- #if plotting AD data:
- mean_map_AD = np.zeros((91,109,91))
-
- #RUN LRP
- case_maps_LRP, counts = HeatmapPlotter.LRP(test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN) #Removed CNN_LRP to save memory
- mean_maps_LRP = HeatmapPlotter.plot_avg_maps(case_maps_LRP, counts, 'LRP', test_mri_nonorm, model_filepath, mean_map_AD)
- #WRITE A PICKLE FILE
- with open(model_filepath+'/figures/loadedLRP' + str(seed)+'.pickle', 'wb') as f:
- pickle.dump([mean_maps_LRP, case_maps_LRP, counts], f) #Removed CNN_LRP to save memory
-
- #RUN GGC
- case_maps_GGC, counts = HeatmapPlotter.GuidedGradCAM(test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN) #Removed CNN_gradcam, CNN_gb, CNN_guided_gradcam to save memory
- mean_maps_GGC = HeatmapPlotter.plot_avg_maps(case_maps_GGC, counts, 'Guided GradCAM', test_mri_nonorm, model_filepath, mean_map_AD)
- #WRITE A PICKLE FILE
- with open(model_filepath+'/figures/loadedGC' + str(seed)+'.pickle', 'wb') as f:
- pickle.dump([mean_maps_GGC, case_maps_GGC, counts], f) #Removed CNN_gradcam, CNN_gb, CNN_guided_gradcam to save memory
- """
- #WRITE THE PICKLE FILE
- with open(model_filepath+'/figures/' + str(seed)+'.pickle', 'wb') as f:
- pickle.dump([[fpr_testRNN, tpr_testRNN, thresholds_testRNN, test_lossRNN, test_accRNN, mci_conf_matrix_testRNN, mci_conf_matrix_test_normedRNN, test_predsRNN, test_predsRNN_class, test_labels_padArray, val_predsRNN, val_labels_padArray ],
- [fpr_valRNN, tpr_valRNN, thresholds_valRNN, val_lossRNN, val_accRNN, mci_conf_matrix_valRNN, mci_conf_matrix_val_normedRNN],
- [fpr_testCNN, tpr_testCNN, thresholds_testCNN, test_lossCNN, test_accCNN, mci_conf_matrix_testCNN, mci_conf_matrix_test_normedCNN, test_predsCNN, test_predsCNN_class, test_labels_padArray, val_predsCNN, val_labels_padArray ],
- [fpr_valCNN, tpr_valCNN, thresholds_valCNN, val_lossCNN, val_accCNN, mci_conf_matrix_valCNN, mci_conf_matrix_val_normedCNN],
- [test_table_CNN,test_table_RNN],
- [test_data, test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray,test_labels_padArray]],f)
- #[CNN_LRP, mean_maps_LRP],
- #[CNN_gradcam, CNN_gb, CNN_guided_gradcam, mean_maps_GGC]], f)
-
- #For Recovery of old runs
- # pickle_in = open('/data_wnx1/_Data/AlzheimersDL/MCI-spasov-3class/pickles/' + str(seed)+'.pickle', 'rb')
- # seedpickle=pickle.load(pickle_in)
- # print(seedpickle)
-
- #Plots to potentially add in the future
- #add sMCI+pMCI ROC curve
- # mci_fpr_test = fpr_test[1]+fpr_test[2]
- # mci_tpr_test = tpr_test[1]+tpr_test[2]
- # areaundercurve['all_mci'] = auc(mci_fpr_test,mci_tpr_test)
- # plt.plot(mci_fpr_test, mci_tpr_test,
- # lw=lw, label='All MCI')
-
- #Calculate macros
- # all_fpr_test = np.unique(np.concatenate([fpr_test[i] for i in range(n_classes)]))
- # mean_tpr_test = np.zeros_like(all_fpr_test)
- # for i in range(n_classes):
- # mean_tpr_test += interp(all_fpr_test, fpr_test[i], tpr_test[i])
- # mean_tpr_test /= n_classes
- # fpr_test["macro"] = all_fpr_test
- # tpr_test["macro"] = mean_tpr_test
- # areaundercurve["macro"] = auc(fpr_test["macro"], tpr_test["macro"])
-
- # plt.plot(fpr_test["macro"], tpr_test["macro"],label='ROC curve of macro-average')
- # plt.legend(loc="lower right")
- # plt.savefig(model_filepath+'/figures/ROC'+str(seed)+'.png', bbox_inches='tight')
-
- #Calculate micros
- # test_data_bin = label_binarize(test_data[-1], classes=[0, 1, 2])
- # test_preds_class_bin = label_binarize(test_preds_class, classes=[0, 1, 2])
- # fpr_test["micro"], tpr_test["micro"], _ = roc_curve(test_data_bin.ravel(), test_preds_class_bin.ravel())
- # areaundercurve["micro"] = auc(fpr_test["micro"], tpr_test["micro"])
-
- # plt.plot(fpr_test["macro"], tpr_test["macro"],label='ROC curve of micro-average')
- # plt.legend(loc="lower right")
- # plt.savefig(model_filepath+'/figures/ROC'+str(seed)+'.png', bbox_inches='tight')
-
- #RUN IT!
- for seed in seeds:
- #Load data
- print('Processing seed number ', seed)
- # data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
- # train_data, val_data, test_data, healthy_dict = data_loader.get_train_val_test()
- # print('length of train data '+str((train_data)),'; length of val data '+str((val_data)),'; test data '+str((test_data)))
- # print('length of healthy_dict[mri] '+str(len(healthy_dict)))
- evaluate_net(seed)
-
-
|