% [D_full, w_fin, Geometry, optGoal] = NLP_beamlet_optimizer;
% orthoslice(D_full, [0,70])
% colorwash(Geometry.data, D_full, [500, 1500], [0,70])


function [D_full, w_fin, Geometry, optGoal] = NLP_beamlet_optimizer
% This function performs the beamlet optimization
% Inputs: none. Still needs to find "Geometry" and "beamlets" from Wiscplan
% Outputs: full dose image dose: D_full, optimal beamlet weights: w_fin
%
% Made by Peter Ferjancic 1. May 2018
% Last updated: 14. August 2018
% Inspired by Ana Barrigan's REGGUI optimization procedures

N_fcallback1 = 5000;
N_fcallback2 = 200000;
patient = 'gbm_005';

switch patient
    case 'patient'
        patient_dir = 'C:\010-work\003_localGit\WiscPlan_v2\data\PatientData';
    case 'tomoPhantom'
        patient_dir = 'C:\010-work\003_localGit\WiscPlan_v2\data\PatientData';        
    case 'phantom_HD'
        patient_dir = 'C:\010-work\003_localGit\WiscPlan_v2\data\PD_HD_dicomPhantom';
    case 'doggo'
        patient_dir = 'C:\010-work\003_localGit\WiscPlan_v2\data\PatientData_dog5_3';
    case 'gbm_005'
        patient_dir = 'C:\010-work\003_localGit\WiscPlan_v2\data\PatientData_ozzy1';
    otherwise
        error('invalid case')
end

merge_beamlets(4, patient_dir);


%% PROGRAM STARTS HERE
% - no tocar lo que hay debajo -
fprintf('starting NLP optimization process... ')

