mci_train.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007
  1. import numpy as np
  2. from sklearn.metrics import roc_curve, auc
  3. from sklearn.metrics import confusion_matrix
  4. from sklearn.preprocessing import label_binarize
  5. from keras import regularizers
  6. import pickle as pickle
  7. from utils.preprocess import DataLoader
  8. from utils.models import Parameters, CNN_Net, RNN_Net
  9. from utils.heatmapPlotting import heatmapPlotter
  10. from matplotlib import pyplot as plt
  11. import pandas as pd
  12. from scipy import interp
  13. from keras.models import Model, load_model#, load_weights
  14. from keras.layers import Input
  15. from keras.optimizers import Adam
  16. import tensorflow as tf
  17. from IPython.display import Image
  18. import matplotlib.cm as cm
  19. import SimpleITK as sitk
  20. import csv
  21. from copy import deepcopy
  22. import matplotlib.colors as mcolors
  23. import nibabel as nib
  24. import math
  25. import sys
  26. sys.path.append('//data/data_wnx3/data_wnx1/rschuurs/CNN+RNN-2class-1cnn-CLEAN/utils')
  27. from sepconv3D import SeparableConv3D
  28. ##for 2 class CNN + RNN ##
  29. #Dummy feature vectors are added to feature vectors from CNN (which are fed only the images)
  30. import os
  31. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  32. os.environ["CUDA_VISIBLE_DEVICES"] = "0" # use id from $ nvidia-smi
  33. target_rows = 91
  34. target_cols = 109
  35. depth = 91
  36. axis = 1
  37. num_clinical = 2
  38. CNN_drop_rate = 0.3
  39. RNN_drop_rate = 0.1
  40. CNN_w_regularizer = regularizers.l2(2e-2)
  41. RNN_w_regularizer = regularizers.l2(1e-6)
  42. CNN_batch_size = 10
  43. RNN_batch_size = 5
  44. val_split = 0.2
  45. optimizer = Adam(lr=1e-5)
  46. final_layer_size = 5
  47. model_filepath = '//data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data'
  48. mri_datapath = '//data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/ADNI_volumes_customtemplate_float32'
  49. params_dict = { 'CNN_w_regularizer': CNN_w_regularizer, 'RNN_w_regularizer': RNN_w_regularizer,
  50. 'CNN_batch_size': CNN_batch_size, 'RNN_batch_size': RNN_batch_size,
  51. 'CNN_drop_rate': CNN_drop_rate, 'epochs': 30,
  52. 'gpu': "/gpu:0", 'model_filepath': model_filepath,
  53. 'image_shape': (target_rows, target_cols, depth, axis),
  54. 'num_clinical': num_clinical,
  55. 'final_layer_size': final_layer_size,
  56. 'optimizer': optimizer, 'RNN_drop_rate': RNN_drop_rate,}
  57. params = Parameters(params_dict)
  58. seeds = [np.random.randint(1, 5000) for _ in range(1)]
  59. def evaluate_net (seed):
  60. n_classes = 2
  61. data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
  62. train_data, val_data, test_data,rnn_HdataT1,rnn_HdataT2,rnn_HdataT3,rnn_AdataT1,rnn_AdataT2,rnn_AdataT3, test_mri_nonorm = data_loader.get_train_val_test(val_split, mri_datapath)
  63. print('Length Val Data[0]: ',len(val_data[0]))
  64. #RUN THE CNN:
  65. netCNN = CNN_Net(params)
  66. historyCNN, featuresModel_CNN = netCNN.train((train_data, val_data))
  67. test_lossCNN, test_accCNN = netCNN.evaluate(test_data)
  68. test_predsCNN = netCNN.predict(test_data)
  69. """
  70. #TO LOAD A PREVIOUS MODEL FOR HEATMAPS: (uncomment this chunk and comment above chunk) #Then add the pickle file and savedWeights to the modelfilepath folder
  71. #can't seem to figure out how to load the whole model (but am saving it anyway). I'm only able to save and load the weights, so note that the model needs to be recompiled, so it has to be the correct architecture
  72. picklename = '1820'
  73. netCNN = CNN_Net(params)
  74. netCNN.load_the_weights("SavedCNNWeights")
  75. pickle_in = open(model_filepath+'/'+picklename+'.pickle', 'rb')
  76. pickle0=pickle.load(pickle_in)
  77. pickle_in.close()
  78. test_data = pickle0[5][0]
  79. pickle0 = 0 #to save memory
  80. test_lossCNN, test_accCNN = netCNN.evaluate(test_data)
  81. test_predsCNN = netCNN.predict(test_data)
  82. print('check_lossCNN, check_accCNN: '+ str(test_lossCNN)+', '+ str(test_accCNN))
  83. """
  84. ##PREP DATA FOR THE RNN
  85. #Get the feature vectors from the final layer for each training image at each timepoint:
  86. rnn_HpredsT1 = featuresModel_CNN.predict([rnn_HdataT1[0],rnn_HdataT1[1],rnn_HdataT1[2]])
  87. rnn_HpredsT2 = featuresModel_CNN.predict([rnn_HdataT2[0],rnn_HdataT2[1],rnn_HdataT2[2]])
  88. rnn_HpredsT3 = featuresModel_CNN.predict([rnn_HdataT3[0],rnn_HdataT3[1],rnn_HdataT3[2]])
  89. rnn_ApredsT1 = featuresModel_CNN.predict([rnn_AdataT1[0],rnn_AdataT1[1],rnn_AdataT1[2]])
  90. rnn_ApredsT2 = featuresModel_CNN.predict([rnn_AdataT2[0],rnn_AdataT2[1],rnn_AdataT2[2]])
  91. rnn_ApredsT3 = featuresModel_CNN.predict([rnn_AdataT3[0],rnn_AdataT3[1],rnn_AdataT3[2]])
  92. #grab the PTIDs for each dataset
  93. rnn_HptidT1 = rnn_HdataT1[4]
  94. rnn_HptidT2 = rnn_HdataT2[4]
  95. rnn_HptidT3 = rnn_HdataT3[4]
  96. rnn_AptidT1 = rnn_AdataT1[4]
  97. rnn_AptidT2 = rnn_AdataT2[4]
  98. rnn_AptidT3 = rnn_AdataT3[4]
  99. #grab the imageIDs for each dataset
  100. rnn_HimageIDT1 = rnn_HdataT1[5]
  101. rnn_HimageIDT2 = rnn_HdataT2[5]
  102. rnn_HimageIDT3 = rnn_HdataT3[5]
  103. rnn_AimageIDT1 = rnn_AdataT1[5]
  104. rnn_AimageIDT2 = rnn_AdataT2[5]
  105. rnn_AimageIDT3 = rnn_AdataT3[5]
  106. #add dummy feature vectors to all missing timepoints
  107. dummyVector = np.full((final_layer_size),-1)
  108. #Healthy patients
  109. rnn_HpredsT1_padded = []
  110. rnn_HpredsT2_padded = []
  111. rnn_HpredsT3_padded = []
  112. rnn_HptidT1_padded = []
  113. rnn_HptidT2_padded = []
  114. rnn_HptidT3_padded = []
  115. rnn_HimageIDT1_padded = []
  116. rnn_HimageIDT2_padded = []
  117. rnn_HimageIDT3_padded = []
  118. j=0
  119. HrnnT1T2T3 = 0
  120. HrnnT1T2 = 0
  121. HrnnT1T3 = 0
  122. HrnnT1 = 0
  123. HrnnT2 = 0
  124. HrnnT2T3 = 0
  125. HrnnT3 = 0
  126. HrnnT1Removed = 0
  127. for ptidT1 in rnn_HptidT1:
  128. rnn_HpredsT1_padded.append(rnn_HpredsT1[j])
  129. rnn_HptidT1_padded.append(ptidT1)
  130. rnn_HimageIDT1_padded.append(rnn_HimageIDT1[j])
  131. j+=1
  132. c=0
  133. k=0
  134. t2 = False
  135. t3 = False
  136. for ptidT2 in rnn_HptidT2:
  137. c+=1
  138. if ptidT1 == ptidT2:
  139. rnn_HpredsT2_padded.append(rnn_HpredsT2[c-1])
  140. rnn_HptidT2_padded.append(ptidT2)
  141. rnn_HimageIDT2_padded.append(rnn_HimageIDT2[c-1])
  142. t2 = True
  143. for ptidT3 in rnn_HptidT3:
  144. k+=1
  145. if ptidT1 == ptidT3:
  146. rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
  147. rnn_HptidT3_padded.append(ptidT3)
  148. rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
  149. HrnnT1T2T3+=1
  150. t3 = True
  151. break
  152. if t3 == False:
  153. rnn_HpredsT3_padded.append(dummyVector)
  154. rnn_HptidT3_padded.append(ptidT1)
  155. rnn_HimageIDT3_padded.append('dummy')
  156. HrnnT1T2+=1
  157. break
  158. if t2 == False:
  159. rnn_HpredsT2_padded.append(dummyVector)
  160. rnn_HptidT2_padded.append(ptidT1)
  161. rnn_HimageIDT2_padded.append('dummy')
  162. for ptidT3 in rnn_HptidT3:
  163. k+=1
  164. if ptidT1 == ptidT3:
  165. rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
  166. rnn_HptidT3_padded.append(ptidT3)
  167. rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
  168. HrnnT1T3+=1
  169. t3 = True
  170. break
  171. if t3 == False:
  172. #rnn_HpredsT3_padded.append(dummyVector)
  173. HrnnT1+=1
  174. rnn_HpredsT1_padded.pop(-1) #remove any scans that have only T1
  175. rnn_HpredsT2_padded.pop(-1)
  176. rnn_HptidT1_padded.pop(-1)
  177. rnn_HptidT2_padded.pop(-1)
  178. rnn_HimageIDT1_padded.pop(-1)
  179. rnn_HimageIDT2_padded.pop(-1)
  180. HrnnT1Removed+=1
  181. c=0
  182. for ptidT2 in rnn_HptidT2:
  183. c+=1
  184. j=0
  185. k=0
  186. match = False
  187. t3=False
  188. for ptidT1 in rnn_HptidT1:
  189. j+=1
  190. if ptidT2 == ptidT1:
  191. match = True
  192. if match == False:
  193. rnn_HpredsT2_padded.append(rnn_HpredsT2[c-1])
  194. rnn_HpredsT1_padded.append(dummyVector)
  195. rnn_HptidT2_padded.append(ptidT2)
  196. rnn_HimageIDT2_padded.append(rnn_HimageIDT2[c-1])
  197. rnn_HptidT1_padded.append(ptidT1)
  198. rnn_HimageIDT1_padded.append('dummy')
  199. for ptidT3 in rnn_HptidT3:
  200. k+=1
  201. if ptidT2 == ptidT3:
  202. rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
  203. rnn_HptidT3_padded.append(ptidT2)
  204. rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
  205. t3 = True
  206. HrnnT2T3+=1
  207. break
  208. if t3 == False:
  209. rnn_HpredsT3_padded.append(dummyVector)
  210. rnn_HptidT3_padded.append(ptidT1)
  211. rnn_HimageIDT3_padded.append('dummy')
  212. HrnnT2+=1
  213. k=0
  214. for ptidT3 in rnn_HptidT3:
  215. k+=1
  216. j=0
  217. c=0
  218. match1 = False
  219. for ptidT1 in rnn_HptidT1:
  220. j+=1
  221. if ptidT3 == ptidT1:
  222. match1 = True
  223. # if match1 == True:
  224. # break
  225. if match1 == False:
  226. match2 = False
  227. for ptidT2 in rnn_HptidT2:
  228. c+=1
  229. if ptidT3 == ptidT2:
  230. match2 = True
  231. # if match2 == True:
  232. # break
  233. if match2 == False:
  234. rnn_HpredsT3_padded.append(rnn_HpredsT3[k-1])
  235. rnn_HptidT3_padded.append(ptidT3)
  236. rnn_HimageIDT3_padded.append(rnn_HimageIDT3[k-1])
  237. rnn_HpredsT1_padded.append(dummyVector)
  238. rnn_HptidT1_padded.append(ptidT1)
  239. rnn_HimageIDT1_padded.append('dummy')
  240. rnn_HpredsT2_padded.append(dummyVector)
  241. rnn_HptidT2_padded.append(ptidT1)
  242. rnn_HimageIDT2_padded.append('dummy')
  243. HrnnT3+=1
  244. #move the data from a list to an array
  245. j=0
  246. c=0
  247. k=0
  248. LenPadded = len(rnn_HpredsT1_padded)
  249. rnn_HpredsT1_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
  250. rnn_HpredsT2_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
  251. rnn_HpredsT3_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
  252. for vector in rnn_HpredsT1_padded:
  253. rnn_HpredsT1_padArray[j] = vector
  254. j+=1
  255. for vector in rnn_HpredsT2_padded:
  256. rnn_HpredsT2_padArray[c] = vector
  257. c+=1
  258. for vector in rnn_HpredsT3_padded:
  259. rnn_HpredsT3_padArray[k] = vector
  260. k+=1
  261. with open(model_filepath+'/figures/paddedPreds.txt','w') as paddedPreds:
  262. paddedPreds.write('Train Preds Sizes: '+'\n')
  263. paddedPreds.write('Type of rnn_HpredsT1: '+str(type(rnn_HpredsT1))+'\n')
  264. paddedPreds.write('Type of rnn_HpredsT1_padded: '+str(type(rnn_HpredsT1_padded))+'\n')
  265. paddedPreds.write('Type of rnn_HpredsT1_padArray: '+str(type(rnn_HpredsT1_padArray))+'\n')
  266. paddedPreds.write('Type of rnn_HpredsT1 elements: '+str(type(rnn_HpredsT1[0]))+'\n')
  267. paddedPreds.write('Type of rnn_HpredsT1_padded elements: '+str(type(rnn_HpredsT1_padded[0]))+'\n')
  268. paddedPreds.write('Type of rnn_HpredsT1_padArray elements: '+str(type(rnn_HpredsT1_padArray[0]))+'\n')
  269. paddedPreds.write('Length of rnn_HpredsT1: '+str(len(rnn_HpredsT1))+'\n')
  270. paddedPreds.write('Length of rnn_HpredsT1_padded: '+str(len(rnn_HpredsT1_padded))+'\n')
  271. paddedPreds.write('Length of rnn_HpredsT2_padded: '+str(len(rnn_HpredsT2_padded))+'\n')
  272. paddedPreds.write('Length of rnn_HpredsT3_padded: '+str(len(rnn_HpredsT3_padded))+'\n')
  273. paddedPreds.write('Length of rnn_HpredsT1_padArray: '+str(len(rnn_HpredsT1_padArray))+'\n')
  274. paddedPreds.write('Length of rnn_HpredsT2_padArray: '+str(len(rnn_HpredsT2_padArray))+'\n')
  275. paddedPreds.write('Length of rnn_HpredsT3_padArray: '+str(len(rnn_HpredsT3_padArray))+'\n')
  276. paddedPreds.write('Length of rnn_HptidT1_padded: '+str(len(rnn_HptidT1_padded))+'\n')
  277. paddedPreds.write('Length of rnn_HptidT2_padded: '+str(len(rnn_HptidT2_padded))+'\n')
  278. paddedPreds.write('Length of rnn_HptidT3_padded: '+str(len(rnn_HptidT3_padded))+'\n')
  279. paddedPreds.write('Length of rnn_HimageIDT1_padded: '+str(len(rnn_HimageIDT1_padded))+'\n')
  280. paddedPreds.write('Length of rnn_HimageIDT2_padded: '+str(len(rnn_HimageIDT2_padded))+'\n')
  281. paddedPreds.write('Length of rnn_HimageIDT3_padded: '+str(len(rnn_HimageIDT3_padded))+'\n')
  282. paddedPreds.write('RNN_HpredsT1_padded: '+str(rnn_HpredsT1_padded)+'\n')
  283. paddedPreds.write('Compare to RNN_HpredsT1: '+str(rnn_HpredsT1)+'\n')
  284. paddedPreds.write('RNN_HpredsT1_padArray: '+str(rnn_HpredsT1_padArray)+'\n')
  285. paddedPreds.write('RNN_HpredsT2_padArray: '+str(rnn_HpredsT2_padArray)+'\n')
  286. paddedPreds.write('RNN_HpredsT3_padArray: '+str(rnn_HpredsT3_padArray)+'\n')
  287. paddedPreds.write('Shape of RNN_HpredsT1_padArray: '+str(rnn_HpredsT1_padArray.shape)+'\n')
  288. paddedPreds.write('Shape of RNN_HpredsT1: '+str(rnn_HpredsT1.shape)+'\n')
  289. paddedPreds.write('RNN_HpredsT1[0]: '+str(rnn_HpredsT1[0])+'\n')
  290. paddedPreds.write('rnn_HpredsT1[0][0]: '+str(rnn_HpredsT1[0][0])+'\n')
  291. paddedPreds.write('rnn_HpredsT1_padArray[0]: '+str(rnn_HpredsT1_padArray[0])+'\n')
  292. paddedPreds.write('rnn_HpredsT1_padArray[0][0]: '+str(rnn_HpredsT1_padArray[0][0])+'\n')
  293. paddedPreds.write('# of Hrnn T1 only: '+str(HrnnT1)+'\n')
  294. paddedPreds.write('# of Hrnn T1 only Removed: '+str(HrnnT1Removed)+'\n')
  295. paddedPreds.write('# of Hrnn T1+T2: '+str(HrnnT1T2)+'\n')
  296. paddedPreds.write('# of Hrnn T1+T2+T3: '+str(HrnnT1T2T3)+'\n')
  297. paddedPreds.write('# of Hrnn T1+T3: '+str(HrnnT1T3)+'\n')
  298. paddedPreds.write('# of Hrnn T2 only: '+str(HrnnT2)+'\n')
  299. paddedPreds.write('# of Hrnn T2+T3: '+str(HrnnT2T3)+'\n')
  300. paddedPreds.write('# of Hrnn T3 only: '+str(HrnnT3)+'\n')
  301. #AD patients
  302. rnn_ApredsT1_padded = []
  303. rnn_ApredsT2_padded = []
  304. rnn_ApredsT3_padded = []
  305. rnn_AptidT1_padded = []
  306. rnn_AptidT2_padded = []
  307. rnn_AptidT3_padded = []
  308. rnn_AimageIDT1_padded = []
  309. rnn_AimageIDT2_padded = []
  310. rnn_AimageIDT3_padded = []
  311. j=0
  312. ArnnT1T2T3 = 0
  313. ArnnT1T2 = 0
  314. ArnnT1T3 = 0
  315. ArnnT1 = 0
  316. ArnnT2 = 0
  317. ArnnT2T3 = 0
  318. ArnnT3 = 0
  319. ArnnT1Removed = 0
  320. for ptidT1 in rnn_AptidT1:
  321. rnn_ApredsT1_padded.append(rnn_ApredsT1[j])
  322. rnn_AptidT1_padded.append(ptidT1)
  323. rnn_AimageIDT1_padded.append(rnn_AimageIDT1[j])
  324. j+=1
  325. c=0
  326. k=0
  327. t2 = False
  328. t3 = False
  329. for ptidT2 in rnn_AptidT2:
  330. c+=1
  331. if ptidT1 == ptidT2:
  332. rnn_ApredsT2_padded.append(rnn_ApredsT2[c-1])
  333. rnn_AptidT2_padded.append(ptidT2)
  334. rnn_AimageIDT2_padded.append(rnn_AimageIDT2[c-1])
  335. t2 = True
  336. for ptidT3 in rnn_AptidT3:
  337. k+=1
  338. if ptidT1 == ptidT3:
  339. rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
  340. rnn_AptidT3_padded.append(ptidT3)
  341. rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
  342. ArnnT1T2T3+=1
  343. t3 = True
  344. break
  345. if t3 == False:
  346. rnn_ApredsT3_padded.append(dummyVector)
  347. rnn_AptidT3_padded.append(ptidT1)
  348. rnn_AimageIDT3_padded.append('dummy')
  349. ArnnT1T2+=1
  350. break
  351. if t2 == False:
  352. rnn_ApredsT2_padded.append(dummyVector)
  353. rnn_AptidT2_padded.append(ptidT1)
  354. rnn_AimageIDT2_padded.append('dummy')
  355. for ptidT3 in rnn_AptidT3:
  356. k+=1
  357. if ptidT1 == ptidT3:
  358. rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
  359. rnn_AptidT3_padded.append(ptidT3)
  360. rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
  361. ArnnT1T3+=1
  362. t3 = True
  363. break
  364. if t3 == False:
  365. #rnn_ApredsT3_padded.append(dummyVector)
  366. ArnnT1+=1
  367. rnn_ApredsT1_padded.pop(-1) #remove any scans that have only T1
  368. rnn_ApredsT2_padded.pop(-1)
  369. rnn_AptidT1_padded.pop(-1)
  370. rnn_AimageIDT1_padded.pop(-1)
  371. rnn_AptidT2_padded.pop(-1)
  372. rnn_AimageIDT2_padded.pop(-1)
  373. ArnnT1Removed+=1
  374. c=0
  375. for ptidT2 in rnn_AptidT2:
  376. c+=1
  377. j=0
  378. k=0
  379. match = False
  380. t3=False
  381. for ptidT1 in rnn_AptidT1:
  382. j+=1
  383. if ptidT2 == ptidT1:
  384. match = True
  385. if match == False:
  386. rnn_ApredsT2_padded.append(rnn_ApredsT2[c-1])
  387. rnn_AptidT2_padded.append(ptidT2)
  388. rnn_AimageIDT2_padded.append(rnn_AimageIDT2[c-1])
  389. rnn_ApredsT1_padded.append(dummyVector)
  390. rnn_AptidT1_padded.append(ptidT1)
  391. rnn_AimageIDT1_padded.append('dummy')
  392. for ptidT3 in rnn_AptidT3:
  393. k+=1
  394. if ptidT2 == ptidT3:
  395. rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
  396. rnn_AptidT3_padded.append(ptidT3)
  397. rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
  398. t3 = True
  399. ArnnT2T3+=1
  400. break
  401. if t3 == False:
  402. rnn_ApredsT3_padded.append(dummyVector)
  403. rnn_AptidT3_padded.append(ptidT1)
  404. rnn_AimageIDT3_padded.append('dummy')
  405. ArnnT2+=1
  406. k=0
  407. for ptidT3 in rnn_AptidT3:
  408. k+=1
  409. j=0
  410. c=0
  411. match1 = False
  412. for ptidT1 in rnn_AptidT1:
  413. j+=1
  414. if ptidT3 == ptidT1:
  415. match1 = True
  416. # if match1 == True:
  417. # break
  418. if match1 == False:
  419. match2 = False
  420. for ptidT2 in rnn_AptidT2:
  421. c+=1
  422. if ptidT3 == ptidT2:
  423. match2 = True
  424. # if match2 == True:
  425. # break
  426. if match2 == False:
  427. rnn_ApredsT3_padded.append(rnn_ApredsT3[k-1])
  428. rnn_AptidT3_padded.append(ptidT3)
  429. rnn_AimageIDT3_padded.append(rnn_AimageIDT3[k-1])
  430. rnn_ApredsT1_padded.append(dummyVector)
  431. rnn_AptidT1_padded.append(ptidT1)
  432. rnn_AimageIDT1_padded.append('dummy')
  433. rnn_ApredsT2_padded.append(dummyVector)
  434. rnn_AptidT2_padded.append(ptidT1)
  435. rnn_AimageIDT2_padded.append('dummy')
  436. ArnnT3+=1
  437. #move the data from a list to an array
  438. j=0
  439. c=0
  440. k=0
  441. LenPadded = len(rnn_ApredsT1_padded)
  442. rnn_ApredsT1_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
  443. rnn_ApredsT2_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
  444. rnn_ApredsT3_padArray = np.zeros((LenPadded,final_layer_size), dtype=object)
  445. for vector in rnn_ApredsT1_padded:
  446. rnn_ApredsT1_padArray[j] = vector
  447. j+=1
  448. for vector in rnn_ApredsT2_padded:
  449. rnn_ApredsT2_padArray[c] = vector
  450. c+=1
  451. for vector in rnn_ApredsT3_padded:
  452. rnn_ApredsT3_padArray[k] = vector
  453. k+=1
  454. with open(model_filepath+'/figures/paddedPreds.txt','a') as paddedPreds:
  455. paddedPreds.write('Length of rnn_ApredsT1_padArray: '+str(len(rnn_ApredsT1_padArray))+'\n')
  456. paddedPreds.write('Length of rnn_ApredsT2_padArray: '+str(len(rnn_ApredsT2_padArray))+'\n')
  457. paddedPreds.write('Length of rnn_ApredsT3_padArray: '+str(len(rnn_ApredsT3_padArray))+'\n')
  458. paddedPreds.write('Length of rnn_AptidT1_padded: '+str(len(rnn_AptidT1_padded))+'\n')
  459. paddedPreds.write('Length of rnn_AptidT2_padded: '+str(len(rnn_AptidT2_padded))+'\n')
  460. paddedPreds.write('Length of rnn_AptidT3_padded: '+str(len(rnn_AptidT3_padded))+'\n')
  461. paddedPreds.write('Length of rnn_AimageIDT1_padded: '+str(len(rnn_AimageIDT1_padded))+'\n')
  462. paddedPreds.write('Length of rnn_AimageIDT2_padded: '+str(len(rnn_AimageIDT2_padded))+'\n')
  463. paddedPreds.write('Length of rnn_AimageIDT3_padded: '+str(len(rnn_AimageIDT3_padded))+'\n')
  464. paddedPreds.write('# of Arnn T1 only: '+str(ArnnT1)+'\n')
  465. paddedPreds.write('# of Arnn T1 only Removed: '+str(ArnnT1Removed)+'\n')
  466. paddedPreds.write('# of Arnn T1+T2: '+str(ArnnT1T2)+'\n')
  467. paddedPreds.write('# of Arnn T1+T2+T3: '+str(ArnnT1T2T3)+'\n')
  468. paddedPreds.write('# of Arnn T1+T3: '+str(ArnnT1T3)+'\n')
  469. paddedPreds.write('# of Arnn T2 only: '+str(ArnnT2)+'\n')
  470. paddedPreds.write('# of Arnn T2+T3: '+str(ArnnT2T3)+'\n')
  471. paddedPreds.write('# of Arnn T3 only: '+str(ArnnT3)+'\n')
  472. #Balance the datasets: (drop the last scans from the H datasets to make the A and H datasets equal. Should be different patients each time because I shuffled in get_filenames
  473. diff = len(rnn_HpredsT1_padArray)-len(rnn_ApredsT1_padArray)
  474. for i in range(diff):
  475. rnn_HpredsT1_padArray = np.delete(rnn_HpredsT1_padArray,-1,0)
  476. rnn_HpredsT2_padArray = np.delete(rnn_HpredsT2_padArray,-1,0)
  477. rnn_HpredsT3_padArray = np.delete(rnn_HpredsT3_padArray,-1,0)
  478. dummyCountHT1 = 0
  479. dummyCountHT2 = 0
  480. dummyCountHT3 = 0
  481. dummyCountAT1 = 0
  482. dummyCountAT2 = 0
  483. dummyCountAT3 = 0
  484. for i in range(len(rnn_HpredsT1_padArray)):
  485. if rnn_HpredsT1_padArray[i][0] == -1:
  486. dummyCountHT1 += 1
  487. if rnn_HpredsT2_padArray[i][0] == -1:
  488. dummyCountHT2 += 1
  489. if rnn_HpredsT3_padArray[i][0] == -1:
  490. dummyCountHT3 += 1
  491. for i in range(len(rnn_ApredsT1_padArray)):
  492. if rnn_ApredsT1_padArray[i][0] == -1:
  493. dummyCountAT1 += 1
  494. if rnn_ApredsT2_padArray[i][0] == -1:
  495. dummyCountAT2 += 1
  496. if rnn_ApredsT3_padArray[i][0] == -1:
  497. dummyCountAT3 += 1
  498. with open(model_filepath+'/figures/paddedPreds.txt','a') as paddedPreds:
  499. paddedPreds.write('Length of rnn_HpredsT1_padArray popped: '+str(len(rnn_HpredsT1_padArray))+'\n')
  500. paddedPreds.write('Length of rnn_HpredsT2_padArray popped: '+str(len(rnn_HpredsT2_padArray))+'\n')
  501. paddedPreds.write('Length of rnn_HpredsT3_padArray popped: '+str(len(rnn_HpredsT3_padArray))+'\n')
  502. with open(model_filepath+'/figures/DataList.txt','a') as datalist:
  503. datalist.write('Number of scans in HT1 (excluding dummies): '+str(len(rnn_HpredsT1_padArray)-dummyCountHT1)+'\n')
  504. datalist.write('Number of scans in HT2 (excluding dummies): '+str(len(rnn_HpredsT2_padArray)-dummyCountHT2)+'\n')
  505. datalist.write('Number of scans in HT3 (excluding dummies): '+str(len(rnn_HpredsT3_padArray)-dummyCountHT3)+'\n')
  506. datalist.write('Number of scans in AT1 (excluding dummies): '+str(len(rnn_ApredsT1_padArray)-dummyCountAT1)+'\n')
  507. datalist.write('Number of scans in AT2 (excluding dummies): '+str(len(rnn_ApredsT2_padArray)-dummyCountAT2)+'\n')
  508. datalist.write('Number of scans in AT3 (excluding dummies): '+str(len(rnn_ApredsT3_padArray)-dummyCountAT3)+'\n')
  509. #Split RNN data into train/val/test
  510. train_predsT1_padArray,train_predsT2_padArray,train_predsT3_padArray,val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray,test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray, train_labels_padArray,val_labels_padArray,test_labels_padArray, test_ptidT1,test_ptidT2,test_ptidT3,test_imageIDT1,test_imageIDT2,test_imageIDT3 = data_loader.split_data_RNN(rnn_HpredsT1_padArray,rnn_HpredsT2_padArray,rnn_HpredsT3_padArray,rnn_ApredsT1_padArray,rnn_ApredsT2_padArray,rnn_ApredsT3_padArray,rnn_HptidT1_padded,rnn_HptidT2_padded,rnn_HptidT3_padded,rnn_HimageIDT1_padded,rnn_HimageIDT2_padded,rnn_HimageIDT3_padded,rnn_AptidT1_padded,rnn_AptidT2_padded,rnn_AptidT3_padded,rnn_AimageIDT1_padded,rnn_AimageIDT2_padded,rnn_AimageIDT3_padded,val_split)
  511. #RUN THE RNN:
  512. netRNN = RNN_Net(params)
  513. historyRNN = netRNN.train(([train_predsT1_padArray,train_predsT2_padArray,train_predsT3_padArray],train_labels_padArray,[val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray],val_labels_padArray))
  514. #EVALUATE RNN:
  515. test_lossRNN, test_accRNN = netRNN.evaluate (([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
  516. test_predsRNN = netRNN.predict(([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
  517. """
  518. #TO LOAD A PREVIOUS MODEL INSTEAD: (uncomment this chunk and comment above chunk - all the way up to through the data prep for RNN)
  519. #note: this is not needed for CNN heatmaps
  520. #can't seem to figure out how to load the whole model (but am saving it anyway). I'm only able to save and load the weights, so note that the model needs to be recompiled, so it has to be the correct architecture
  521. #Also, I should check that it works by running the same test set first and making sure I get the same results
  522. netRNN = RNN_Net(params)
  523. netRNN.load_the_weights("SavedRNNWeights")
  524. pickle_in = open(model_filepath+'/'+picklename+'.pickle', 'rb') #change this to be the pickle filename
  525. pickle0=pickle.load(pickle_in)
  526. pickle_in.close()
  527. test_predsT1_padArray = pickle0[5][1]
  528. test_predsT2_padArray = pickle0[5][2]
  529. test_predsT3_padArray = pickle0[5][3]
  530. test_labels_padArray = pickle0[5][4]
  531. test_labels_padArray = np.delete(test_labels_padArray,0)
  532. pickle0 = 0
  533. print('test_labels_padArray: ',test_labels_padArray)
  534. test_lossRNN, test_accRNN = netRNN.evaluate(([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
  535. test_predsRNN = netRNN.predict(([test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray],test_labels_padArray))
  536. print('check_lossRNN, check_accRNN: '+ str(test_lossRNN)+', '+ str(test_accRNN))
  537. """
  538. #PLOTS FOR THE CNN ALONE
  539. #plot accuracy learning curves
  540. plt.figure()
  541. plt.plot(historyCNN['acc'],color='red')
  542. plt.plot(historyCNN['val_acc'],color='blue')
  543. plt.title('CNN model accuracy learning curve')
  544. plt.ylabel('accuracy')
  545. plt.xlabel('epoch')
  546. plt.xlabel('1 - Specificity',fontsize=20)
  547. plt.ylabel('Sensitivity',fontsize=20)
  548. plt.legend(['training', 'validation'], loc='upper left')
  549. plt.savefig(model_filepath+'/figures/CNN_LCacc'+str(seed)+'.png', bbox_inches='tight')
  550. #plot loss learning curves
  551. plt.figure()
  552. plt.plot(historyCNN['loss'],color='orange')
  553. plt.plot(historyCNN['val_loss'],color='purple')
  554. plt.title('CNN model loss learning curve')
  555. plt.ylabel('loss')
  556. plt.xlabel('epoch')
  557. plt.legend(['training', 'validation'], loc='upper right')
  558. plt.savefig(model_filepath+'/figures/CNN_LCloss'+str(seed)+'.png', bbox_inches='tight')
  559. #plot test ROC curve
  560. fpr_testCNN = dict()
  561. tpr_testCNN = dict()
  562. thresholds_testCNN = dict()
  563. areaundercurveCNN = dict()
  564. plt.figure()
  565. test_predsCNN_class = np.argmax(test_predsCNN,axis=-1)
  566. test_predsCNN_count = np.bincount(test_predsCNN_class, minlength=n_classes)
  567. print('test_labelsCNN: ', test_data[3])
  568. print('test_predCNNclass: ', test_predsCNN_class)
  569. print('test_predCNNcount: ', test_predsCNN_count)
  570. tROC = True
  571. for i in range(n_classes):
  572. if test_predsCNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
  573. print('Class ', i, 'is predicted 0 times in CNN testing.')
  574. print('Cannot plot Test ROC curve for CNN.')
  575. tROC = False
  576. break
  577. if tROC == True:
  578. if n_classes ==2:
  579. #fpr_testCNN, tpr_testCNN, thresholds_testCNN = roc_curve(np.array(pd.get_dummies(test_data[3]))[:,1], np.array(test_predsCNN)[:,1])
  580. fpr_testCNN, tpr_testCNN, thresholds_testCNN = roc_curve(test_data[3], test_predsCNN[:,1])
  581. areaundercurveCNN = auc(fpr_testCNN,tpr_testCNN)
  582. lw = 3
  583. class_name = ['AD','Healthy']
  584. plt.plot(fpr_testCNN, tpr_testCNN, lw=lw)
  585. plt.title('CNN ROC')
  586. plt.xlabel('1 - Specificity',fontsize=13)
  587. plt.ylabel('Sensitivity',fontsize=13)
  588. else:
  589. for i in range(n_classes):
  590. fpr_testCNN[i], tpr_testCNN[i], thresholds_testCNN[i] = roc_curve(np.array(pd.get_dummies(test_data[3]))[:,i], np.array(test_predsCNN)[:,i])
  591. areaundercurveCNN[i] = auc(fpr_testCNN[i],tpr_testCNN[i])
  592. lw = 3
  593. class_name = ['AD','Healthy']
  594. plt.plot(fpr_testCNN[i], tpr_testCNN[i],
  595. lw=lw, label=str(class_name[i]))
  596. plt.title('CNN ROC')
  597. plt.xlabel('1 - Specificity',fontsize=13)
  598. plt.ylabel('Sensitivity',fontsize=13)
  599. if tROC==True: #skips ROC curve and TPRs for situation where one class is never predicted
  600. #plot testROC
  601. plt.legend(loc="lower right")
  602. plt.savefig(model_filepath+'/figures/CNN_ROC'+str(seed)+'.png', bbox_inches='tight')
  603. #print TPRs for each class
  604. #print('TPR_AD_CNN = '+str(tpr_testCNN[0]))
  605. #print('TPR_Healthy_CNN = '+str(tpr_testCNN[1]))
  606. #Confusion matrix
  607. mci_conf_matrix_testCNN = confusion_matrix(y_true = test_data[3], y_pred = np.round(test_predsCNN_class))
  608. plt.figure()
  609. ax = plt.subplot()
  610. cax = ax.matshow(mci_conf_matrix_testCNN)
  611. plt.title('Full CNN T1 Confusion Matrix')
  612. plt.colorbar(cax)
  613. ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
  614. ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
  615. plt.xlabel('Predicted',fontsize=13)
  616. plt.ylabel('True',fontsize=13)
  617. plt.savefig(model_filepath+'/figures/CNN_ConfMatrix'+str(seed)+'.png', bbox_inches='tight')
  618. #Normalized confusion matrix
  619. mci_conf_matrix_test_normedCNN = mci_conf_matrix_testCNN/(mci_conf_matrix_testCNN.sum(axis=1)[:,np.newaxis])
  620. plt.figure()
  621. ax = plt.subplot()
  622. cax = ax.matshow(mci_conf_matrix_test_normedCNN)
  623. plt.title('Full CNN T1 Normalized Confusion Matrix')
  624. plt.colorbar(cax)
  625. ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
  626. ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
  627. plt.xlabel('Predicted',fontsize=13)
  628. plt.ylabel('True',fontsize=13)
  629. plt.savefig(model_filepath+'/figures/CNN_ConfMatrixNormed'+str(seed)+'.png', bbox_inches='tight')
  630. #validation ROC
  631. val_lossCNN, val_accCNN = netCNN.evaluate ((val_data))
  632. val_predsCNN = netCNN.predict((val_data))
  633. val_predsCNN_class = np.argmax(val_predsCNN,axis=-1)
  634. fpr_valCNN = dict()
  635. tpr_valCNN = dict()
  636. thresholds_valCNN = dict()
  637. val_predsCNN_count = np.bincount(val_predsCNN_class, minlength=n_classes)
  638. print('val_predsCNN_count: ', val_predsCNN_count)
  639. vROC = True
  640. for i in range(n_classes):
  641. if val_predsCNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
  642. print('Class ', i, 'is predicted 0 times in CNN validation.')
  643. print('Cannot plot vROC curve for CNN.')
  644. vROC = False
  645. break
  646. if vROC == True:
  647. if n_classes ==2:
  648. fpr_valCNN, tpr_valCNN, thresholds_valCNN = roc_curve(np.array(pd.get_dummies(val_data[3]))[:,1], np.array(val_predsCNN)[:,1])
  649. else:
  650. fpr_valCNN[i], tpr_valCNN[i], thresholds_valCNN[i] = roc_curve(np.array(pd.get_dummies(val_data[3]))[:,i], np.array(val_predsCNN)[:,i])
  651. mci_conf_matrix_valCNN = confusion_matrix(y_true = val_data[3], y_pred = np.round(val_predsCNN_class))
  652. mci_conf_matrix_val_normedCNN = mci_conf_matrix_valCNN/(mci_conf_matrix_valCNN.sum(axis=1)[:,np.newaxis])
  653. print("Test CNN accuracy: "+str(test_accCNN))
  654. print("CNN AUC: " +str(areaundercurveCNN))
  655. #PLOTS FOR THE RNN
  656. #plot accuracy learning curves
  657. plt.figure()
  658. plt.plot(historyRNN['acc'],color='red')
  659. plt.plot(historyRNN['val_acc'],color='blue')
  660. plt.title('RNN model accuracy learning curve')
  661. plt.ylabel('accuracy')
  662. plt.xlabel('epoch')
  663. plt.legend(['training', 'validation'], loc='upper left')
  664. plt.savefig(model_filepath+'/figures/RNN_LCacc'+str(seed)+'.png', bbox_inches='tight')
  665. #plot loss learning curves
  666. plt.figure()
  667. plt.plot(historyRNN['loss'],color='orange')
  668. plt.plot(historyRNN['val_loss'],color='purple')
  669. plt.title('RNN model loss learning curve')
  670. plt.ylabel('loss')
  671. plt.xlabel('epoch')
  672. plt.legend(['training', 'validation'], loc='upper right')
  673. plt.savefig(model_filepath+'/figures/RNN_LCloss'+str(seed)+'.png', bbox_inches='tight')
  674. #plot 2-class test ROC curve
  675. fpr_testRNN = dict()
  676. tpr_testRNN = dict()
  677. thresholds_testRNN = dict()
  678. areaundercurveRNN = dict()
  679. plt.figure()
  680. test_predsRNN_class = np.argmax(test_predsRNN,axis=-1)
  681. test_predsRNN_count = np.bincount(test_predsRNN_class, minlength=n_classes)
  682. print('test_labelsRNN: ', test_labels_padArray)
  683. print('test_predsRNN_class: ', test_predsRNN_class)
  684. print('test_predsRNN_count: ', test_predsRNN_count)
  685. tROC = True
  686. for i in range(n_classes):
  687. if test_predsRNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
  688. print('Class ', i, 'is predicted 0 times in RNN testing.')
  689. print('Cannot plot Test ROC curve for RNN.')
  690. tROC = False
  691. break
  692. if tROC == True:
  693. if n_classes ==2:
  694. fpr_testRNN, tpr_testRNN, thresholds_testRNN = roc_curve(np.array(pd.get_dummies(test_labels_padArray))[:,i], np.array(test_predsRNN)[:,1]) #changed first 1 from i
  695. areaundercurveRNN = auc(fpr_testRNN,tpr_testRNN)
  696. lw = 3
  697. class_name = ['AD','Healthy']
  698. plt.plot(fpr_testRNN, tpr_testRNN, lw=lw)
  699. plt.title('RNN ROC')
  700. plt.xlabel('1 - Specificity',fontsize=13)
  701. plt.ylabel('Sensitivity',fontsize=13)
  702. else:
  703. for i in range(n_classes):
  704. fpr_testRNN[i], tpr_testRNN[i], thresholds_testRNN[i] = roc_curve(np.array(pd.get_dummies(test_labels_padArray))[:,i], np.array(test_predsRNN)[:,i])
  705. areaundercurveRNN[i] = auc(fpr_testRNN[i],tpr_testRNN[i])
  706. lw = 3
  707. class_name = ['AD','Healthy']
  708. plt.plot(fpr_testRNN[i], tpr_testRNN[i],
  709. lw=lw, label=str(class_name[i]))
  710. plt.title('RNN ROC')
  711. plt.xlabel('1 - Specificity',fontsize=13)
  712. plt.ylabel('Sensitivity',fontsize=13)
  713. if tROC==True: #skips ROC curve and TPRs for situation where one class is never predicted
  714. #plot testROC
  715. plt.legend(loc="lower right")
  716. plt.savefig(model_filepath+'/figures/RNN_ROC'+str(seed)+'.png', bbox_inches='tight')
  717. #print TPRs for each class
  718. #print('TPR_AD_RNN = '+str(tpr_testRNN[0]))
  719. #print('TPR_Healthy_RNN = '+str(tpr_testRNN[1]))
  720. #Confusion matrix
  721. mci_conf_matrix_testRNN = confusion_matrix(y_true = test_labels_padArray, y_pred = np.round(test_predsRNN_class))
  722. plt.figure()
  723. ax = plt.subplot()
  724. cax = ax.matshow(mci_conf_matrix_testRNN)
  725. plt.title('RNN Confusion Matrix')
  726. plt.colorbar(cax)
  727. ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
  728. ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
  729. plt.xlabel('Predicted',fontsize=13)
  730. plt.ylabel('True',fontsize=13)
  731. plt.savefig(model_filepath+'/figures/RNN_ConfMatrix'+str(seed)+'.png', bbox_inches='tight')
  732. #Normalized confusion matrix
  733. mci_conf_matrix_test_normedRNN = mci_conf_matrix_testRNN/(mci_conf_matrix_testRNN.sum(axis=1)[:,np.newaxis])
  734. plt.figure()
  735. ax = plt.subplot()
  736. cax = ax.matshow(mci_conf_matrix_test_normedRNN)
  737. plt.title('RNN Normalized Confusion Matrix')
  738. plt.colorbar(cax)
  739. ax.set_xticklabels(['','AD','Healthy'],fontsize=11)
  740. ax.set_yticklabels(['','AD','Healthy'],fontsize=11)
  741. plt.xlabel('Predicted',fontsize=13)
  742. plt.ylabel('True',fontsize=13)
  743. plt.savefig(model_filepath+'/figures/RNN_ConfMatrixNormed'+str(seed)+'.png', bbox_inches='tight')
  744. #validation ROC
  745. val_lossRNN, val_accRNN = netRNN.evaluate (([val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray],val_labels_padArray))
  746. val_predsRNN = netRNN.predict(([val_predsT1_padArray,val_predsT2_padArray,val_predsT3_padArray],val_labels_padArray))
  747. val_predsRNN_class = np.argmax(val_predsRNN,axis=-1)
  748. fpr_valRNN = dict()
  749. tpr_valRNN = dict()
  750. thresholds_valRNN = dict()
  751. val_predsRNN_count = np.bincount(val_predsRNN_class, minlength=n_classes)
  752. print('val_predsRNN_count: ', val_predsRNN_count)
  753. vROC = True
  754. for i in range(n_classes):
  755. if val_predsRNN_count[i]==0: #skips ROC curve for situation where one class is never predicted
  756. print('Class ', i, 'is predicted 0 times in RNN validation.')
  757. print('Cannot plot vROC curve for RNN.')
  758. vROC = False
  759. break
  760. if vROC==True:
  761. if n_classes == 2:
  762. fpr_valRNN, tpr_valRNN, thresholds_valRNN = roc_curve(np.array(pd.get_dummies(val_labels_padArray))[:,1], np.array(val_predsRNN)[:,1])
  763. else:
  764. for i in range(n_classes):
  765. fpr_valRNN[i], tpr_valRNN[i], thresholds_valRNN[i] = roc_curve(np.array(pd.get_dummies(val_labels_padArray))[:,i], np.array(val_predsRNN)[:,i])
  766. mci_conf_matrix_valRNN = confusion_matrix(y_true = val_labels_padArray, y_pred = np.round(val_predsRNN_class))
  767. mci_conf_matrix_val_normedRNN = mci_conf_matrix_valRNN/(mci_conf_matrix_valRNN.sum(axis=1)[:,np.newaxis])
  768. print("Test RNN accuracy: "+str(test_accRNN))
  769. print("RNN AUC: " +str(areaundercurveRNN))
  770. #TEST SET TABLES
  771. test_table_CNN = (test_data[4],test_data[5],test_data[3],test_data[6],test_data[7],test_predsCNN_class,test_predsCNN[0],test_predsCNN[1])
  772. test_table_RNN = (test_ptidT1,test_imageIDT1,test_imageIDT2,test_imageIDT3,test_labels_padArray,test_predsRNN_class,test_predsRNN[0],test_predsRNN[1])
  773. #WRITE THE OUTPUT FILE
  774. with open(model_filepath+'/figures/Outputs'+str(seed)+'.txt','w') as outputs:
  775. #RNN
  776. outputs.write('RNN Confusion Matrix Values:'+'\n')
  777. outputs.write(str(mci_conf_matrix_testRNN)+'\n')
  778. outputs.write('RNN Normalized Confusion Matrix Values:'+'\n')
  779. outputs.write(str(mci_conf_matrix_test_normedRNN)+'\n')
  780. outputs.write('RNN Test accuracy:'+'\n')
  781. outputs.write(str(test_accRNN)+'\n')
  782. outputs.write('RNN AUC:'+'\n')
  783. outputs.write(str(areaundercurveRNN) +'\n')
  784. outputs.write('RNN Test Predictions Probabilities'+'\n')
  785. outputs.write(str(test_predsRNN) +'\n')
  786. outputs.write('RNN Test Predictions MaxProb Class'+'\n')
  787. outputs.write(str(test_predsRNN_class) +'\n')
  788. #CNN
  789. outputs.write('Full CNN Confusion Matrix Values:'+'\n')
  790. outputs.write(str(mci_conf_matrix_testCNN)+'\n')
  791. outputs.write('Full CNN Normalized Confusion Matrix Values:'+'\n')
  792. outputs.write(str(mci_conf_matrix_test_normedCNN)+'\n')
  793. outputs.write('Full CNN Test accuracy:'+'\n')
  794. outputs.write(str(test_accCNN)+'\n')
  795. outputs.write('Full CNN AUC:'+'\n')
  796. outputs.write(str(areaundercurveCNN) +'\n')
  797. outputs.write('Full CNN Test Predictions Probabilities'+'\n')
  798. outputs.write(str(test_predsCNN) +'\n')
  799. outputs.write('Full CNN Test Predictions MaxProb Class'+'\n')
  800. outputs.write(str(test_predsCNN_class) +'\n')
  801. #outputs.write('Index of best CNN Gmean'+'\n')
  802. #outputs.write(str(ixC) +'\n')
  803. #outputs.write('Optimal CNN Threshold'+'\n')
  804. #outputs.write(str(bestThreshCNN) +'\n')
  805. #outputs.write('Value of highest Gmean'+'\n')
  806. #outputs.write(str(highGmeanCNN) +'\n')
  807. #outputs.write('CNN Accuracy at Optimized Threshold'+'\n')
  808. #outputs.write(str(OptAccCNN) +'\n'+'\n')
  809. #Testset output tables
  810. outputs.write('test_table_CNN'+'\n')
  811. outputs.write(str(test_table_CNN)+'\n'+'\n')
  812. outputs.write('test_table_RNN'+'\n')
  813. outputs.write(str(test_table_RNN)+'\n'+'\n')
  814. #TEST SET TABLES
  815. Cptid = test_data[4]
  816. Cptid = np.insert(Cptid,0,'PTID')
  817. CimageID = test_data[5]
  818. CimageID = np.insert(CimageID,0,'imgID')
  819. Cconfid = test_data[6]
  820. Cconfid = np.insert(Cconfid,0,'DxConfidence')
  821. Ccsf = test_data[7]
  822. Ccsf = np.insert(Ccsf,0,'CSF_Path')
  823. Clabels = test_data[3]
  824. Clabels = np.insert(Clabels.astype(str),0,'label')
  825. test_predsCNN_class = np.insert(test_predsCNN_class.astype(str),0,'prediction')
  826. probsCAD = [item[0] for item in test_predsCNN]
  827. probsCNC = [item[1] for item in test_predsCNN]
  828. probsCAD.insert(0,'prediction probabilities AD')
  829. probsCNC.insert(0,'prediction probabilities NC')
  830. test_ptidT1 = np.insert(test_ptidT1,0,'PTID')
  831. test_imageIDT1 = np.insert(test_imageIDT1,0,'imIDT1')
  832. test_imageIDT2 = np.insert(test_imageIDT2,0,'imIDT2')
  833. test_imageIDT3 = np.insert(test_imageIDT3,0,'imIDT3')
  834. test_labels_padArray = np.insert(test_labels_padArray.astype(str),0,'label')
  835. test_predsRNN_class = np.insert(test_predsRNN_class.astype(str),0,'prediction')
  836. probsRAD = [item[0] for item in test_predsRNN]
  837. probsRNC = [item[1] for item in test_predsRNN]
  838. probsRAD.insert(0,'prediction probabilities AD')
  839. probsRNC.insert(0,'prediction probabilities NC')
  840. with open(model_filepath+'/figures/test_table_'+str(seed)+'.csv','w') as Testcsv:
  841. Testcsv_writer = csv.writer(Testcsv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
  842. Testcsv_writer.writerow(['CNN'])
  843. Testcsv_writer.writerow(Cptid)
  844. Testcsv_writer.writerow(CimageID)
  845. Testcsv_writer.writerow(Clabels)
  846. Testcsv_writer.writerow(test_predsCNN_class)
  847. Testcsv_writer.writerow(Cconfid)
  848. Testcsv_writer.writerow(Ccsf)
  849. Testcsv_writer.writerow(probsCAD)
  850. Testcsv_writer.writerow(probsCNC)
  851. Testcsv_writer.writerow(' ')
  852. Testcsv_writer.writerow(' ')
  853. Testcsv_writer.writerow(['RNN'])
  854. Testcsv_writer.writerow(test_ptidT1)
  855. Testcsv_writer.writerow(test_imageIDT1)
  856. Testcsv_writer.writerow(test_imageIDT2)
  857. Testcsv_writer.writerow(test_imageIDT3)
  858. Testcsv_writer.writerow(test_labels_padArray)
  859. Testcsv_writer.writerow(test_predsRNN_class)
  860. Testcsv_writer.writerow(probsRAD)
  861. Testcsv_writer.writerow(probsRNC)
  862. """
  863. #HEATMAPS:
  864. HeatmapPlotter = heatmapPlotter(seed)
  865. #if plotting NC data (need to load AD avg):
  866. picklename = 'loadedGC4912' #CHANGE to LRP or GGC and change seed
  867. pickle_in = open(model_filepath+'/'+picklename+'.pickle', 'rb')
  868. pickle0=pickle.load(pickle_in)
  869. pickle_in.close()
  870. mean_map_AD = pickle0[0]["AD"]
  871. pickle0=0
  872. #if plotting AD data:
  873. mean_map_AD = np.zeros((91,109,91))
  874. #RUN LRP
  875. case_maps_LRP, counts = HeatmapPlotter.LRP(test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN) #Removed CNN_LRP to save memory
  876. mean_maps_LRP = HeatmapPlotter.plot_avg_maps(case_maps_LRP, counts, 'LRP', test_mri_nonorm, model_filepath, mean_map_AD)
  877. #WRITE A PICKLE FILE
  878. with open(model_filepath+'/figures/loadedLRP' + str(seed)+'.pickle', 'wb') as f:
  879. pickle.dump([mean_maps_LRP, case_maps_LRP, counts], f) #Removed CNN_LRP to save memory
  880. #RUN GGC
  881. case_maps_GGC, counts = HeatmapPlotter.GuidedGradCAM(test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN) #Removed CNN_gradcam, CNN_gb, CNN_guided_gradcam to save memory
  882. mean_maps_GGC = HeatmapPlotter.plot_avg_maps(case_maps_GGC, counts, 'Guided GradCAM', test_mri_nonorm, model_filepath, mean_map_AD)
  883. #WRITE A PICKLE FILE
  884. with open(model_filepath+'/figures/loadedGC' + str(seed)+'.pickle', 'wb') as f:
  885. pickle.dump([mean_maps_GGC, case_maps_GGC, counts], f) #Removed CNN_gradcam, CNN_gb, CNN_guided_gradcam to save memory
  886. """
  887. #WRITE THE PICKLE FILE
  888. with open(model_filepath+'/figures/' + str(seed)+'.pickle', 'wb') as f:
  889. pickle.dump([[fpr_testRNN, tpr_testRNN, thresholds_testRNN, test_lossRNN, test_accRNN, mci_conf_matrix_testRNN, mci_conf_matrix_test_normedRNN, test_predsRNN, test_predsRNN_class, test_labels_padArray, val_predsRNN, val_labels_padArray ],
  890. [fpr_valRNN, tpr_valRNN, thresholds_valRNN, val_lossRNN, val_accRNN, mci_conf_matrix_valRNN, mci_conf_matrix_val_normedRNN],
  891. [fpr_testCNN, tpr_testCNN, thresholds_testCNN, test_lossCNN, test_accCNN, mci_conf_matrix_testCNN, mci_conf_matrix_test_normedCNN, test_predsCNN, test_predsCNN_class, test_labels_padArray, val_predsCNN, val_labels_padArray ],
  892. [fpr_valCNN, tpr_valCNN, thresholds_valCNN, val_lossCNN, val_accCNN, mci_conf_matrix_valCNN, mci_conf_matrix_val_normedCNN],
  893. [test_table_CNN,test_table_RNN],
  894. [test_data, test_predsT1_padArray,test_predsT2_padArray,test_predsT3_padArray,test_labels_padArray]],f)
  895. #[CNN_LRP, mean_maps_LRP],
  896. #[CNN_gradcam, CNN_gb, CNN_guided_gradcam, mean_maps_GGC]], f)
  897. #For Recovery of old runs
  898. # pickle_in = open('/data_wnx1/_Data/AlzheimersDL/MCI-spasov-3class/pickles/' + str(seed)+'.pickle', 'rb')
  899. # seedpickle=pickle.load(pickle_in)
  900. # print(seedpickle)
  901. #Plots to potentially add in the future
  902. #add sMCI+pMCI ROC curve
  903. # mci_fpr_test = fpr_test[1]+fpr_test[2]
  904. # mci_tpr_test = tpr_test[1]+tpr_test[2]
  905. # areaundercurve['all_mci'] = auc(mci_fpr_test,mci_tpr_test)
  906. # plt.plot(mci_fpr_test, mci_tpr_test,
  907. # lw=lw, label='All MCI')
  908. #Calculate macros
  909. # all_fpr_test = np.unique(np.concatenate([fpr_test[i] for i in range(n_classes)]))
  910. # mean_tpr_test = np.zeros_like(all_fpr_test)
  911. # for i in range(n_classes):
  912. # mean_tpr_test += interp(all_fpr_test, fpr_test[i], tpr_test[i])
  913. # mean_tpr_test /= n_classes
  914. # fpr_test["macro"] = all_fpr_test
  915. # tpr_test["macro"] = mean_tpr_test
  916. # areaundercurve["macro"] = auc(fpr_test["macro"], tpr_test["macro"])
  917. # plt.plot(fpr_test["macro"], tpr_test["macro"],label='ROC curve of macro-average')
  918. # plt.legend(loc="lower right")
  919. # plt.savefig(model_filepath+'/figures/ROC'+str(seed)+'.png', bbox_inches='tight')
  920. #Calculate micros
  921. # test_data_bin = label_binarize(test_data[-1], classes=[0, 1, 2])
  922. # test_preds_class_bin = label_binarize(test_preds_class, classes=[0, 1, 2])
  923. # fpr_test["micro"], tpr_test["micro"], _ = roc_curve(test_data_bin.ravel(), test_preds_class_bin.ravel())
  924. # areaundercurve["micro"] = auc(fpr_test["micro"], tpr_test["micro"])
  925. # plt.plot(fpr_test["macro"], tpr_test["macro"],label='ROC curve of micro-average')
  926. # plt.legend(loc="lower right")
  927. # plt.savefig(model_filepath+'/figures/ROC'+str(seed)+'.png', bbox_inches='tight')
  928. #RUN IT!
  929. for seed in seeds:
  930. #Load data
  931. print('Processing seed number ', seed)
  932. # data_loader = DataLoader((target_rows, target_cols, depth, axis), seed = seed)
  933. # train_data, val_data, test_data, healthy_dict = data_loader.get_train_val_test()
  934. # print('length of train data '+str((train_data)),'; length of val data '+str((val_data)),'; test data '+str((test_data)))
  935. # print('length of healthy_dict[mri] '+str(len(healthy_dict)))
  936. evaluate_net(seed)