import statUtils
import os
import radiomics
import SimpleITK
import sys
import json
import numpy

def main(parFile='../templates/statistics.json'):
    setup=statUtils.loadSetup(parFile)
    rFile='radiomics.json'

    #update threshold values if needed
    with open(rFile,'w') as f:
        f.write(json.dumps(setup['radiomics']))
    setup['db'],setup['fb']=statUtils.connectDB('onko-nix')
    users=statUtils.getUsers(setup['db'],setup['project'])
    qFilter=[]
    try:
        vList=';'.join(setup['participants'])    
        qFilter.append({'variable':'ParticipantId','value':vList,'oper':'in'})
    except KeyError:
        pass

    try:
        vList=';'.join(setup['visits'])
        qFilter.append({'variable':'visitCode','value':vList,'oper':'in'})
    except KeyError:
        pass

    ds=setup['db'].selectRows(setup['project'],'study',setup['imagingDataset'],qFilter)
    if not os.path.isdir(setup['localDir']):
        os.mkdir(setup['localDir'])
    #select just the first row; debugging
    rows=ds['rows']
    setup['values']=['COM','MTV','TLG','SUVmean','SUVmax','voxelCount','SUVSD']
    #params=os.path.join('..','templates','radiomics.yaml')
    setup['featureExtractor']=radiomics.featureextractor.RadiomicsFeatureExtractor(rFile)

    n=setup.get('n',-1)
    if n>0:
        rows=rows[:n]

    for r in rows:

        #check if we have to do calculation
        setup['SUVdataset']='SUVanalysis_liver'
        liverDone=checkData(setup,r)

        setup['SUVdataset']='SUVanalysis_SUVmax'
        suvMaxDone=checkData(setup,r) 

        setup['SUVdataset']='SUVanalysis_liver1p5'
        liver1p5Done=checkData(setup,r)

        
        if liverDone and suvMaxDone and liver1p5Done:
            print('Skipping {} {}'.format(r['ParticipantId'],r['visitCode']))
            continue
        doneCode=f'({liverDone}/{liver1p5Done}/{suvMaxDone})'
        print(f'Done: (liver/liver1p5/suvMax): {doneCode}')
        #PET
        for q in ['petResampled']:
            localPath=statUtils.getImage(setup,r,q)
        if localPath=="NONE":
            continue
        pet=SimpleITK.ReadImage(localPath)

        #Seg
        segPaths=statUtils.getSegmentations(setup,r)
        if "NONE" in segPaths.values():
            os.remove(localPath)
            continue
        segKeys=list(segPaths.keys())
        for x in segPaths:
            print('Loaded {}/{}'.format(users[x],segPaths[x]))

        seg={x:SimpleITK.ReadImage(segPaths[x]) for x in segPaths}


        try:
            thr=setup['radiomics']['setting']['resegmentRange'][0]
        except KeyError:
            thr=None

        setup['radiomics']['setting']['resegmentRange']=None        
        firstOrder=setup['radiomics']['featureClass']['firstorder']
        if 'Variance' not in firstOrder:
            firstOrder.append('Variance')
        #get value for maximum in organs or liver mean and std
        outputs=getValues(setup,r,pet,seg)
        setup['SUVdataset']='SUVanalysis'
        #uploadData(setup,r,outputs)

        print(outputs)
        
        default={'SUVmax':0,'SUVmean':0,'SUVSD':0}
        #liver threshold
        liverId=1
        liverThreshold={x:outputs[x].get(liverId,default)['SUVmean']
                +2*outputs[x].get(liverId,default)['SUVSD'] 
                for x in outputs}
        liver1p5Threshold={x:outputs[x].get(liverId,default)['SUVmean']
                +1.5*outputs[x].get(liverId,default)['SUVSD'] 
                for x in outputs}
        lesionId=4
        bmId=3
        suvMax={x:numpy.max([outputs[x].get(lesionId,default)['SUVmax'],
                    outputs[x].get(bmId,default)['SUVmax']])
                for x in outputs}
        suvMaxThreshold={x:0.41*suvMax[x] for x in suvMax}
    

        print('thr[liver]={} thr[liver/1.5]={} thr(suvmax)={}'.format(liverThreshold,liver1p5Threshold,suvMaxThreshold))
       

        if not liverDone:
            setup['SUVdataset']='SUVanalysis_liver'
            liverOutputs=thresholdAnalysis(setup,r,pet,seg,liverThreshold)
            uploadData(setup,r,liverOutputs)

        #also for threshold=1.5
        if not liver1p5Done:
            setup['SUVdataset']='SUVanalysis_liver1p5'
            liver1p5Outputs=thresholdAnalysis(setup,r,pet,seg,liver1p5Threshold)
            uploadData(setup,r,liver1p5Outputs)

        if not suvMaxDone:
            setup['SUVdataset']='SUVanalysis_SUVmax'
            suvMaxOutputs=thresholdAnalysis(setup,r,pet,seg,suvMaxThreshold)
            uploadData(setup,r,suvMaxOutputs)

        #skip threshold of 4
        doThreshold4=False
        if doThreshold4:
            #threshold of 4
            setup['radiomics']['setting']['resegmentRange']=[4]
            setup['radiomics']['setting']['resegmentShape']=True
            outputs4=getValues(setup,r,pet,seg)
            setup['SUVdataset']='SUVanalysis_SUV4'
            uploadData(setup,r,outputs4)

        
        #outputs=getValues(setup,users,r,pet)
        #uploadData(setup,r,outputs)

        #cleanup
        os.remove(localPath)

        for x in segPaths:
            os.remove(segPaths[x])