% -- LOAD GEOMETRY AND BEAMLETS --
load([patient_dir '\matlab_files\Geometry.mat'])
% beamlet_batch_filename = [patient_dir '\beamlet_batch_files\' 'beamletbatch0.bin'];
beamlet_batch_filename = [patient_dir '\' 'batch_dose.bin'];
beamlet_cell_array = read_ryan_beamlets(beamlet_batch_filename, 'ryan');
fprintf('\n loaded beamlets...')

% -- SET INITIAL PARAMETERS --
numVox  = numel(Geometry.data);
numBeamlet = size(beamlet_cell_array,2);

%% -- BEAMLET DOSE DELIVERY --
beamlets = get_beamlets(beamlet_cell_array, numVox);
% show_joint_beamlets(beamlets, size(Geometry.data), 7:9)
fprintf('\n beamlet array constructed...')
% - merge beamlets into beams -
load([patient_dir '\all_beams.mat'])
beamletOrigin=[0 0 0];
beam_i=0;
beam_i_list=[];
for beamlet_i = 1:numel(all_beams)
    newBeamletOrigin = all_beams{beamlet_i}.ip;
    if any(newBeamletOrigin ~= beamletOrigin)
        beam_i = beam_i+1;
        beamletOrigin = newBeamletOrigin;
    end
    beam_i_list=[beam_i_list, beam_i];
end
beamlets_joined=[];
target_joined=[];
wbar2 = waitbar(0, 'merging beamlets into beams');
numBeam=numel(unique(beam_i_list));
for beam_i = unique(beam_i_list)
    beam_full = sum(beamlets(:,beam_i_list == beam_i), 2); 
    beamlets_joined(:,end+1) = beam_full;
    waitbar(beam_i/numBeam, wbar2)
end
close(wbar2)


%% -- OPTIMIZATION TARGETS --
make_ROI_goals(Geometry, beamlets, beamlets_joined, patient);

[optGoal, optGoal_beam, optGoal_idx, targetMinMax_idx] = get_ROI_goals(patient);

% -- make them robust --
RO_params=0;
optGoal_beam = make_robust_optGoal(optGoal_beam, RO_params, beamlets_joined);
optGoal = make_robust_optGoal(optGoal, RO_params, beamlets);

%% -- INITIALIZE BEAMLET WEIGHTS --
w0_beams = ones(numBeam,1);
w0_beams = mean(optGoal_beam{1}.target ./ (optGoal_beam{1}.beamlets_pruned*w0_beams+0.1)) * w0_beams;
w0_beams = w0_beams + (2*rand(size(w0_beams))-1) *0.1 .*w0_beams; % random perturbation


% -- CALLBACK OPTIMIZATION FUNCTION --
fun1 = @(x) get_penalty(x, optGoal_beam);
fun2 = @(x) get_penalty(x, optGoal);

% -- OPTIMIZATION PARAMETERS --
% define optimization parameters
A = [];
b = [];
Aeq = [];
beq = [];
lb = zeros(1, numBeamlet);
lb_beam = zeros(1, numBeamlet);
ub = [];
nonlcon = [];

% define opt limits, and make it fmincon progress
options = optimoptions('fmincon');
options.MaxFunctionEvaluations = N_fcallback1;
options.Display = 'iter';
options.PlotFcn = 'optimplotfval';
% options.UseParallel = true;
fprintf('\n running initial optimizer:')

%% Run the optimization
% -- GET FULL BEAM WEIGHTS --
tic
w_beam = fmincon(fun1,w0_beams,A,b,Aeq,beq,lb_beam,ub,nonlcon,options);
% t=toc;
% disp(['Optimization time for beams = ',num2str(t)]);

w_beamlets = ones(numBeamlet,1);
numBeam=numel(unique(beam_i_list));
for beam_i = 1:numBeam % assign weights to beamlets
    % beamlets from same beam get same initial weights
    w_beamlets(beam_i_list == beam_i) = w_beam(beam_i);
end
w_beamlets = w_beamlets + (2*rand(size(w_beamlets))-1) *0.1 .*w_beamlets; % small random perturbation

% -- GET FULL BEAMLET WEIGHTS --
options.MaxFunctionEvaluations = N_fcallback2;
% tic
w_fin = fmincon(fun2,w_beamlets,A,b,Aeq,beq,lb,ub,nonlcon,options);
t=toc;
disp(['Optimization time for beamlets = ',num2str(t)]);


%% evaluate the results
D_full = reshape(beamlets * w_fin, size(Geometry.data));

%% save outputs
NLP_result.dose = D_full;
NLP_result.weights = w_fin;
save([patient_dir '\matlab_files\NLP_result.mat'], 'NLP_result');

plot_DVH(D_full, optGoal, optGoal_idx, targetMinMax_idx)
colorwash(Geometry.data, D_full);
% plot_DVH_robust(D_full, optGoal, optGoal_idx)
end

%% support functions
% ---- PENALTY FUNCTION ----
function penalty = get_penalty(x, optGoal)
    % this function gets called by the optimizer. It checks the penalty for
    % all the robust implementation and returns the worst result.
    
    NumScenarios = optGoal{1}.NbrRandScenarios * optGoal{1}.NbrSystSetUpScenarios * optGoal{1}.NbrRangeScenarios;
    fobj = zeros(NumScenarios,1);  
    sc_i = 1;
    
    for nrs_i = 1:optGoal{1}.NbrRandScenarios 
        for sss_i = 1 :optGoal{1}.NbrSystSetUpScenarios % syst. setup scenarios = sss
            for rrs_i = 1:optGoal{1}.NbrRangeScenarios % range scenario = rs
                fobj(sc_i)=eval_f(x, optGoal, nrs_i, sss_i, rrs_i);
                sc_i = sc_i + 1;
            end
        end
    end
    % take the worst case penalty of evaluated scenarios
    penalty=max(fobj);
end
% ------ supp: penalty for single scenario ------
function penalty = eval_f(x, optGoal, nrs_i, sss_i, rrs_i)
    penalty = 0;
    % for each condition
    for goal_i = 1:numel(optGoal)
        switch optGoal{goal_i}.function
            % min, max, min_sq, max_sq, LeastSquare, min_perc_Volume, max_perc_Volume
            case 'min'
                % penalize if achieved dose is lower than target dose
                d_penalty = 1.0e0 * sum(max(0, ...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target) -...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned * x)));
            case 'max'
                % penalize if achieved dose is higher than target dose
                d_penalty = 1.0e0 * sum(max(0, ...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned * x)-...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target)));
            case 'min_sq'
                % penalize if achieved dose is higher than target dose
                temp1=min(0, (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned * x)-...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target));
                d_penalty = 1.0e0 * sum(temp1.^2);
            case 'max_sq'
                % penalize if achieved dose is higher than target dose
                temp1=max(0, (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned * x)-...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target));
                d_penalty = 1.0e0 * sum(temp1.^2);
            case 'LeastSquare'
                % penalize with sum of squares any deviation from target
                % dose
                d_penalty = 1.0e-1* sum(((...
                    optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned * x) - ...
                    optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target).^2);
            case 'min_perc_Volume'
                % penalize by amount of volume under threshold
                perc_vox = numel(find((optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target) -...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned * x) > 0)) ...
                    / numel(optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target);
                d_penalty = 3.0e5 * min(perc_vox-0.05, 0)
                
            case 'max_perc_Volume'
                % penalize by amount of volume under threshold
                perc_vox = numel(find((optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target) -...
                    (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned * x) < 0)) ...
                    / numel(optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target);
                d_penalty = 3.0e4 * min(perc_vox-0.05, 0)
                    
        end
        penalty = penalty + d_penalty * optGoal{goal_i}.opt_weight;
    end
end

