123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- import numpy as np
- from numpy.random import RandomState
- from os import listdir
- import nibabel as nib
- import math
- import csv
- import random
- from itertools import permutations
- class PatientSorter():
- def __init__(self, seed=None):
- self.seed = seed
-
- def sort_patients(self, scan_dict, label, xls_datapath,first=False):
- #from LP_ADNIMERGE, AD patients have at most 4 scans and healthy patients have at most 7 scans; but there are 8 possible timepoints (excluding m06)
- #initialize dicts
- keys = ['mri','PTID','viscode','imageID']
- T1_dict,T2_dict,T3_dict,T4_dict,T5_dict,T6_dict,T7_dict,T8_dict = [{key: [] for key in keys} for i in range(8)]
- dictList = [T1_dict,T2_dict,T3_dict,T4_dict,T5_dict,T6_dict,T7_dict,T8_dict]
-
- #convert viscodes into numbers
- z=0
- for viscode in scan_dict['viscode']:
- z+=1
- if viscode == 'bl':
- scan_dict['viscode'][z-1] = 0.
- elif viscode == 'm12':
- scan_dict['viscode'][z-1] = 12.
- elif viscode == 'm24':
- scan_dict['viscode'][z-1] = 24.
- elif viscode == 'm36':
- scan_dict['viscode'][z-1] = 36.
- elif viscode == 'm48':
- scan_dict['viscode'][z-1] = 48.
- elif viscode == 'm60':
- scan_dict['viscode'][z-1] = 60.
- elif viscode == 'm72':
- scan_dict['viscode'][z-1] = 72.
- elif viscode == 'm84':
- scan_dict['viscode'][z-1] = 84.
- elif viscode == 'm96':
- scan_dict['viscode'][z-1] = 96.
- else:
- scan_dict['viscode'][z-1] = 2000.
-
- #Initialize Counting Variables
- usedPTIDs = []
- scanCount = []
- sortedViscodes = []
- a=0
- #combos2 = combinations([T1,T2,T3,T4,T5,T6,T7],2) #all combinations of TPs of length 2
- #combos3 = combinations([T1,T2,T3,T4,T5,T6,T7],3)
- #combos4 = combinations([T1,T2,T3,T4,T5,T6,T7],4)
- #combos5 = combinations([T1,T2,T3,T4,T5,T6,T7],5)
- #combos6 = combinations([T1,T2,T3,T4,T5,T6,T7],6)
- #for i in list(combos2):
- # print i #at the end, check if all values are true in each i, if so, add to count
-
- #sort scans into TP dicts
- #print("scan_dict['mri']",scan_dict['mri'])
-
- T1only=0
- T1T2=0
- T1T3=0
- T1T2T3=0
-
- for ptid1 in scan_dict['PTID']:
- used = False
- T2=False
- T3=False
- a+=1
- tempList = []
- tempViscodes = []
- tempImageID = []
- sortedTempList = []
- sortedTempViscodes = []
- sortedTempImageID = []
- TempLength = 0
- for usedptid in usedPTIDs: #check if this ptid has been accounted for already
- if ptid1 == usedptid:
- used = True
- break
- if used == False: #if this PTID has not been accounted for yet... (otherise go to next ptid)
- tempList.append(scan_dict['mri'][a-1]) #add first scan to templist
- tempViscodes.append(scan_dict['viscode'][a-1]) #add viscode of this scan to templist
- tempImageID.append(scan_dict['imageID'][a-1])
- usedPTIDs.append(ptid1) #add ptid to used list
- for b in range(a,len(scan_dict['PTID'])): #check for other scans with same ptid
- if scan_dict['PTID'][b] == ptid1:
- tempList.append(scan_dict['mri'][b]) #if match, then add that scan to the templist
- tempViscodes.append(scan_dict['viscode'][b]) #add viscode of this scan to templist
- tempImageID.append(scan_dict['imageID'][b])
- #record number of scans for that patient
- tempLength = len(tempList)
- scanCount.append(tempLength)
- if tempLength > 1: #Throw out all scans with only 1 timepoint
- sortedTempList = [x for _,x in sorted(zip(tempViscodes,tempList))] #sort the scans in order by viscode
- sortedTempImageID = [x for _,x in sorted(zip(tempViscodes,tempImageID))]
- #sortedTempViscodes = [y for _,y in sorted(zip(tempList,tempViscodes))] #sort the viscodes
- tempViscodes.sort() #sort the viscodes in order
- #print('sortViscodes ',tempViscodes)
- sortedViscodes.append(tempViscodes)
- #print('usedPTIDs ',usedPTIDs)
- #print('length of tempList ',len(tempList))
- #print('scanCount ',scanCount)
- T1_dict['mri'].append(sortedTempList[0]) #add first scan to T1_dict
- T1_dict['PTID'].append(ptid1)
- T1_dict['viscode'].append(tempViscodes[0])
- T1_dict['imageID'].append(sortedTempImageID[0])
- diff = []
- for i in range(tempLength-1):
- diff.append(tempViscodes[i+1] - tempViscodes[0]) #so that diff[0] applies to the difference between temp1 and temp0 and diff[1] is diff bw temp2 and temp0
- for i in range(len(diff)): #sort scans into the appropriate list based on time diff from T1
- if diff[i] == 12:
- T2_dict['mri'].append(sortedTempList[i+1])
- T2_dict['PTID'].append(ptid1)
- T2_dict['viscode'].append(tempViscodes[i+1])
- T2_dict['imageID'].append(sortedTempImageID[i+1])
- T2=True
- elif diff[i] == 24:
- T3_dict['mri'].append(sortedTempList[i+1])
- T3_dict['PTID'].append(ptid1)
- T3_dict['viscode'].append(tempViscodes[i+1])
- T3_dict['imageID'].append(sortedTempImageID[i+1])
- T3=True
- elif diff[i] == 36:
- T4_dict['mri'].append(sortedTempList[i+1])
- T4_dict['PTID'].append(ptid1)
- T4_dict['viscode'].append(tempViscodes[i+1])
- T4_dict['imageID'].append(sortedTempImageID[i+1])
- elif diff[i] == 48:
- T5_dict['mri'].append(sortedTempList[i+1])
- T5_dict['PTID'].append(ptid1)
- T5_dict['viscode'].append(tempViscodes[i+1])
- T5_dict['imageID'].append(sortedTempImageID[i+1])
- elif diff[i] == 60:
- T6_dict['mri'].append(sortedTempList[i+1])
- T6_dict['PTID'].append(ptid1)
- T6_dict['viscode'].append(tempViscodes[i+1])
- T6_dict['imageID'].append(sortedTempImageID[i+1])
- elif diff[i] == 72:
- T7_dict['mri'].append(sortedTempList[i+1])
- T7_dict['PTID'].append(ptid1)
- T7_dict['viscode'].append(tempViscodes[i+1])
- T7_dict['imageID'].append(sortedTempImageID[i+1])
- elif diff[i] == 84:
- T8_dict['mri'].append(sortedTempList[i+1])
- T8_dict['PTID'].append(ptid1)
- T8_dict['viscode'].append(tempViscodes[i+1])
- T8_dict['imageID'].append(sortedTempImageID[i+1])
- if T2==True and T3==False:
- T1T2+=1
- if T2==True and T3==True:
- T1T2T3+=1
- if T2==False and T3==True:
- T1T3+=1
- if T2==False and T3==False:
- T1only+=1
-
- #all scans have been sorted
- #get counts of scanCounts
- scans1tp = sum(1 for i in scanCount if i == 1)
- scans2tp = sum(1 for i in scanCount if i == 2)
- scans3tp = sum(1 for i in scanCount if i == 3)
- scans4tp = sum(1 for i in scanCount if i == 4)
- scans5tp = sum(1 for i in scanCount if i == 5)
- scans6tp = sum(1 for i in scanCount if i == 6)
- scans7tp = sum(1 for i in scanCount if i == 7)
- scans8tp = sum(1 for i in scanCount if i == 8)
- if first == True:
- with open(xls_datapath+'/figures/RNNDicts.txt','w') as InitialDicts:
- InitialDicts.write('Label: '+label+'\n')
- InitialDicts.write('Length of T1_dict: '+str(len(T1_dict['mri']))+'\n')
- InitialDicts.write('Length of T2_dict: '+str(len(T2_dict['mri']))+'\n')
- InitialDicts.write('Length of T3_dict: '+str(len(T3_dict['mri']))+'\n')
- InitialDicts.write('Length of T4_dict: '+str(len(T4_dict['mri']))+'\n')
- InitialDicts.write('Length of T5_dict: '+str(len(T5_dict['mri']))+'\n')
- InitialDicts.write('Length of T6_dict: '+str(len(T6_dict['mri']))+'\n')
- InitialDicts.write('Length of T7_dict: '+str(len(T7_dict['mri']))+'\n')
- InitialDicts.write('Length of T8_dict: '+str(len(T8_dict['mri']))+'\n')
- InitialDicts.write('ScanCount: '+str(scanCount)+'\n')
- InitialDicts.write('Number of patients with 1 scan: '+str(scans1tp)+'\n')
- InitialDicts.write('Number of patients with 2 scans: '+str(scans2tp)+'\n')
- InitialDicts.write('Number of patients with 3 scans: '+str(scans3tp)+'\n')
- InitialDicts.write('Number of patients with 4 scans: '+str(scans4tp)+'\n')
- InitialDicts.write('Number of patients with 5 scans: '+str(scans5tp)+'\n')
- InitialDicts.write('Number of patients with 6 scans: '+str(scans6tp)+'\n')
- InitialDicts.write('Number of patients with 7 scans: '+str(scans7tp)+'\n')
- InitialDicts.write('Number of patients with 8 scans: '+str(scans8tp)+'\n')
- InitialDicts.write('Used PTIDs: '+str(usedPTIDs)+'\n')
- InitialDicts.write('Sorted Viscodes: '+str(sortedViscodes)+'\n')
- InitialDicts.write('For first 3 timepoints...'+'\n')
- InitialDicts.write('T1only: '+str(T1only)+'\n')
- InitialDicts.write('T1T2: '+str(T1T2)+'\n')
- InitialDicts.write('T1T3: '+str(T1T3)+'\n')
- InitialDicts.write('T1T2T3: '+str(T1T2T3)+'\n')
- else:
- with open(xls_datapath+'/figures/RNNDicts.txt','a') as InitialDicts:
- InitialDicts.write('Label: '+label+'\n')
- InitialDicts.write('Length of T1_dict: '+str(len(T1_dict['mri']))+'\n')
- InitialDicts.write('Length of T2_dict: '+str(len(T2_dict['mri']))+'\n')
- InitialDicts.write('Length of T3_dict: '+str(len(T3_dict['mri']))+'\n')
- InitialDicts.write('Length of T4_dict: '+str(len(T4_dict['mri']))+'\n')
- InitialDicts.write('Length of T5_dict: '+str(len(T5_dict['mri']))+'\n')
- InitialDicts.write('Length of T6_dict: '+str(len(T6_dict['mri']))+'\n')
- InitialDicts.write('Length of T7_dict: '+str(len(T7_dict['mri']))+'\n')
- InitialDicts.write('Length of T8_dict: '+str(len(T8_dict['mri']))+'\n')
- InitialDicts.write('ScanCount: '+str(scanCount)+'\n')
- InitialDicts.write('Number of patients with 1 scan: '+str(scans1tp)+'\n')
- InitialDicts.write('Number of patients with 2 scans: '+str(scans2tp)+'\n')
- InitialDicts.write('Number of patients with 3 scans: '+str(scans3tp)+'\n')
- InitialDicts.write('Number of patients with 4 scans: '+str(scans4tp)+'\n')
- InitialDicts.write('Number of patients with 5 scans: '+str(scans5tp)+'\n')
- InitialDicts.write('Number of patients with 6 scans: '+str(scans6tp)+'\n')
- InitialDicts.write('Number of patients with 7 scans: '+str(scans7tp)+'\n')
- InitialDicts.write('Number of patients with 8 scans: '+str(scans8tp)+'\n')
- InitialDicts.write('Used PTIDs: '+str(usedPTIDs)+'\n')
- InitialDicts.write('Sorted Viscodes: '+str(sortedViscodes)+'\n')
- InitialDicts.write('For first 3 timepoints...'+'\n')
- InitialDicts.write('T1only: '+str(T1only)+'\n')
- InitialDicts.write('T1T2: '+str(T1T2)+'\n')
- InitialDicts.write('T1T3: '+str(T1T3)+'\n')
- InitialDicts.write('T1T2T3: '+str(T1T2T3)+'\n')
-
- return T1_dict,T2_dict,T3_dict,T4_dict,T5_dict,T6_dict,T7_dict,T8_dict
-
- # def run_sort(self, healthy_dict, ad_dict):
- # healthy_sorted = sort_patients(healthy_dict)
- # ad_sorted = sort_patients(ad_dict)
- #
- # return healthy_sorted, ad_sorted
-
-
-
-
-
|