patientsort.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import numpy as np
  2. from numpy.random import RandomState
  3. from os import listdir
  4. import nibabel as nib
  5. import math
  6. import csv
  7. import random
  8. from itertools import permutations
  9. class PatientSorter():
  10. def __init__(self, seed=None):
  11. self.seed = seed
  12. def sort_patients(self, scan_dict, label, xls_datapath,first=False):
  13. #from LP_ADNIMERGE, AD patients have at most 4 scans and healthy patients have at most 7 scans; but there are 8 possible timepoints (excluding m06)
  14. #initialize dicts
  15. keys = ['mri','PTID','viscode','imageID']
  16. T1_dict,T2_dict,T3_dict,T4_dict,T5_dict,T6_dict,T7_dict,T8_dict = [{key: [] for key in keys} for i in range(8)]
  17. dictList = [T1_dict,T2_dict,T3_dict,T4_dict,T5_dict,T6_dict,T7_dict,T8_dict]
  18. #convert viscodes into numbers
  19. z=0
  20. for viscode in scan_dict['viscode']:
  21. z+=1
  22. if viscode == 'bl':
  23. scan_dict['viscode'][z-1] = 0.
  24. elif viscode == 'm12':
  25. scan_dict['viscode'][z-1] = 12.
  26. elif viscode == 'm24':
  27. scan_dict['viscode'][z-1] = 24.
  28. elif viscode == 'm36':
  29. scan_dict['viscode'][z-1] = 36.
  30. elif viscode == 'm48':
  31. scan_dict['viscode'][z-1] = 48.
  32. elif viscode == 'm60':
  33. scan_dict['viscode'][z-1] = 60.
  34. elif viscode == 'm72':
  35. scan_dict['viscode'][z-1] = 72.
  36. elif viscode == 'm84':
  37. scan_dict['viscode'][z-1] = 84.
  38. elif viscode == 'm96':
  39. scan_dict['viscode'][z-1] = 96.
  40. else:
  41. scan_dict['viscode'][z-1] = 2000.
  42. #Initialize Counting Variables
  43. usedPTIDs = []
  44. scanCount = []
  45. sortedViscodes = []
  46. a=0
  47. #combos2 = combinations([T1,T2,T3,T4,T5,T6,T7],2) #all combinations of TPs of length 2
  48. #combos3 = combinations([T1,T2,T3,T4,T5,T6,T7],3)
  49. #combos4 = combinations([T1,T2,T3,T4,T5,T6,T7],4)
  50. #combos5 = combinations([T1,T2,T3,T4,T5,T6,T7],5)
  51. #combos6 = combinations([T1,T2,T3,T4,T5,T6,T7],6)
  52. #for i in list(combos2):
  53. # print i #at the end, check if all values are true in each i, if so, add to count
  54. #sort scans into TP dicts
  55. #print("scan_dict['mri']",scan_dict['mri'])
  56. T1only=0
  57. T1T2=0
  58. T1T3=0
  59. T1T2T3=0
  60. for ptid1 in scan_dict['PTID']:
  61. used = False
  62. T2=False
  63. T3=False
  64. a+=1
  65. tempList = []
  66. tempViscodes = []
  67. tempImageID = []
  68. sortedTempList = []
  69. sortedTempViscodes = []
  70. sortedTempImageID = []
  71. TempLength = 0
  72. for usedptid in usedPTIDs: #check if this ptid has been accounted for already
  73. if ptid1 == usedptid:
  74. used = True
  75. break
  76. if used == False: #if this PTID has not been accounted for yet... (otherise go to next ptid)
  77. tempList.append(scan_dict['mri'][a-1]) #add first scan to templist
  78. tempViscodes.append(scan_dict['viscode'][a-1]) #add viscode of this scan to templist
  79. tempImageID.append(scan_dict['imageID'][a-1])
  80. usedPTIDs.append(ptid1) #add ptid to used list
  81. for b in range(a,len(scan_dict['PTID'])): #check for other scans with same ptid
  82. if scan_dict['PTID'][b] == ptid1:
  83. tempList.append(scan_dict['mri'][b]) #if match, then add that scan to the templist
  84. tempViscodes.append(scan_dict['viscode'][b]) #add viscode of this scan to templist
  85. tempImageID.append(scan_dict['imageID'][b])
  86. #record number of scans for that patient
  87. tempLength = len(tempList)
  88. scanCount.append(tempLength)
  89. if tempLength > 1: #Throw out all scans with only 1 timepoint
  90. sortedTempList = [x for _,x in sorted(zip(tempViscodes,tempList))] #sort the scans in order by viscode
  91. sortedTempImageID = [x for _,x in sorted(zip(tempViscodes,tempImageID))]
  92. #sortedTempViscodes = [y for _,y in sorted(zip(tempList,tempViscodes))] #sort the viscodes
  93. tempViscodes.sort() #sort the viscodes in order
  94. #print('sortViscodes ',tempViscodes)
  95. sortedViscodes.append(tempViscodes)
  96. #print('usedPTIDs ',usedPTIDs)
  97. #print('length of tempList ',len(tempList))
  98. #print('scanCount ',scanCount)
  99. T1_dict['mri'].append(sortedTempList[0]) #add first scan to T1_dict
  100. T1_dict['PTID'].append(ptid1)
  101. T1_dict['viscode'].append(tempViscodes[0])
  102. T1_dict['imageID'].append(sortedTempImageID[0])
  103. diff = []
  104. for i in range(tempLength-1):
  105. diff.append(tempViscodes[i+1] - tempViscodes[0]) #so that diff[0] applies to the difference between temp1 and temp0 and diff[1] is diff bw temp2 and temp0
  106. for i in range(len(diff)): #sort scans into the appropriate list based on time diff from T1
  107. if diff[i] == 12:
  108. T2_dict['mri'].append(sortedTempList[i+1])
  109. T2_dict['PTID'].append(ptid1)
  110. T2_dict['viscode'].append(tempViscodes[i+1])
  111. T2_dict['imageID'].append(sortedTempImageID[i+1])
  112. T2=True
  113. elif diff[i] == 24:
  114. T3_dict['mri'].append(sortedTempList[i+1])
  115. T3_dict['PTID'].append(ptid1)
  116. T3_dict['viscode'].append(tempViscodes[i+1])
  117. T3_dict['imageID'].append(sortedTempImageID[i+1])
  118. T3=True
  119. elif diff[i] == 36:
  120. T4_dict['mri'].append(sortedTempList[i+1])
  121. T4_dict['PTID'].append(ptid1)
  122. T4_dict['viscode'].append(tempViscodes[i+1])
  123. T4_dict['imageID'].append(sortedTempImageID[i+1])
  124. elif diff[i] == 48:
  125. T5_dict['mri'].append(sortedTempList[i+1])
  126. T5_dict['PTID'].append(ptid1)
  127. T5_dict['viscode'].append(tempViscodes[i+1])
  128. T5_dict['imageID'].append(sortedTempImageID[i+1])
  129. elif diff[i] == 60:
  130. T6_dict['mri'].append(sortedTempList[i+1])
  131. T6_dict['PTID'].append(ptid1)
  132. T6_dict['viscode'].append(tempViscodes[i+1])
  133. T6_dict['imageID'].append(sortedTempImageID[i+1])
  134. elif diff[i] == 72:
  135. T7_dict['mri'].append(sortedTempList[i+1])
  136. T7_dict['PTID'].append(ptid1)
  137. T7_dict['viscode'].append(tempViscodes[i+1])
  138. T7_dict['imageID'].append(sortedTempImageID[i+1])
  139. elif diff[i] == 84:
  140. T8_dict['mri'].append(sortedTempList[i+1])
  141. T8_dict['PTID'].append(ptid1)
  142. T8_dict['viscode'].append(tempViscodes[i+1])
  143. T8_dict['imageID'].append(sortedTempImageID[i+1])
  144. if T2==True and T3==False:
  145. T1T2+=1
  146. if T2==True and T3==True:
  147. T1T2T3+=1
  148. if T2==False and T3==True:
  149. T1T3+=1
  150. if T2==False and T3==False:
  151. T1only+=1
  152. #all scans have been sorted
  153. #get counts of scanCounts
  154. scans1tp = sum(1 for i in scanCount if i == 1)
  155. scans2tp = sum(1 for i in scanCount if i == 2)
  156. scans3tp = sum(1 for i in scanCount if i == 3)
  157. scans4tp = sum(1 for i in scanCount if i == 4)
  158. scans5tp = sum(1 for i in scanCount if i == 5)
  159. scans6tp = sum(1 for i in scanCount if i == 6)
  160. scans7tp = sum(1 for i in scanCount if i == 7)
  161. scans8tp = sum(1 for i in scanCount if i == 8)
  162. if first == True:
  163. with open(xls_datapath+'/figures/RNNDicts.txt','w') as InitialDicts:
  164. InitialDicts.write('Label: '+label+'\n')
  165. InitialDicts.write('Length of T1_dict: '+str(len(T1_dict['mri']))+'\n')
  166. InitialDicts.write('Length of T2_dict: '+str(len(T2_dict['mri']))+'\n')
  167. InitialDicts.write('Length of T3_dict: '+str(len(T3_dict['mri']))+'\n')
  168. InitialDicts.write('Length of T4_dict: '+str(len(T4_dict['mri']))+'\n')
  169. InitialDicts.write('Length of T5_dict: '+str(len(T5_dict['mri']))+'\n')
  170. InitialDicts.write('Length of T6_dict: '+str(len(T6_dict['mri']))+'\n')
  171. InitialDicts.write('Length of T7_dict: '+str(len(T7_dict['mri']))+'\n')
  172. InitialDicts.write('Length of T8_dict: '+str(len(T8_dict['mri']))+'\n')
  173. InitialDicts.write('ScanCount: '+str(scanCount)+'\n')
  174. InitialDicts.write('Number of patients with 1 scan: '+str(scans1tp)+'\n')
  175. InitialDicts.write('Number of patients with 2 scans: '+str(scans2tp)+'\n')
  176. InitialDicts.write('Number of patients with 3 scans: '+str(scans3tp)+'\n')
  177. InitialDicts.write('Number of patients with 4 scans: '+str(scans4tp)+'\n')
  178. InitialDicts.write('Number of patients with 5 scans: '+str(scans5tp)+'\n')
  179. InitialDicts.write('Number of patients with 6 scans: '+str(scans6tp)+'\n')
  180. InitialDicts.write('Number of patients with 7 scans: '+str(scans7tp)+'\n')
  181. InitialDicts.write('Number of patients with 8 scans: '+str(scans8tp)+'\n')
  182. InitialDicts.write('Used PTIDs: '+str(usedPTIDs)+'\n')
  183. InitialDicts.write('Sorted Viscodes: '+str(sortedViscodes)+'\n')
  184. InitialDicts.write('For first 3 timepoints...'+'\n')
  185. InitialDicts.write('T1only: '+str(T1only)+'\n')
  186. InitialDicts.write('T1T2: '+str(T1T2)+'\n')
  187. InitialDicts.write('T1T3: '+str(T1T3)+'\n')
  188. InitialDicts.write('T1T2T3: '+str(T1T2T3)+'\n')
  189. else:
  190. with open(xls_datapath+'/figures/RNNDicts.txt','a') as InitialDicts:
  191. InitialDicts.write('Label: '+label+'\n')
  192. InitialDicts.write('Length of T1_dict: '+str(len(T1_dict['mri']))+'\n')
  193. InitialDicts.write('Length of T2_dict: '+str(len(T2_dict['mri']))+'\n')
  194. InitialDicts.write('Length of T3_dict: '+str(len(T3_dict['mri']))+'\n')
  195. InitialDicts.write('Length of T4_dict: '+str(len(T4_dict['mri']))+'\n')
  196. InitialDicts.write('Length of T5_dict: '+str(len(T5_dict['mri']))+'\n')
  197. InitialDicts.write('Length of T6_dict: '+str(len(T6_dict['mri']))+'\n')
  198. InitialDicts.write('Length of T7_dict: '+str(len(T7_dict['mri']))+'\n')
  199. InitialDicts.write('Length of T8_dict: '+str(len(T8_dict['mri']))+'\n')
  200. InitialDicts.write('ScanCount: '+str(scanCount)+'\n')
  201. InitialDicts.write('Number of patients with 1 scan: '+str(scans1tp)+'\n')
  202. InitialDicts.write('Number of patients with 2 scans: '+str(scans2tp)+'\n')
  203. InitialDicts.write('Number of patients with 3 scans: '+str(scans3tp)+'\n')
  204. InitialDicts.write('Number of patients with 4 scans: '+str(scans4tp)+'\n')
  205. InitialDicts.write('Number of patients with 5 scans: '+str(scans5tp)+'\n')
  206. InitialDicts.write('Number of patients with 6 scans: '+str(scans6tp)+'\n')
  207. InitialDicts.write('Number of patients with 7 scans: '+str(scans7tp)+'\n')
  208. InitialDicts.write('Number of patients with 8 scans: '+str(scans8tp)+'\n')
  209. InitialDicts.write('Used PTIDs: '+str(usedPTIDs)+'\n')
  210. InitialDicts.write('Sorted Viscodes: '+str(sortedViscodes)+'\n')
  211. InitialDicts.write('For first 3 timepoints...'+'\n')
  212. InitialDicts.write('T1only: '+str(T1only)+'\n')
  213. InitialDicts.write('T1T2: '+str(T1T2)+'\n')
  214. InitialDicts.write('T1T3: '+str(T1T3)+'\n')
  215. InitialDicts.write('T1T2T3: '+str(T1T2T3)+'\n')
  216. return T1_dict,T2_dict,T3_dict,T4_dict,T5_dict,T6_dict,T7_dict,T8_dict
  217. # def run_sort(self, healthy_dict, ad_dict):
  218. # healthy_sorted = sort_patients(healthy_dict)
  219. # ad_sorted = sort_patients(ad_dict)
  220. #
  221. # return healthy_sorted, ad_sorted