% ---- GET BEAMLETS ----
function beamlets = get_beamlets(beamlet_cell_array, numVox);
    wbar1 = waitbar(0, 'Creating beamlet array');
    numBeam = size(beamlet_cell_array,2);
    batchSize=100;
    beamlets = sparse(0, 0);
    for beam_i=1:numBeam
        % for each beam define how much dose it delivers on each voxel
        idx=beamlet_cell_array{1, beam_i}.non_zero_indices;

        % break the beamlets into multiple batches
        if rem(beam_i, batchSize)==1;
            beamlet_batch = sparse(numVox, batchSize);
            beam_i_temp=1;
        end

        beamlet_batch(idx, beam_i_temp) = 1000*beamlet_cell_array{1, beam_i}.non_zero_values;
        waitbar(beam_i/numBeam, wbar1, ['Adding beamlet array: #', num2str(beam_i)])

        % add the batch to full set when filled
        if rem(beam_i, batchSize)==0;
            beamlets =[beamlets, beamlet_batch];
        end
        % crop and add the batch to full set when completed
        if beam_i==numBeam;
            beamlet_batch=beamlet_batch(:, 1:beam_i_temp);
            beamlets =[beamlets, beamlet_batch];
        end
        beam_i_temp=beam_i_temp+1;

    end
    close(wbar1)

end
function show_joint_beamlets(beamlets, IMGsize, Beam_list)
    % this function overlays and plots multiple beamlets. The goal is to
    % check whether some beamlets are part of the same beam manually.
    
    beam=sum(beamlets(:, Beam_list), 2);
    
    beamImg=reshape(full(beam), IMGsize);
        
    orthoslice(beamImg)
    
end


% ---- MAKE ROI ROBUST ----
function optGoal = make_robust_optGoal(optGoal, RO_params, beamlets);
    % take regular optimal goal and translate it into several robust cases
    
    % RO_params - should have the information below
    % nrs - random scenarios
    % sss - system setup scenarios
    % rrs - random range scenarios
    
    % X - X>0 moves image right
    % Y - Y>0 moves image down
    % Z - in/out.
    
    shift_mag = 1; % vox of shift
    nrs_scene_list={[0,0,0]};

%     sss_scene_list={[0,0,0]};
    sss_scene_list={[0,0,0], [-shift_mag,0,0], [shift_mag,0,0], [0,-shift_mag,0], [0,shift_mag,0]};
%     sss_scene_list={[0,0,0], [-shift_mag,0,0], [shift_mag,0,0], [0,-shift_mag,0], [0,shift_mag,0],...
%         [-shift_mag*2,0,0], [shift_mag*2,0,0], [0,-shift_mag*2,0], [0,shift_mag*2,0]};


    rrs_scene_list={[0,0,0]};

%     [targetIn, meta] = nrrdread('C:\010-work\003_localGit\WiscPlan_v2\data\CDP_data\CDP5_DP_target.nrrd');
%     [targetIn, meta] = nrrdread('C:\010-work\003_localGit\WiscPlan_v2\data\PD_HD_dicomPhantom\Tomo_DP_target.nrrd');
    
    for i = 1:numel(optGoal)
        optGoal{i}.NbrRandScenarios     =numel(nrs_scene_list);
        optGoal{i}.NbrSystSetUpScenarios=numel(sss_scene_list);
        optGoal{i}.NbrRangeScenarios    =numel(rrs_scene_list);
    end
    

    for goal_i = 1:numel(optGoal)
        % get target
        idx=optGoal{goal_i}.ROI_idx;
        targetImg1=zeros(optGoal{goal_i}.imgDim);
        targetImg1(idx)=1;
        % get beamlets
        
        for nrs_i = 1:optGoal{goal_i}.NbrRandScenarios          % num. of random scenarios
            % modify target and beamlets
            targetImg2=targetImg1;
            % beamlets stay the same
            
            for sss_i = 1 :optGoal{goal_i}.NbrSystSetUpScenarios   % syst. setup scenarios = sss
                % modify target and beamlets
                targetImg3=get_RO_sss(targetImg2, sss_scene_list{sss_i});
                % beamlets stay the same
                
                for rrs_i = 1:optGoal{goal_i}.NbrRangeScenarios   % range scenario = rrs
                    % modify target and beamlets
                    targetImg4=targetImg3;
                    % beamlets stay the same
                    
                    %% make new target and beamlets
                    ROI_idx=[];
                    ROI_idx=find(targetImg4>0);
                    
                    if isfield(optGoal{goal_i}, 'target_alpha')
                        target = double(optGoal{goal_i}.target_alpha * targetIn(ROI_idx));
                        disp('exists')
                    else
                        target = ones(numel(ROI_idx), 1) * optGoal{goal_i}.D_final;
                        disp('not exists')
                    end
                    
                    
                    beamlets_pruned = beamlets(ROI_idx, :);
                    
                    % save to optGoal output
                    optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.ROI_idx            = ROI_idx;
                    optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.beamlets_pruned    = beamlets_pruned;
                    optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rrs{rrs_i}.target             = target;
                end
            end
        end
    end
end
% ------ supp: RO case SSS ------
function targetImg3=get_RO_sss(targetImg2, sss_scene_shift);
    % translate the target image 
    targetImg3 = imtranslate(targetImg2,sss_scene_shift);
end