def thresholdAnalysis(setup,r,pet,seg,thrs):
    #thresholds thrs are by participant and region
    outputs={}
    for s in thrs:
        #update radiomics setting
        setup['radiomics']['setting']['resegmentRange']=[thrs[s]]
        setup['radiomics']['setting']['resegmentShape']=True
        outputs[s]=getValuesForSegmentation(setup,r,pet,seg[s])
            
        _=[outputs[s][y].update({'threshold':thrs[s]}) for y in outputs[s]]
    return outputs


def getValues(setup,row,pet,seg):

    rFile='radiomics.json'

    with open(rFile,'w') as f:
        f.write(json.dumps(setup['radiomics']))

    setup['featureExtractor']=radiomics.featureextractor.RadiomicsFeatureExtractor(rFile)


    #find labels associated with each (non-overlaping) segmentation
    ids=statUtils.getSegments(list(seg.values())[0])
    outputs={x:{} for x in seg}
    for x in seg:
        
        for id in ids:
            print('{} {}'.format(id,ids[id]))
            try:
                output=statUtils.getRadiomicsComponentStats(setup,pet,seg[x],ids[id])
            except ValueError:
                continue
            outputs[x][ids[id]]=output
        
    os.remove(rFile)
    return outputs

def getValuesForSegmentation(setup,row,pet,seg):

    rFile='radiomics.json'

    with open(rFile,'w') as f:
        f.write(json.dumps(setup['radiomics']))

    setup['featureExtractor']=radiomics.featureextractor.RadiomicsFeatureExtractor(rFile)


    #find labels associated with each (non-overlaping) segmentation
    ids=statUtils.getSegments(seg)
    outputs={}

    for id in ids:
        print('{} {}'.format(id,ids[id]))
        try:
            output=statUtils.getRadiomicsComponentStats(setup,pet,seg,ids[id])
        except ValueError:
            continue
        outputs[ids[id]]=output
        
    os.remove(rFile)
    return outputs


def uploadData(setup,r,outputs):
    baseVar=['ParticipantId','SequenceNum','patientCode','visitCode']
    for x in outputs:
        for s in outputs[x]:
            output=outputs[x][s]
            output.update({x:r[x] for x in baseVar})
            output['User']=x
            output['segment']=s
            statUtils.updateDatasetRows(setup['db'],setup['project'],setup['SUVdataset'],[output])
    
def checkData(setup,r):
    qFilter=[]
    qFilter.append({'variable':'ParticipantId','value':r['ParticipantId'],'oper':'eq'})
    qFilter.append({'variable':'visitCode','value':r['visitCode'],'oper':'eq'})
    ds=setup['db'].selectRows(setup['project'],'study',setup['SUVdataset'],qFilter)
    n=len(ds['rows'])
    print('[{}:{}/{}] got {} rows.'.format(setup['SUVdataset'],r['ParticipantId'],r['visitCode'],n))
    return n>0



if __name__=='__main__':
    main(sys.argv[1])