NLP_optimizer_v2.m 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. % paths.patient_dir
  2. % paths.Goal_dir (previously called DP_dir)
  3. % paths.patient
  4. % paths.goalsName
  5. % colorwash(Geometry.data, D_full, [500, 1500], [0,70])
  6. % orthoslice(D_full, [0,70])
  7. function [D_full, w_fin, Geometry, optGoal] = NLP_optimizer_v2(varargin)
  8. % This function performs the beamlet optimization
  9. % Inputs: (Pat_path, path2goal) or none
  10. % If paths are used they should be passed as strings.
  11. % Outputs: full dose image dose: [D_full, w_fin, Geometry, optGoal]
  12. %
  13. % [D_full, w_fin, Geometry, optGoal] = NLP_beamlet_optimizer;
  14. %
  15. % Made by Peter Ferjancic 1. May 2018
  16. % Last updated: 1. April 2019
  17. if nargin<2
  18. load('WiscPlan_preferences.mat')
  19. [Pat_path] = uigetdir([WiscPlan_preferences.patientDataPath ], 'Select Patient folder');
  20. [Goal_file,Goal_path,indx] = uigetfile([Pat_path '\matlab_files\*.mat'], 'Select OptGoal file');
  21. path2geometry = [Pat_path, '\matlab_files\Geometry.mat'];
  22. path2goal = [Goal_path, Goal_file];
  23. else
  24. Pat_path = varargin{1};
  25. path2geometry = [Pat_path, '\matlab_files\Geometry.mat'];
  26. path2goal = varargin{2};
  27. end
  28. N_fcallback1 = 10000;
  29. N_fcallback2 = 200000;
  30. %% PROGRAM STARTS HERE
  31. % - no tocar lo que hay debajo -
  32. fprintf('starting NLP optimization process... \n')
  33. % % -- LOAD GEOMETRY, GOALS, BEAMLETS --
  34. load(path2geometry)
  35. load(path2goal)
  36. [beamlets, beamlets_joined, numBeamlet, numBeam, beam_i_list] = get_beam_lets(Geometry, Pat_path);
  37. %% -- OPTIMIZATION TARGETS --
  38. % -- make the optimization optGoal structure --
  39. for i_goal = 1:size(OptGoals.goals,1)
  40. optGoal{i_goal}=OptGoals.data{i_goal};
  41. optGoal{i_goal}.beamlets_pruned = sparse(beamlets(optGoal{i_goal}.ROI_idx, :));
  42. optGoal_beam{i_goal}=OptGoals.data{i_goal};
  43. optGoal_beam{i_goal}.beamlets_pruned = sparse(beamlets_joined(optGoal{i_goal}.ROI_idx, :));
  44. end
  45. % optGoal_idx = ROI_goals.optGoal_idx;
  46. % targetMinMax_idx = ROI_goals.targetMinMax_idx;
  47. % -- make them robust --
  48. RO_params=0;
  49. optGoal_beam = make_robust_optGoal(optGoal_beam, RO_params, beamlets_joined);
  50. optGoal = make_robust_optGoal(optGoal, RO_params, beamlets);
  51. %% -- INITIALIZE BEAMLET WEIGHTS --
  52. w0_beams = ones(numBeam,1);
  53. w0_beams = mean(optGoal_beam{1}.D_final(optGoal{1}.ROI_idx) ./ (optGoal_beam{1}.beamlets_pruned*w0_beams+0.1)) * w0_beams;
  54. % w0_beams = w0_beams + (2*rand(size(w0_beams))-1) *0.1 .*w0_beams; % random perturbation
  55. w0_beams = double(w0_beams);
  56. % -- CALLBACK OPTIMIZATION FUNCTION --
  57. fun1 = @(x) get_penalty(x, optGoal_beam);
  58. fun2 = @(x) get_penalty(x, optGoal);
  59. % -- OPTIMIZATION PARAMETERS --
  60. % define optimization parameters
  61. A = [];
  62. b = [];
  63. Aeq = [];
  64. beq = [];
  65. lb = zeros(1, numBeamlet);
  66. lb_beam = zeros(1, numBeam);
  67. ub = [];
  68. nonlcon = [];
  69. % define opt limits, and make it fmincon progress
  70. options = optimoptions('fmincon');
  71. options.MaxFunctionEvaluations = N_fcallback1;
  72. options.Display = 'iter';
  73. options.PlotFcn = 'optimplotfval';
  74. options.UseParallel = true;
  75. % options.OptimalityTolerance = 1e-9;
  76. fprintf('\n running initial optimizer:')
  77. %% Run the optimization
  78. % -- GET FULL BEAM WEIGHTS --
  79. tic
  80. w_beam = fmincon(fun1,w0_beams,A,b,Aeq,beq,lb_beam,ub,nonlcon,options);
  81. fprintf(' done!:')
  82. t=toc;
  83. disp(['Optimization time for beams = ',num2str(t)]);
  84. w_beamlets = ones(numBeamlet,1);
  85. numBeam=numel(unique(beam_i_list));
  86. for beam_i = 1:numBeam % assign weights to beamlets
  87. % beamlets from same beam get same initial weights
  88. w_beamlets(beam_i_list == beam_i) = w_beam(beam_i);
  89. end
  90. w_beamlets = w_beamlets + (2*rand(size(w_beamlets))-1) *0.1 .*w_beamlets; % small random perturbation
  91. w_beamlets = 1.1* w_beamlets; % this just kicks the beamlets up a bit to make it easier for the optimizer to start
  92. % -- GET FULL BEAMLET WEIGHTS --
  93. options.MaxFunctionEvaluations = N_fcallback2;
  94. % tic
  95. fprintf('\n running full optimizer:')
  96. w_fin = fmincon(fun2,w_beamlets,A,b,Aeq,beq,lb,ub,nonlcon,options);
  97. fprintf(' done!:')
  98. t=toc;
  99. disp(['Optimization time for beamlets = ',num2str(t)]);
  100. %% evaluate the results
  101. D_full = reshape(beamlets * w_fin, size(Geometry.data));
  102. %% save outputs
  103. NLP_result.dose = D_full;
  104. NLP_result.weights = w_fin;
  105. save([Pat_path, '\matlab_files\NLP_result.mat'], 'NLP_result');
  106. warning('this part needs modification')
  107. plot_DVH(D_full, optGoal)
  108. colorwash(Geometry.data, D_full, [-500, 500], [0, 60]);
  109. % plot_DVH_robust(D_full, optGoal, optGoal_idx)
  110. end
  111. %% support functions
  112. % ---- PENALTY FUNCTION ----
  113. function penalty = get_penalty(x, optGoal)
  114. % this function gets called by the optimizer. It checks the penalty for
  115. % all the robust implementation and returns the worst result.
  116. NumScenarios = optGoal{1}.NbrRandScenarios * optGoal{1}.NbrSystSetUpScenarios * optGoal{1}.NbrRangeScenarios;
  117. fobj = zeros(NumScenarios,1);
  118. sc_i = 1;
  119. for nrs_i = 1:optGoal{1}.NbrRandScenarios
  120. for sss_i = 1 :optGoal{1}.NbrSystSetUpScenarios % syst. setup scenarios = sss
  121. for rgs_i = 1:optGoal{1}.NbrRangeScenarios % range scenario = rs
  122. fobj(sc_i)=eval_f(x, optGoal, nrs_i, sss_i, rgs_i);
  123. sc_i = sc_i + 1;
  124. end
  125. end
  126. end
  127. % take the worst case penalty of evaluated scenarios
  128. penalty=max(fobj);
  129. end
  130. % ------ supp: penalty for single scenario ------
  131. function penalty = eval_f(x, optGoal, nrs_i, sss_i, rgs_i)
  132. penalty = 0;
  133. % for each condition
  134. for goal_i = 1:numel(optGoal)
  135. switch optGoal{goal_i}.function
  136. % min, max, min_sq, max_sq, LeastSquare, min_perc_Volume, max_perc_Volume
  137. case 'min'
  138. % penalize if achieved dose is lower than target dose
  139. d_penalty = 1.0e0 * sum(max(0, ...
  140. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target) -...
  141. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned * x)));
  142. case 'max'
  143. % penalize if achieved dose is higher than target dose
  144. d_penalty = 1.0e0 * sum(max(0, ...
  145. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned * x)-...
  146. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target)));
  147. case 'min_sq'
  148. % penalize if achieved dose is higher than target dose
  149. temp1=min(0, (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned * x)-...
  150. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target));
  151. d_penalty = 1.0e0 * sum(temp1.^2);
  152. case 'max_sq'
  153. % penalize if achieved dose is higher than target dose
  154. temp1=max(0, (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned * x)-...
  155. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target));
  156. d_penalty = 1.0e0 * sum(temp1.^2);
  157. case 'LeastSquare'
  158. % penalize with sum of squares any deviation from target
  159. % dose
  160. d_penalty = 1.0e-1* sum(((...
  161. optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned * x) - ...
  162. optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target).^2);
  163. case 'min_perc_Volume'
  164. % penalize by amount of volume under threshold
  165. perc_vox = numel(find((optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target) -...
  166. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned * x) > 0)) ...
  167. / numel(optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target);
  168. d_penalty = 3.0e5 * min(perc_vox-0.05, 0)
  169. case 'max_perc_Volume'
  170. % penalize by amount of volume under threshold
  171. perc_vox = numel(find((optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target) -...
  172. (optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned * x) < 0)) ...
  173. / numel(optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target);
  174. d_penalty = 3.0e4 * min(perc_vox-0.05, 0)
  175. end
  176. penalty = penalty + d_penalty * optGoal{goal_i}.opt_weight;
  177. end
  178. end
  179. % ---- MAKE ROI ROBUST ----
  180. function optGoal = make_robust_optGoal(optGoal, RO_params, beamlets);
  181. % take regular optimal goal and translate it into several robust cases
  182. % RO_params - should have the information below
  183. % nrs - random scenarios
  184. % sss - system setup scenarios
  185. % rgs - random range scenarios
  186. % X - X>0 moves image right
  187. % Y - Y>0 moves image down
  188. % Z - in/out.
  189. shift_mag = 2; % vox of shift
  190. nrs_scene_list={[0,0,0]};
  191. % ----====#### CHANGE ROBUSTNESS HERE ####====----
  192. sss_scene_list={[0,0,0]};
  193. % sss_scene_list={[0,0,0], [-shift_mag,0,0], [shift_mag,0,0], [0,-shift_mag,0], [0,shift_mag,0], [0,0,-1], [0,0,1]};
  194. % sss_scene_list={[0,0,0], [-shift_mag,0,0], [shift_mag,0,0], [0,-shift_mag,0], [0,shift_mag,0],...
  195. % [-shift_mag*2,0,0], [shift_mag*2,0,0], [0,-shift_mag*2,0], [0,shift_mag*2,0]};
  196. % ----====#### CHANGE ROBUSTNESS HERE ####====----
  197. rgs_scene_list={[0,0,0]};
  198. % [targetIn, meta] = nrrdread('C:\010-work\003_localGit\WiscPlan_v2\data\archive\CDP_data\CDP5_DP_target.nrrd');
  199. % [targetIn, meta] = nrrdread('C:\010-work\003_localGit\WiscPlan_v2\data\PD_HD_dicomPhantom\Tomo_DP_target.nrrd');
  200. % [targetIn, meta] = nrrdread('C:\010-work\003_localGit\WiscPlan_v2\data\archive\CDP_data\CDP5_DP_target.nrrd');
  201. for i = 1:numel(optGoal)
  202. optGoal{i}.NbrRandScenarios =numel(nrs_scene_list);
  203. optGoal{i}.NbrSystSetUpScenarios=numel(sss_scene_list);
  204. optGoal{i}.NbrRangeScenarios =numel(rgs_scene_list);
  205. end
  206. for goal_i = 1:numel(optGoal)
  207. % get target
  208. idx=optGoal{goal_i}.ROI_idx;
  209. targetImg1=zeros(optGoal{goal_i}.imgDim);
  210. targetImg1(idx)=1;
  211. % get beamlets
  212. for nrs_i = 1:optGoal{goal_i}.NbrRandScenarios % num. of random scenarios
  213. % modify target and beamlets
  214. targetImg2=targetImg1;
  215. % beamlets stay the same
  216. for sss_i = 1 :optGoal{goal_i}.NbrSystSetUpScenarios % syst. setup scenarios = sss
  217. % modify target and beamlets
  218. [targetImg3 idxValid]=get_RO_sss(targetImg2, sss_scene_list{sss_i});
  219. % beamlets stay the same
  220. for rgs_i = 1:optGoal{goal_i}.NbrRangeScenarios % range scenario = rgs
  221. % modify target and beamlets
  222. targetImg4=targetImg3;
  223. % beamlets stay the same
  224. %% make new target and beamlets
  225. ROI_idx=[];
  226. ROI_idx=find(targetImg4>0);
  227. target = optGoal{goal_i}.D_final(idxValid);
  228. beamlets_pruned = beamlets(ROI_idx, :);
  229. % save to optGoal output
  230. optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.ROI_idx = ROI_idx;
  231. optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.beamlets_pruned = beamlets_pruned;
  232. optGoal{goal_i}.nrs{nrs_i}.sss{sss_i}.rgs{rgs_i}.target = target;
  233. end
  234. end
  235. end
  236. end
  237. end
  238. % ------ supp: RO case SSS ------
  239. function [targetImg3 ia]=get_RO_sss(targetImg2, sss_scene_shift);
  240. % translate the target image
  241. targetImg3 = imtranslate(targetImg2,sss_scene_shift);
  242. % now we need to figure out if any target voxels fell out during the
  243. % shift
  244. imgValid = imtranslate(targetImg3,-sss_scene_shift);
  245. imgInvalid = (targetImg2-imgValid);
  246. idx_1 = find(targetImg2);
  247. idx_2 = find(imgInvalid);
  248. [idxValid,ia] = setdiff(idx_1,idx_2);
  249. [C,ia, ib] = intersect(idx_1,idxValid);
  250. end