preprocess.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. import numpy as np
  2. from numpy.random import RandomState
  3. from os import listdir
  4. import nibabel as nib
  5. import math
  6. import csv
  7. import random
  8. from keras.utils import to_categorical
  9. from utils.patientsort import PatientSorter
  10. ##for 2 class model CNN + RNN ##
  11. class DataLoader:
  12. """The DataLoader class is intended to be used on images placed in folder ../ADNI_volumes_customtemplate_float32
  13. naming convention is: class_subjectID_imageType.nii.gz
  14. masked_brain denotes structural MRI, JD_masked_brain denotes Jacobian Determinant
  15. stableNL: healthy controls
  16. stableMCItoAD: progressive MCI
  17. stableAD: Alzheimer's subjects
  18. Additionally, we use clinical features from csv file ../LP_ADNIMERGE.csv
  19. """
  20. def __init__(self, target_shape, seed = None):
  21. self.mri_datapath = '//data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/ADNI_volumes_customtemplate_float32'
  22. self.xls_datapath = '//data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data'
  23. self.target_shape = target_shape
  24. self.seed = seed
  25. def shuffle_dict_lists (self, dictionary):
  26. p = RandomState(self.seed).permutation(len(list(dictionary.values())[0]))
  27. for key in list(dictionary.keys()):
  28. dictionary[key] = [dictionary[key][i] for i in p]
  29. def get_filenames (self,mri_datapath):
  30. '''Puts filenames in ../ADNI_volumes_customtemplate_float32 in
  31. dictionaries according to class (stableMCI, MCItoAD, stableNL and stableAD)
  32. with keys corresponding to image modality (mri and JD)
  33. '''
  34. file_names = sorted(listdir(mri_datapath))
  35. keys = ['mri','PTID','viscode','imageID'] #is it an issue that I added viscodes here?
  36. healthy_dict, ad_dict = [{key: [] for key in keys} for i in range(2)] #!!
  37. healthyBL_dict,healthyM6_dict,healthyM12_dict,healthyM24_dict,healthyM36_dict,healthyM48_dict,healthyM60_dict,healthyM72_dict,healthyM84_dict,healthyM96_dict = [{key: [] for key in keys} for i in range(10)]
  38. adBL_dict,adM6_dict,adM12_dict,adM24_dict,adM36_dict,adM48_dict,adM60_dict,adM72_dict,adM84_dict,adM96_dict = [{key: [] for key in keys} for i in range(10)]
  39. healthyT1_dict,healthyT2_dict,healthyT3_dict,adT1_dict,adT2_dict,adT3_dict = [{key: [] for key in keys} for i in range(6)]
  40. healthyT1_Rdict,healthyT2_Rdict,healthyT3_Rdict,adT1_Rdict,adT2_Rdict,adT3_Rdict = [{key: [] for key in keys} for i in range(6)]
  41. #Get xls info
  42. with open(self.xls_datapath + '/' + 'LP_ADNIMERGE.csv', 'r') as f:
  43. reader = csv.reader(f)
  44. xls = [row for row in reader] #Extract all data from csv file in a list.
  45. test_xls=[]
  46. for _file in file_names:
  47. for row in xls:
  48. #imageID = 'I'+row[3] #prevents shorter imageIDs from matching to longer IDs which contain them
  49. imageID = row[3] #use this for loading validation set
  50. if imageID in _file:
  51. test_xls.append(row)
  52. break
  53. #sort the filenames into dicts
  54. for _file in file_names:
  55. if _file[-3:] == 'nii':
  56. if 'stableNL' in _file:
  57. for row in test_xls:
  58. #imageID = 'I'+row[3] #prevents shorter imageIDs from matching to longer IDs which contain them
  59. imageID = row[3] #use this for loading validation set
  60. if imageID in _file:
  61. if row[5] == 'bl':
  62. healthyBL_dict['mri'].append(_file)
  63. healthyBL_dict['PTID'].append(row[2])
  64. healthyBL_dict['viscode'].append(row[2])
  65. healthyBL_dict['imageID'].append(row[3])
  66. break
  67. if row[5] == 'm06':
  68. healthyM6_dict['mri'].append(_file)
  69. healthyM6_dict['PTID'].append(row[2])
  70. healthyM6_dict['viscode'].append(row[2])
  71. healthyM6_dict['imageID'].append(row[3])
  72. break
  73. elif row[5] == 'm12':
  74. healthyM12_dict['mri'].append(_file)
  75. healthyM12_dict['PTID'].append(row[2])
  76. healthyM12_dict['viscode'].append(row[2])
  77. healthyM12_dict['imageID'].append(row[3])
  78. break
  79. elif row[5] == 'm24':
  80. healthyM24_dict['mri'].append(_file)
  81. healthyM24_dict['PTID'].append(row[2])
  82. healthyM24_dict['viscode'].append(row[2])
  83. healthyM24_dict['imageID'].append(row[3])
  84. break
  85. elif row[5] == 'm36':
  86. healthyM36_dict['mri'].append(_file)
  87. healthyM36_dict['PTID'].append(row[2])
  88. healthyM36_dict['viscode'].append(row[2])
  89. healthyM36_dict['imageID'].append(row[3])
  90. break
  91. elif row[5] == 'm48':
  92. healthyM48_dict['mri'].append(_file)
  93. healthyM48_dict['PTID'].append(row[2])
  94. healthyM48_dict['viscode'].append(row[2])
  95. healthyM48_dict['imageID'].append(row[3])
  96. break
  97. elif row[5] == 'm60':
  98. healthyM60_dict['mri'].append(_file)
  99. healthyM60_dict['PTID'].append(row[2])
  100. healthyM60_dict['viscode'].append(row[2])
  101. healthyM60_dict['imageID'].append(row[3])
  102. break
  103. elif row[5] == 'm72':
  104. healthyM72_dict['mri'].append(_file)
  105. healthyM72_dict['PTID'].append(row[2])
  106. healthyM72_dict['viscode'].append(row[2])
  107. healthyM72_dict['imageID'].append(row[3])
  108. break
  109. elif row[5] == 'm84':
  110. healthyM84_dict['mri'].append(_file)
  111. healthyM84_dict['PTID'].append(row[2])
  112. healthyM84_dict['viscode'].append(row[2])
  113. healthyM84_dict['imageID'].append(row[3])
  114. break
  115. elif row[5] == 'm96':
  116. healthyM96_dict['mri'].append(_file)
  117. healthyM96_dict['PTID'].append(row[2])
  118. healthyM96_dict['viscode'].append(row[2])
  119. healthyM96_dict['imageID'].append(row[3])
  120. elif 'stableAD' in _file:
  121. for row in test_xls:
  122. #imageID = 'I'+row[3] #prevents shorter imageIDs from matching to longer IDs which contain them
  123. imageID = row[3] #use this for loading validation set
  124. if imageID in _file:
  125. if row[5] == 'bl':
  126. adBL_dict['mri'].append(_file)
  127. adBL_dict['PTID'].append(row[2])
  128. adBL_dict['viscode'].append(row[2])
  129. adBL_dict['imageID'].append(row[3])
  130. break
  131. elif row[5] == 'm06':
  132. adM6_dict['mri'].append(_file)
  133. adM6_dict['PTID'].append(row[2])
  134. adM6_dict['viscode'].append(row[2])
  135. adM6_dict['imageID'].append(row[3])
  136. break
  137. elif row[5] == 'm12':
  138. adM12_dict['mri'].append(_file)
  139. adM12_dict['PTID'].append(row[2])
  140. adM12_dict['viscode'].append(row[2])
  141. adM12_dict['imageID'].append(row[3])
  142. break
  143. elif row[5] == 'm24':
  144. adM24_dict['mri'].append(_file)
  145. adM24_dict['PTID'].append(row[2])
  146. adM24_dict['viscode'].append(row[2])
  147. adM24_dict['imageID'].append(row[3])
  148. break
  149. elif row[5] == 'm36':
  150. adM36_dict['mri'].append(_file)
  151. adM36_dict['PTID'].append(row[2])
  152. adM36_dict['viscode'].append(row[2])
  153. adM36_dict['imageID'].append(row[3])
  154. break
  155. elif row[5] == 'm48':
  156. adM48_dict['mri'].append(_file)
  157. adM48_dict['PTID'].append(row[2])
  158. adM48_dict['viscode'].append(row[2])
  159. adM48_dict['imageID'].append(row[3])
  160. break
  161. elif row[5] == 'm60':
  162. adM60_dict['mri'].append(_file)
  163. adM60_dict['PTID'].append(row[2])
  164. adM60_dict['viscode'].append(row[2])
  165. adM60_dict['imageID'].append(row[3])
  166. break
  167. elif row[5] == 'm72':
  168. adM72_dict['mri'].append(_file)
  169. adM72_dict['PTID'].append(row[2])
  170. adM72_dict['viscode'].append(row[2])
  171. adM72_dict['imageID'].append(row[3])
  172. break
  173. elif row[5] == 'm84':
  174. adM84_dict['mri'].append(_file)
  175. adM84_dict['PTID'].append(row[2])
  176. adM84_dict['viscode'].append(row[2])
  177. adM84_dict['imageID'].append(row[3])
  178. break
  179. elif row[5] == 'm96':
  180. adM96_dict['mri'].append(_file)
  181. adM96_dict['PTID'].append(row[2])
  182. adM96_dict['viscode'].append(row[2])
  183. adM96_dict['imageID'].append(row[3])
  184. #Choose which tps to call T1,T2,T3 ONLY NEEDED IF NOT USING ALL DATA IN CNN (then define healthy_dict_CNN as whichever of these scans you want
  185. #healthyT1_dict['mri'] = healthyM24_dict['mri']
  186. #healthyT2_dict['mri'] = healthyM36_dict['mri']
  187. #healthyT3_dict['mri'] = healthyM48_dict['mri']
  188. #adT1_dict['mri'] = adBL_dict['mri']
  189. #adT2_dict['mri'] = adM12_dict['mri']
  190. #adT3_dict['mri'] = adM24_dict['mri']
  191. #healthyT1_dict['PTID'] = healthyM24_dict['PTID']
  192. #healthyT2_dict['PTID'] = healthyM36_dict['PTID']
  193. #healthyT3_dict['PTID'] = healthyM48_dict['PTID']
  194. #adT1_dict['PTID'] = adBL_dict['PTID']
  195. #adT2_dict['PTID'] = adM12_dict['PTID']
  196. #adT3_dict['PTID'] = adM24_dict['PTID']
  197. #Use the above dicts for the CNN, now create the dicts for the RNN
  198. #sort into all healthy and all AD dicts, with PTIDs
  199. for _file in file_names:
  200. if _file[-3:] == 'nii':
  201. if 'stableNL' in _file:
  202. for row in test_xls:
  203. #imageID = 'I'+row[3] #prevents shorter imageIDs from matching to longer IDs which contain them
  204. imageID = row[3] #use this for loading validation set
  205. if imageID in _file:
  206. if row[5] != 'm06': #throw out all m06 scans
  207. healthy_dict['mri'].append(_file)
  208. healthy_dict['PTID'].append(row[2])
  209. healthy_dict['viscode'].append(row[5])
  210. healthy_dict['imageID'].append(row[3])
  211. break
  212. if 'stableAD' in _file:
  213. for row in test_xls:
  214. #imageID = 'I'+row[3] #prevents shorter imageIDs from matching to longer IDs which contain them
  215. imageID = row[3] #use this for loading validation set
  216. if imageID in _file:
  217. if row[5] != 'm06': #throw out all m06 scans
  218. ad_dict['mri'].append(_file)
  219. ad_dict['PTID'].append(row[2])
  220. ad_dict['viscode'].append(row[5])
  221. ad_dict['imageID'].append(row[3])
  222. break
  223. #sort RNN data into TP dicts
  224. patientSorter = PatientSorter(self.seed)
  225. healthyT1_Rdict,healthyT2_Rdict,healthyT3_Rdict,healthyT4_Rdict,healthyT5_Rdict,healthyT6_Rdict,healthyT7_Rdict,healthyT8_Rdict = patientSorter.sort_patients(healthy_dict,'healthy',self.xls_datapath,first=True)
  226. adT1_Rdict,adT2_Rdict,adT3_Rdict,adT4_Rdict,adT5_Rdict,adT6_Rdict,adT7_Rdict,adT8_Rdict = patientSorter.sort_patients(ad_dict,'ad',self.xls_datapath)
  227. with open(self.xls_datapath+'/figures/InitialDicts.txt','w') as InitialDicts:
  228. InitialDicts.write('healthyBL: '+str(len(healthyBL_dict['mri']))+'\n')
  229. InitialDicts.write('healthyM6: '+str(len(healthyM6_dict['mri']))+'\n')
  230. InitialDicts.write('healthyM12: '+str(len(healthyM12_dict['mri']))+'\n')
  231. InitialDicts.write('healthyM24: '+str(len(healthyM24_dict['mri']))+'\n')
  232. InitialDicts.write('healthyM36: '+str(len(healthyM36_dict['mri']))+'\n')
  233. InitialDicts.write('healthyM48: '+str(len(healthyM48_dict['mri']))+'\n')
  234. InitialDicts.write('healthyM60: '+str(len(healthyM60_dict['mri']))+'\n')
  235. InitialDicts.write('healthyM72: '+str(len(healthyM72_dict['mri']))+'\n')
  236. InitialDicts.write('healthyM84: '+str(len(healthyM84_dict['mri']))+'\n')
  237. InitialDicts.write('healthyM96: '+str(len(healthyM96_dict['mri']))+'\n')
  238. InitialDicts.write('adBL: '+str(len(adBL_dict['mri']))+'\n')
  239. InitialDicts.write('adM6: '+str(len(adM6_dict['mri']))+'\n')
  240. InitialDicts.write('adM12: '+str(len(adM12_dict['mri']))+'\n')
  241. InitialDicts.write('adM24: '+str(len(adM24_dict['mri']))+'\n')
  242. InitialDicts.write('adM36: '+str(len(adM36_dict['mri']))+'\n')
  243. InitialDicts.write('adM48: '+str(len(adM48_dict['mri']))+'\n')
  244. InitialDicts.write('adM60: '+str(len(adM60_dict['mri']))+'\n')
  245. InitialDicts.write('adM72: '+str(len(adM72_dict['mri']))+'\n')
  246. InitialDicts.write('adM84: '+str(len(adM84_dict['mri']))+'\n')
  247. InitialDicts.write('adM96: '+str(len(adM96_dict['mri']))+'\n')
  248. InitialDicts.write('healthyBL: '+'\n')
  249. InitialDicts.write(str(healthyBL_dict['mri'])+'\n')
  250. InitialDicts.write('healthyM6: '+'\n')
  251. InitialDicts.write(str(healthyM6_dict['mri'])+'\n')
  252. InitialDicts.write('healthyM12: '+'\n')
  253. InitialDicts.write(str(healthyM12_dict['mri'])+'\n')
  254. InitialDicts.write('healthyM24: '+'\n')
  255. InitialDicts.write(str(healthyM24_dict['mri'])+'\n')
  256. InitialDicts.write('healthyM36: '+'\n')
  257. InitialDicts.write(str(healthyM36_dict['mri'])+'\n')
  258. InitialDicts.write('healthyM48: '+'\n')
  259. InitialDicts.write(str(healthyM48_dict['mri'])+'\n')
  260. InitialDicts.write('adBL: '+'\n')
  261. InitialDicts.write(str(adBL_dict['mri'])+'\n')
  262. InitialDicts.write('adM6: '+'\n')
  263. InitialDicts.write(str(adM6_dict['mri'])+'\n')
  264. InitialDicts.write('adM12: '+'\n')
  265. InitialDicts.write(str(adM12_dict['mri'])+'\n')
  266. InitialDicts.write('adM24: '+'\n')
  267. InitialDicts.write(str(adM24_dict['mri'])+'\n')
  268. InitialDicts.write('adM36: '+'\n')
  269. InitialDicts.write(str(adM36_dict['mri'])+'\n')
  270. InitialDicts.write('adM48: '+'\n')
  271. InitialDicts.write(str(adM48_dict['mri'])+'\n')
  272. self.shuffle_dict_lists (healthyBL_dict)
  273. self.shuffle_dict_lists (healthyM6_dict)
  274. self.shuffle_dict_lists (healthyM12_dict) #Randomly shuffle lists healthy_dict ['JD'] and healthy_dict['mri'] in unison
  275. self.shuffle_dict_lists (healthyM24_dict)
  276. self.shuffle_dict_lists (healthyM36_dict)
  277. self.shuffle_dict_lists (healthyM48_dict)
  278. self.shuffle_dict_lists (adBL_dict)
  279. self.shuffle_dict_lists (adM6_dict)
  280. self.shuffle_dict_lists (adM12_dict)
  281. self.shuffle_dict_lists (adM24_dict)
  282. self.shuffle_dict_lists (adM36_dict)
  283. self.shuffle_dict_lists (adM48_dict)
  284. self.shuffle_dict_lists (healthyT1_Rdict) #This shuffling is actually getting the patients out of order!
  285. self.shuffle_dict_lists (healthyT2_Rdict) #But doesn't matter because I use the PTIDs to sort them again later.
  286. self.shuffle_dict_lists (healthyT3_Rdict)
  287. self.shuffle_dict_lists (healthyT4_Rdict)
  288. self.shuffle_dict_lists (healthyT5_Rdict)
  289. self.shuffle_dict_lists (healthyT6_Rdict)
  290. self.shuffle_dict_lists (healthyT7_Rdict)
  291. self.shuffle_dict_lists (healthyT8_Rdict)
  292. self.shuffle_dict_lists (adT1_Rdict)
  293. self.shuffle_dict_lists (adT2_Rdict)
  294. self.shuffle_dict_lists (adT3_Rdict)
  295. self.shuffle_dict_lists (adT4_Rdict)
  296. self.shuffle_dict_lists (adT5_Rdict)
  297. self.shuffle_dict_lists (adT6_Rdict)
  298. self.shuffle_dict_lists (adT7_Rdict)
  299. self.shuffle_dict_lists (adT8_Rdict)
  300. self.shuffle_dict_lists (ad_dict)
  301. self.shuffle_dict_lists (healthy_dict)
  302. #return healthyT1_dict,healthyT2_dict,healthyT3_dict,adT1_dict,adT2_dict,adT3_dict,healthyT1_Rdict,healthyT2_Rdict,healthyT3_Rdict,adT1_Rdict,adT2_Rdict,adT3_Rdict #,healthyExtra_dict,adExtra_dict #, smci_dict, pmci_dict
  303. #return healthyBL_dict,healthyM6_dict,healthyM12_dict,healthyM24_dict,healthyM36_dict,healthyM48_dict,adBL_dict,adM6_dict,adM12_dict,adM24_dict,adM36_dict,adM48_dict,healthyT1_Rdict,healthyT2_Rdict,healthyT3_Rdict,adT1_Rdict,adT2_Rdict,adT3_Rdict
  304. return healthy_dict, ad_dict, healthyT1_Rdict,healthyT2_Rdict,healthyT3_Rdict,healthyT4_Rdict,healthyT5_Rdict,healthyT6_Rdict,healthyT7_Rdict,healthyT8_Rdict,adT1_Rdict,adT2_Rdict,adT3_Rdict,adT4_Rdict,adT5_Rdict,adT6_Rdict,adT7_Rdict,adT8_Rdict
  305. def split_filenames (self, healthy_dict, ad_dict, val_split = 0.20):
  306. '''Split filename dictionaries in training/validation and test sets.
  307. '''
  308. keys = ['mri']
  309. train_dict, val_dict, test_dict = [{key: [] for key in keys} for _ in range(3)]
  310. # num_test_samples = int(((len(healthy_dict['mri']) + len(ad_dict['mri']) \
  311. # +len(pmci_dict['mri']) + len(smci_dict['mri']))*val_split)/4)
  312. # num_val_samples = int(int(math.ceil ((val_split*(len(ad_dict['mri']) + len(healthy_dict['mri']) \
  313. # +len(pmci_dict['mri']) + len(smci_dict['mri'])- num_test_samples*4)))/4))
  314. num_test_ad = int(len(ad_dict['mri'])*val_split)
  315. num_test_healthy = int(len(healthy_dict['mri'])*val_split)
  316. num_val_ad = int((len(ad_dict['mri'])-num_test_ad)*val_split)
  317. num_val_healthy = int((len(healthy_dict['mri'])-num_test_healthy)*val_split)
  318. with open(self.xls_datapath+'/figures/DataList.txt','w') as dataList:
  319. dataList.write('FOR CNN'+'\n')
  320. dataList.write('Dict Sizes:'+'\n')
  321. dataList.write('#AD_dict '+str(len(ad_dict['mri']))+'#NC_dict '+str(len(healthy_dict['mri']))+'\n'+'\n')
  322. #dataList.write('#ADT1_dict '+str(len(adT1_dict['mri']))+'#ADT2_dict '+str(len(adT2_dict['mri']))+'#ADT3_dict '+str(len(adT3_dict['mri']))+
  323. #'#NCT1_dict '+str(len(healthyT1_dict['mri']))+'#NCT2_dict '+str(len(healthyT2_dict['mri']))+'#NCT3_dict '+str(len(healthyT3_dict['mri']))+'\n'+'\n')
  324. #dataList.write('Test Dict ADT2:'+'\n')
  325. #dataList.write(str(adT2_dict['mri'])+'\n')
  326. dataList.write('Test Data Split by class:'+'\n')
  327. dataList.write('#ADtestsamples '+str(num_test_ad)+'#NCtestsamples '+str(num_test_healthy)+'\n'+'\n')
  328. #dataList.write('#ADtestsamplesT1 '+str(num_test_adT1)+'#ADtestsamplesT2 '+str(num_test_adT2)+'#ADtestsamplesT3 '+str(num_test_adT3)+
  329. #'#NCtestsamplesT1 '+str(num_test_healthyT1)+'#NCtestsamplesT2 '+str(num_test_healthyT2)+'#NCtestsamplesT3 '+str(num_test_healthyT3)+'\n'+'\n')
  330. dataList.write('Val Data Split by class:'+'\n')
  331. dataList.write('#ADvalsamples '+str(num_val_ad)+'#NCvalsamples '+str(num_val_healthy)+'\n'+'\n')
  332. #dataList.write('#ADvalsamplesT1 '+str(num_val_adT1)+'#ADvalsamplesT2 '+str(num_val_adT2)+'#ADvalsamplesT3 '+str(num_val_adT3)+
  333. #'#NCvalsamplesT1 '+str(num_val_healthyT1)+'#NCvalsamplesT2 '+str(num_val_healthyT2)+'#NCvalsamplesT3 '+str(num_val_healthyT3)+'\n'+'\n')
  334. test_ad = ad_dict['mri'][:num_test_ad]
  335. test_healthy = healthy_dict['mri'][:num_test_healthy]
  336. test_dict['mri'] = test_ad + test_healthy
  337. test_dict['health_state'] = np.concatenate((np.zeros(len(test_ad)),np.ones(len(test_healthy))))
  338. val_ad = ad_dict['mri'][num_test_ad : num_test_ad + num_val_ad]
  339. val_healthy = healthy_dict['mri'][num_test_healthy : num_test_healthy + num_val_healthy]
  340. val_dict['mri'] = val_ad + val_healthy
  341. val_dict['health_state'] = np.concatenate((np.zeros(len(val_ad)),np.ones(len(val_healthy))))
  342. train_ad = ad_dict['mri'][num_test_ad + num_val_ad:]
  343. train_healthy = healthy_dict['mri'][num_test_healthy + num_val_healthy:]
  344. train_dict['mri'] = train_ad + train_healthy
  345. train_dict['health_state'] = np.concatenate((np.zeros(len(train_ad)),np.ones(len(train_healthy))))
  346. with open(self.xls_datapath+'/figures/DataList.txt','a') as dataList:
  347. dataList.write('Train Data Split by class:'+'\n')
  348. dataList.write('#ADtrainsamples '+str(len(train_ad))+'#NCtrainsamples '+str(len(train_healthy))+'\n')
  349. #dataList.write('#ADtrainsamplesT1 '+str(len(train_adT1))+'#ADtrainsamplesT2 '+str(len(train_adT2))+'#ADtrainsamplesT3 '+str(len(train_adT3))+
  350. #'#NCtrainsamplesT1 '+str(len(train_healthyT1))+'#NCtrainsamplesT2 '+str(len(train_healthyT2))+'#NCtrainsamplesT3 '+str(len(train_healthyT3))+'\n'+'\n')
  351. #dataList.write('Number of non-dummy images in train data dictionaries:'+'\n')
  352. #dataList.wrtie('#ADtrainsamplesT1 '+str(len(train_adT1))+'#ADtrainsamplesT2 '+str(len(train_adT2))+'#ADtrainsamplesT3 '+str(len(train_adT3))+
  353. #'#NCtrainsamplesT1 '+str(len(train_healthyT1))+'#NCtrainsamplesT2 '+str(len(train_healthyT2))+'#NCtrainsamplesT3 '+str(len(train_healthyT3))+'\n'+'\n')
  354. return train_dict,val_dict,test_dict
  355. #SHOULD FOLLOW SAME ORDER OF SUBJECTS AS mri_file_names
  356. def get_data_xls (self, mri_file_names, RNN=False):
  357. '''Method extracts clinical variables data for all files in mri_file_names list
  358. Both mri_file_names and LP_ADNIMERGE.csv are in imageID order
  359. '''
  360. with open(self.xls_datapath + '/' + 'LP_ADNIMERGE.csv', 'r') as f:
  361. reader = csv.reader(f)
  362. xls = [row for row in reader] #Extract all data from csv file in a list.
  363. #xls extracts baseline features for patients sorted as in mri_file_names
  364. test_xls=[]
  365. for file_name in mri_file_names:
  366. for row in xls[1:]:
  367. #imageID = 'I'+row[3] #prevents shorter imageIDs from matching to longer IDs which contain them
  368. imageID = row[3] #use this for loading validation set
  369. if imageID in file_name:
  370. test_xls.append(row)
  371. break
  372. #check datalists
  373. if RNN == False:
  374. with open(self.xls_datapath+'/figures/DataList.txt','a') as dataList:
  375. dataList.write('Total CNN Train/Val/Test for each timepoint:'+'\n')
  376. dataList.write("length of _dict(mri) "+str(len(mri_file_names))+'\n')
  377. dataList.write("length of test_xls "+str(len(test_xls))+'\n'+'\n')
  378. """
  379. with open(self.xls_datapath + '/xlschecks/' + 'dictmri'+str(mri_file_names[1])+'.txt', 'w') as names:
  380. for line in mri_file_names:
  381. names.write(" ".join(line)+"\n")
  382. with open(self.xls_datapath + '/xlschecks/' + 'testxls'+str(mri_file_names[1])+'.txt', 'w') as testxls:
  383. for line in test_xls:
  384. testxls.write(" ".join(line)+"\n")
  385. """
  386. else:
  387. with open(self.xls_datapath+'/figures/DataList.txt','a') as dataList:
  388. dataList.write('Total RNN scans in each class for each timepoint (H/A):'+'\n')
  389. dataList.write("length of _dict(mri) "+str(len(mri_file_names))+'\n')
  390. dataList.write("length of test_xls "+str(len(test_xls))+'\n'+'\n')
  391. # with open(self.xls_datapath + '/xlschecks/' + 'dictmri'+str(mri_file_names[1])+'.txt', 'w') as names:
  392. # for line in mri_file_names:
  393. # names.write(" ".join(line)+"\n")
  394. # with open(self.xls_datapath + '/xlschecks/' + 'testxls'+str(mri_file_names[1])+'.txt', 'w') as testxls:
  395. # for line in test_xls:
  396. # testxls.write(" ".join(line)+"\n")
  397. #convert gender features to binary variables #removed ethnicity/race
  398. for row in test_xls:
  399. # row[8] = float(row[8])
  400. if row[6] == 'M':
  401. row[6] = 1.
  402. else:
  403. row[6] = 0.
  404. # row[10] = float(row[10])
  405. # if row[11] == 'Hisp/Latino':
  406. # row[11] = 1.
  407. # else:
  408. # row[11] = 0.
  409. # if row[12] == 'White': #White or non-white only;
  410. # row[12] = 1. #Cluster Am. Indian, unknown, black, asian
  411. # else:
  412. # row[12] = 0.
  413. # row[13:] = [float(x) for x in row[13:]]
  414. clinical_features = np.asarray([row[6:8] for row in test_xls]) #only capturing gender and age
  415. PTIDs = np.asarray([row[2] for row in test_xls])
  416. imageIDs = np.asarray([row[3] for row in test_xls])
  417. confids = np.asarray([row[15] for row in test_xls])
  418. csfs = np.asarray([row[16] for row in test_xls])
  419. return clinical_features, PTIDs, imageIDs, confids, csfs
  420. def get_data_mri (self, filename_dict, mri_datapath, RNN=False):
  421. '''Loads subject volumes from filename dictionary
  422. Returns MRI volume and label
  423. '''
  424. mris = np.zeros( (len(filename_dict['mri']),) + self.target_shape)
  425. jacs = np.zeros( (len(filename_dict['mri']),) + self.target_shape)
  426. if RNN == False:
  427. labels = filename_dict['health_state']
  428. else:
  429. labels = np.zeros(len(filename_dict['mri'])) #just a placeholder bc I never actually use this value
  430. #keys = ['JD', 'mri']
  431. keys = ['mri']
  432. for key in keys:
  433. for j, filename in enumerate (filename_dict[key]):
  434. if filename == 'NaN': #for dummy images, can likely delete
  435. mris[j] = np.full((91,109,91,1),-1)
  436. else:
  437. proxy_image = nib.load(mri_datapath + '/' + filename)
  438. image = np.asarray(proxy_image.dataobj)
  439. # if key == 'JD':
  440. # jacs[j] = np.asarray(np.expand_dims(image, axis = -1))
  441. # else:
  442. mris[j] = np.asarray(np.expand_dims(image, axis = -1))
  443. with open(self.xls_datapath+'/figures/getdatamri.txt','w') as getdatamri:
  444. getdatamri.write('Images:'+'\n')
  445. getdatamri.write(str(mris)+'\n')
  446. return mris.astype('float32'), jacs.astype('float32'), labels
  447. def normalize_data (self, train_im, val_im, test_im, mode):
  448. #We use different normalization procedures depending on data type
  449. if mode != 'mri' and mode != 'jac' and mode != 'xls':
  450. print ('Mode has to be mri, jac or xls depending on image data type')
  451. elif mode == 'mri':
  452. print('length of train_im: ', len(train_im))
  453. std = np.std(train_im, axis = 0)
  454. #print('std: ', std)
  455. mean = np.mean (train_im, axis = 0)
  456. #print('mean: ', mean)
  457. train_im = (train_im - mean)/(std + 1e-20)
  458. print('length of norm train_im: ', len(train_im))
  459. val_im = (val_im - mean)/(std + 1e-20)
  460. test_im = (test_im - mean)/(std + 1e-20)
  461. elif mode == 'jac':
  462. high = np.max(train_im)
  463. low = np.min(train_im)
  464. train_im = (train_im - low)/(high - low)
  465. val_im = (val_im - low)/(high - low)
  466. test_im = (test_im - low)/(high - low)
  467. else:
  468. high = np.max(train_im, axis = 0)
  469. low = np.min(train_im, axis = 0)
  470. train_im = (train_im - low)/(high - low)
  471. val_im = (val_im - low)/(high - low)
  472. test_im = (test_im - low)/(high - low)
  473. return train_im, val_im, test_im
  474. def normalize_data_RNN (self, dataT1, dataT2, dataT3, mode):
  475. #We use different normalization procedures depending on data type
  476. if mode != 'mri' and mode != 'jac' and mode != 'xls':
  477. print ('Mode has to be mri, jac or xls depending on image data type')
  478. elif mode == 'mri':
  479. stdT1 = np.std(dataT1, axis = 0)
  480. meanT1 = np.mean (dataT1, axis = 0)
  481. dataT1 = (dataT1 - meanT1)/(stdT1 + 1e-20)
  482. stdT2 = np.std(dataT2, axis = 0)
  483. meanT2 = np.mean (dataT2, axis = 0)
  484. dataT2 = (dataT2 - meanT2)/(stdT2 + 1e-20)
  485. stdT3 = np.std(dataT3, axis = 0)
  486. meanT3 = np.mean (dataT3, axis = 0)
  487. dataT3 = (dataT3 - meanT3)/(stdT3 + 1e-20)
  488. return dataT1, dataT2, dataT3
  489. def split_data_RNN (self, healthy_arrayT1,healthy_arrayT2,healthy_arrayT3, ad_arrayT1,ad_arrayT2,ad_arrayT3,rnn_HptidT1_padded,rnn_HptidT2_padded,rnn_HptidT3_padded,rnn_HimageIDT1_padded,rnn_HimageIDT2_padded,rnn_HimageIDT3_padded,rnn_AptidT1_padded,rnn_AptidT2_padded,rnn_AptidT3_padded,rnn_AimageIDT1_padded,rnn_AimageIDT2_padded,rnn_AimageIDT3_padded, val_split = 0.20):
  490. '''Split the feature vectors for the RNN into training/validation and test sets.
  491. Should be the same process as split filenames, but now I have arrays instead of dictionaries
  492. Also, I want to split data by patient, not by scan.
  493. All timepoint arrays should be organized by patient with dummy vectors as placeholders.
  494. So I only need to split T1, then the same spots in T2 and T3 can follow.
  495. '''
  496. train_arrayT1 = []
  497. train_arrayT2 = []
  498. train_arrayT3 = []
  499. val_arrayT1 = []
  500. val_arrayT2 = []
  501. val_arrayT3 = []
  502. test_arrayT1 = []
  503. test_arrayT2 = []
  504. test_arrayT3 = []
  505. num_test_ad= int(len(ad_arrayT1)*val_split)
  506. num_test_healthy = int(len(healthy_arrayT1)*val_split)
  507. num_val_ad = int((len(ad_arrayT1)-num_test_ad)*val_split)
  508. num_val_healthy = int((len(healthy_arrayT1)-num_test_healthy)*val_split)
  509. test_adT1 = ad_arrayT1[:num_test_ad]
  510. test_adT2 = ad_arrayT2[:num_test_ad]
  511. test_adT3 = ad_arrayT3[:num_test_ad]
  512. test_healthyT1 = healthy_arrayT1[:num_test_healthy]
  513. test_healthyT2 = healthy_arrayT2[:num_test_healthy]
  514. test_healthyT3 = healthy_arrayT3[:num_test_healthy]
  515. test_arrayT1 = np.concatenate((test_adT1, test_healthyT1),axis=0)
  516. test_arrayT2 = np.concatenate((test_adT2, test_healthyT2),axis=0)
  517. test_arrayT3 = np.concatenate((test_adT3, test_healthyT3),axis=0)
  518. test_labels = np.concatenate((np.zeros(len(test_adT1)),np.ones(len(test_healthyT1))))
  519. test_AptidT1 = rnn_AptidT1_padded[:num_test_ad]
  520. test_AptidT2 = rnn_AptidT2_padded[:num_test_ad]
  521. test_AptidT3 = rnn_AptidT3_padded[:num_test_ad]
  522. test_AimageIDT1 = rnn_AimageIDT1_padded[:num_test_ad]
  523. test_AimageIDT2 = rnn_AimageIDT2_padded[:num_test_ad]
  524. test_AimageIDT3 = rnn_AimageIDT3_padded[:num_test_ad]
  525. test_HptidT1 = rnn_HptidT1_padded[:num_test_healthy]
  526. test_HptidT2 = rnn_HptidT2_padded[:num_test_healthy]
  527. test_HptidT3 = rnn_HptidT3_padded[:num_test_healthy]
  528. test_HimageIDT1 = rnn_HimageIDT1_padded[:num_test_healthy]
  529. test_HimageIDT2 = rnn_HimageIDT2_padded[:num_test_healthy]
  530. test_HimageIDT3 = rnn_HimageIDT3_padded[:num_test_healthy]
  531. test_ptidT1 = np.concatenate((test_AptidT1, test_HptidT1),axis=0)
  532. test_ptidT2 = np.concatenate((test_AptidT2, test_HptidT2),axis=0)
  533. test_ptidT3 = np.concatenate((test_AptidT3, test_HptidT3),axis=0)
  534. test_imageIDT1 = np.concatenate((test_AimageIDT1, test_HimageIDT1),axis=0)
  535. test_imageIDT2 = np.concatenate((test_AimageIDT2, test_HimageIDT2),axis=0)
  536. test_imageIDT3 = np.concatenate((test_AimageIDT3, test_HimageIDT3),axis=0)
  537. val_adT1 = ad_arrayT1[num_test_ad : num_test_ad + num_val_ad]
  538. val_adT2 = ad_arrayT2[num_test_ad : num_test_ad + num_val_ad]
  539. val_adT3 = ad_arrayT3[num_test_ad : num_test_ad + num_val_ad]
  540. val_healthyT1 = healthy_arrayT1[num_test_healthy : num_test_healthy + num_val_healthy]
  541. val_healthyT2 = healthy_arrayT2[num_test_healthy : num_test_healthy + num_val_healthy]
  542. val_healthyT3 = healthy_arrayT3[num_test_healthy : num_test_healthy + num_val_healthy]
  543. val_arrayT1 = np.concatenate((val_adT1, val_healthyT1),axis=0)
  544. val_arrayT2 = np.concatenate((val_adT2, val_healthyT2),axis=0)
  545. val_arrayT3 = np.concatenate((val_adT3, val_healthyT3),axis=0)
  546. val_labels = np.concatenate((np.zeros(len(val_adT1)),np.ones(len(val_healthyT1))))
  547. train_adT1 = ad_arrayT1[num_test_ad + num_val_ad:]
  548. train_adT2 = ad_arrayT2[num_test_ad + num_val_ad:]
  549. train_adT3 = ad_arrayT3[num_test_ad + num_val_ad:]
  550. train_healthyT1 = healthy_arrayT1[num_test_healthy + num_val_healthy:]
  551. train_healthyT2 = healthy_arrayT2[num_test_healthy + num_val_healthy:]
  552. train_healthyT3 = healthy_arrayT3[num_test_healthy + num_val_healthy:]
  553. train_arrayT1 = np.concatenate((train_adT1, train_healthyT1),axis=0)
  554. train_arrayT2 = np.concatenate((train_adT2, train_healthyT2),axis=0)
  555. train_arrayT3 = np.concatenate((train_adT3, train_healthyT3),axis=0)
  556. train_labels = np.concatenate((np.zeros(len(train_adT1)),np.ones(len(train_healthyT1))))
  557. with open(self.xls_datapath+'/figures/DataList.txt','a') as dataList:
  558. dataList.write('AFTER CLASS BALANCING'+'\n')
  559. dataList.write('RNN Train Data Split by class and timepoint:'+'\n')
  560. dataList.write('#ADtrainsamplesT1 '+str(len(train_adT1))+'#ADtrainsamplesT2 '+str(len(train_adT2))+'#ADtrainsamplesT3 '+str(len(train_adT3))+
  561. '#NCtrainsamplesT1 '+str(len(train_healthyT1))+'#NCtrainsamplesT2 '+str(len(train_healthyT2))+'#NCtrainsamplesT3 '+str(len(train_healthyT3))+'\n'+'\n')
  562. dataList.write('RNN Val Data Split by class and timepoint:'+'\n')
  563. dataList.write('#ADvalsamplesT1 '+str(len(val_adT1))+'#ADvalsamplesT2 '+str(len(val_adT2))+'#ADvalsamplesT3 '+str(len(val_adT3))+
  564. '#NCvalsamplesT1 '+str(len(val_healthyT1))+'#NCvalsamplesT2 '+str(len(val_healthyT2))+'#NCvalsamplesT3 '+str(len(val_healthyT3))+'\n'+'\n')
  565. dataList.write('RNN Test Data Split by class and timepoint:'+'\n')
  566. dataList.write('#ADtestsamplesT1 '+str(len(test_adT1))+'#ADtestsamplesT2 '+str(len(test_adT2))+'#ADtestsamplesT3 '+str(len(test_adT3))+
  567. '#NCtestsamplesT1 '+str(len(test_healthyT1))+'#NCtestsamplesT2 '+str(len(test_healthyT2))+'#NCtestsamplesT3 '+str(len(test_healthyT3))+'\n'+'\n')
  568. return train_arrayT1,train_arrayT2,train_arrayT3,val_arrayT1,val_arrayT2,val_arrayT3,test_arrayT1,test_arrayT2,test_arrayT3, train_labels,val_labels,test_labels, test_ptidT1,test_ptidT2,test_ptidT3,test_imageIDT1,test_imageIDT2,test_imageIDT3
  569. def get_train_val_test (self, val_split, mri_datapath):
  570. healthy_dict,ad_dict,healthyT1_Rdict,healthyT2_Rdict,healthyT3_Rdict,healthyT4_Rdict,healthyT5_Rdict,healthyT6_Rdict,healthyT7_Rdict,healthyT8_Rdict,adT1_Rdict,adT2_Rdict,adT3_Rdict,adT4_Rdict,adT5_Rdict,adT6_Rdict,adT7_Rdict,adT8_Rdict = self.get_filenames(mri_datapath)
  571. #make classes balanced
  572. diff = len(healthy_dict['mri'])-len(ad_dict['mri'])
  573. for i in range(diff):
  574. healthy_dict['mri'].pop(-1)
  575. healthy_dict['PTID'].pop(-1)
  576. healthy_dict['viscode'].pop(-1)
  577. healthy_dict['imageID'].pop(-1)
  578. train_dict, val_dict, test_dict = self.split_filenames (healthy_dict, ad_dict, val_split = val_split)
  579. #train_dictT1,train_dictT2,train_dictT3, val_dictT1,val_dictT2,val_dictT3, test_dictT1,test_dictT2,test_dictT3 = self.split_filenames (healthyM24_dict,healthyM36_dict,healthyM48_dict, adBL_dict,adM12_dict,adM24_dict, val_split = val_split)
  580. #train_dictT4,train_dictT5,train_dictT6, val_dictT4,val_dictT5,val_dictT6, test_dictT4,test_dictT5,test_dictT6 = self.split_filenames (healthyBL_dict,healthyM6_dict,healthyM12_dict, adM6_dict,adM36_dict,adM48_dict, val_split = val_split, first=False)
  581. # print("length of train_dict[mri]"+str(len(train_dict['mri'])))
  582. train_mri, train_jac, train_labels = self.get_data_mri(train_dict,mri_datapath)
  583. train_xls, train_ptid, train_imageID, train_confid, train_csf = self.get_data_xls (train_dict['mri'])
  584. val_mri, val_jac, val_labels = self.get_data_mri(val_dict,mri_datapath)
  585. val_xls, val_ptid, val_imageID, val_confid, val_csf = self.get_data_xls (val_dict['mri'])
  586. test_mri, test_jac, test_labels = self.get_data_mri(test_dict,mri_datapath)
  587. test_xls, test_ptid, test_imageID, test_confid, test_csf = self.get_data_xls (test_dict['mri'])
  588. #previously removed normalization because it seemed to be making all my images the exact same...?
  589. #somehow it's ok now though! See normalizedTestData
  590. #now seems to be making the images weird and dark. I normalize in preprocess so I don't think I need to here
  591. #carry the non-normalized through for grad-cam purposes
  592. test_mri_nonorm = test_mri
  593. #train_mri, val_mri, test_mri = self.normalize_data (train_mri, val_mri, test_mri, mode = 'mri')
  594. #with open(self.xls_datapath+'/figures/normalizedTestData.txt','w') as normed:
  595. # normed.write('Normalized CNN Train Images:'+'\n')
  596. # normed.write(str(train_mri)+'\n')
  597. test_data = (test_mri, test_mri, test_xls, test_labels, test_ptid, test_imageID, test_confid, test_csf)
  598. val_data = (val_mri, val_mri, val_xls, val_labels, val_ptid, val_imageID, val_confid, val_csf)
  599. train_data = (train_mri, train_mri, train_xls, train_labels, train_ptid, train_imageID, train_confid, train_csf)
  600. #get data lists for RNN
  601. rnn_HmriT1, rnn_HjacT1, rnn_HlabelsT1 = self.get_data_mri(healthyT1_Rdict,mri_datapath, RNN=True)
  602. rnn_HxlsT1, rnn_HptidT1, rnn_HimageIDT1, rnn_HconfidT1, rnn_HcsfT1 = self.get_data_xls (healthyT1_Rdict['mri'], RNN=True)
  603. rnn_HmriT2, rnn_HjacT2, rnn_HlabelsT2 = self.get_data_mri(healthyT2_Rdict,mri_datapath, RNN=True)
  604. rnn_HxlsT2, rnn_HptidT2, rnn_HimageIDT2, rnn_HconfidT2, rnn_HcsfT2 = self.get_data_xls (healthyT2_Rdict['mri'], RNN=True)
  605. rnn_HmriT3, rnn_HjacT3, rnn_HlabelsT3 = self.get_data_mri(healthyT3_Rdict,mri_datapath, RNN=True)
  606. rnn_HxlsT3, rnn_HptidT3, rnn_HimageIDT3, rnn_HconfidT3, rnn_HcsfT3 = self.get_data_xls (healthyT3_Rdict['mri'], RNN=True)
  607. #normalize:
  608. #rnn_HmriT1,rnn_HmriT2,rnn_HmriT3 = self.normalize_data_RNN (rnn_HmriT1,rnn_HmriT2,rnn_HmriT3, mode = 'mri') #Don't have any dummies yet, so this should only affect the actual images
  609. #rnn_HjacT1,rnn_HjacT2,rnn_HjacT3 = self.normalize_data_RNN (rnn_HjacT1,rnn_HjacT2,rnn_HjacT3, mode = 'jac')
  610. rnn_HdataT1 = (rnn_HmriT1, rnn_HjacT1, rnn_HxlsT1, rnn_HlabelsT1, rnn_HptidT1, rnn_HimageIDT1, rnn_HconfidT1, rnn_HcsfT1)
  611. rnn_HdataT2 = (rnn_HmriT2, rnn_HjacT2, rnn_HxlsT2, rnn_HlabelsT2, rnn_HptidT2, rnn_HimageIDT2, rnn_HconfidT2, rnn_HcsfT2)
  612. rnn_HdataT3 = (rnn_HmriT3, rnn_HjacT3, rnn_HxlsT3, rnn_HlabelsT3, rnn_HptidT3, rnn_HimageIDT3, rnn_HconfidT3, rnn_HcsfT3)
  613. rnn_AmriT1, rnn_AjacT1, rnn_AlabelsT1 = self.get_data_mri(adT1_Rdict,mri_datapath, RNN=True)
  614. rnn_AxlsT1, rnn_AptidT1, rnn_AimageIDT1, rnn_AconfidT1, rnn_AcsfT1 = self.get_data_xls (adT1_Rdict['mri'], RNN=True)
  615. rnn_AmriT2, rnn_AjacT2, rnn_AlabelsT2 = self.get_data_mri(adT2_Rdict,mri_datapath, RNN=True)
  616. rnn_AxlsT2, rnn_AptidT2, rnn_AimageIDT2, rnn_AconfidT2, rnn_AcsfT2 = self.get_data_xls (adT2_Rdict['mri'], RNN=True)
  617. rnn_AmriT3, rnn_AjacT3, rnn_AlabelsT3 = self.get_data_mri(adT3_Rdict,mri_datapath, RNN=True)
  618. rnn_AxlsT3, rnn_AptidT3, rnn_AimageIDT3, rnn_AconfidT3, rnn_AcsfT3 = self.get_data_xls (adT3_Rdict['mri'], RNN=True)
  619. #normalize:
  620. #rnn_AmriT1,rnn_AmriT2,rnn_AmriT3 = self.normalize_data_RNN (rnn_AmriT1,rnn_AmriT2,rnn_AmriT3, mode = 'mri')
  621. #rnn_AjacT1,rnn_AjacT2,rnn_AjacT3 = self.normalize_data_RNN (rnn_AjacT1,rnn_AjacT2,rnn_AjacT3, mode = 'jac')
  622. #with open(self.xls_datapath+'/figures/normalizedTestData.txt','a') as normed:
  623. # normed.write('Normalized RNN T1 NC Images:'+'\n')
  624. # normed.write(str(rnn_HmriT1)+'\n')
  625. # normed.write('Normalized RNN T1 AD Images:'+'\n')
  626. # normed.write(str(rnn_AmriT1)+'\n')
  627. rnn_AdataT1 = (rnn_AmriT1, rnn_AjacT1, rnn_AxlsT1, rnn_AlabelsT1, rnn_AptidT1, rnn_AimageIDT1, rnn_AconfidT1, rnn_AcsfT1)
  628. rnn_AdataT2 = (rnn_AmriT2, rnn_AjacT2, rnn_AxlsT2, rnn_AlabelsT2, rnn_AptidT2, rnn_AimageIDT2, rnn_AconfidT2, rnn_AcsfT2)
  629. rnn_AdataT3 = (rnn_AmriT3, rnn_AjacT3, rnn_AxlsT3, rnn_AlabelsT3, rnn_AptidT3, rnn_AimageIDT3, rnn_AconfidT3, rnn_AcsfT3)
  630. return train_data, val_data, test_data,rnn_HdataT1,rnn_HdataT2,rnn_HdataT3,rnn_AdataT1,rnn_AdataT2,rnn_AdataT3, test_mri_nonorm