heatmapPlotting.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import numpy as np
  2. import nibabel as nib
  3. import math
  4. import csv
  5. import matplotlib.cm as cm
  6. import SimpleITK as sitk
  7. import csv
  8. from copy import deepcopy
  9. import matplotlib.colors as mcolors
  10. import nibabel as nib
  11. from matplotlib import pyplot as plt
  12. class heatmapPlotter():
  13. def __init__(self, seed=None):
  14. self.seed = seed
  15. #self.shape = test_mri_nonorm[0].shape
  16. #ATTEMPT AT VISUALIZE_SALIENCY:
  17. #for j in range(len(test_data[0])):
  18. # grads = netCNN.make_vis_saliency(test_data,j)
  19. # plt.imshow(grads,alpha=0.6)
  20. def plot_idv_brain(self, heat_map, brain_img, ref_scale, fig=None, ax=None, contour_areas=[],
  21. x_idx=slice(0, 91), y_idx=slice(0, 109), z_idx=slice(0, 91),
  22. vmin=90, vmax=99.5, set_nan=True, cmap=None, c=None):
  23. if fig is None or ax is None:
  24. fig, ax = plt.subplots(1, figsize=(12, 12))
  25. img = deepcopy(heat_map)
  26. #if set_nan:
  27. #img[nmm_mask==0]=np.nan
  28. if cmap is None:
  29. cmap = mcolors.LinearSegmentedColormap.from_list(name='alphared',colors=[(1, 0, 0, 0),"darkred", "red", "darkorange", "orange", "yellow"],N=5000)
  30. grey_vmin, grey_vmax = np.min(brain_img), np.max(brain_img)
  31. if brain_img is not None:
  32. brain = deepcopy(brain_img)
  33. ax.imshow(np.squeeze(brain[x_idx, y_idx, z_idx],-1), cmap="gray", #was .T before (but I dont need to transpose the indices I dont think)
  34. vmin=grey_vmin, vmax=grey_vmax ) #,alpha=.9
  35. vmin, vmax = np.percentile(ref_scale, vmin), np.percentile(ref_scale, vmax)
  36. im = ax.imshow(np.squeeze(img[x_idx, y_idx, z_idx],-1), cmap=cmap, #was .T before (but I dont need to transpose the indices I dont think)
  37. vmin=vmin, vmax=vmax, interpolation="gaussian", alpha=.7)
  38. ax.axis('off')
  39. #plot_contours(contour_areas, x_idx, y_idx, z_idx, fig=fig, ax=ax, c=c)
  40. plt.gca().invert_yaxis()
  41. return fig, ax, im
  42. ##GRAD-CAM
  43. def GuidedGradCAM(self, test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN):
  44. last_conv_layer_name = "features" #maybe supposed to be fc1?
  45. classifier_layer_names = "CNNclass_output" #supposed to have 2 layers??
  46. shape = test_mri_nonorm[0].shape
  47. cases = ["AD", "NC", "TP", "TN", "FP", "FN"]
  48. case_maps_GGC = {case: np.zeros(shape) for case in cases}
  49. mean_maps_GGC = {case: np.zeros(shape) for case in cases}
  50. counts = {case: 0 for case in cases}
  51. j=53 #CHANGE START POINT FOR NC DATA
  52. while j < len(test_data[0]): #CHANGE END POINT FOR NC DATA = len(test_data[0]), for AD data = len(test_data[0])/2
  53. #sitk_mri = sitk.GetImageFromArray(test_mri_nonorm[j], isVector=True) #use the non normalized image array
  54. #sitk.WriteImage(sitk_mri,model_filepath+'/figures/mri_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')
  55. #sitk_mri_normed = sitk.GetImageFromArray(test_data[0][j],isVector=True) #check out the normalized image
  56. #sitk.WriteImage(sitk_mri_normed,model_filepath+'/figures/mri_normed_'+str(seed)+'_'+str(j)+'_'+test_data[4][j]+'_'+test_data[3][j]+'.nii')
  57. CNN_gradcam_map = netCNN.make_gradcam_heatmap2(test_data,j)
  58. #CNN_gradcam[j] = CNN_gradcam_map
  59. #CNN_sitk_gradcam = sitk.GetImageFromArray(CNN_gradcam_map, isVector=True)
  60. #CNN_sitk_gradcam.CopyInformation(sitk_mri)
  61. #sitk.WriteImage(CNN_sitk_gradcam,model_filepath+'/figures/CNN_gradcam_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')
  62. #GUIDED BACKPROP
  63. CNN_gb_map = netCNN.guided_backprop(test_data,j)
  64. #CNN_gb[j] = CNN_gb_map
  65. #CNN_sitk_gb = sitk.GetImageFromArray(CNN_gb_map)
  66. #sitk.WriteImage(CNN_sitk_gb,model_filepath+'/figures/CNN_gb_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')#
  67. #GUIDED GRAD-CAM
  68. CNN_guided_gradcam_map = CNN_gb_map * CNN_gradcam_map
  69. #CNN_guided_gradcam[j] = CNN_guided_gradcam_map
  70. #CNN_sitk_guided_gradcam = sitk.GetImageFromArray(CNN_guided_gradcam_map)
  71. #sitk.WriteImage(CNN_sitk_guided_gradcam,model_filepath+'/figures/CNN_guided_gradcam_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')
  72. """ Just for now for memory purposes
  73. #Plot middle slice of each
  74. subplot_args = { 'nrows': 1, 'ncols': 5, 'figsize': (12, 4),
  75. 'subplot_kw': {'xticks': [], 'yticks': []} }
  76. f, ax = plt.subplots(**subplot_args)
  77. ax[0].set_title('Original Image', fontsize=11)
  78. ax[0].imshow(test_mri_nonorm[j][:,:,45,0],cmap='gray')
  79. ax[1].set_title('Guided Backprop overlay', fontsize=11)
  80. ax[1].imshow(test_mri_nonorm[j][:,:,45,0],cmap='gray')
  81. ax[1].imshow(CNN_gb_map[:,:,45,0],cmap='jet', alpha=0.4)
  82. ax[2].set_title('GRAD-CAM', fontsize=11)
  83. ax[2].imshow(CNN_gradcam_map[:,:,45,0],cmap='jet')
  84. ax[3].set_title('Guided GRAD-CAM', fontsize=11)
  85. ax[3].imshow(CNN_guided_gradcam_map[:,:,45,0],cmap='jet')
  86. ax[4].set_title('Guided GRAD-CAM overlay', fontsize=11)
  87. ax[4].imshow(test_mri_nonorm[j][:,:,45,0],cmap='gray')
  88. ax[4].imshow(CNN_guided_gradcam_map[:,:,45,0],cmap='jet', alpha=0.4)
  89. plt.savefig(model_filepath+'/figures/CNN_grad_maps_z45_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.png')
  90. #plt.show()
  91. fig.clf()
  92. plt.close(f)
  93. """
  94. #Sort maps by cases
  95. true_case = "AD" if test_data[3][j]==0 else "NC"
  96. if np.argmax(test_predsCNN[j])==0 and true_case=="AD":
  97. case = "TP"
  98. elif np.argmax(test_predsCNN[j])==0 and true_case!="AD":
  99. case = "FP"
  100. elif np.argmax(test_predsCNN[j])==1 and true_case=="NC":
  101. case = "TN"
  102. elif np.argmax(test_predsCNN[j])==1 and true_case!="NC":
  103. case = "FN"
  104. """
  105. #for Guided Grad Cam
  106. case_maps_GGC[case] += CNN_guided_gradcam_map
  107. counts[case] += 1
  108. case_maps_GGC[true_case] += CNN_guided_gradcam_map
  109. counts[true_case] += 1
  110. """
  111. #for Grad Cam
  112. case_maps_GGC[case] += CNN_gradcam_map
  113. counts[case] += 1
  114. case_maps_GGC[true_case] += CNN_gradcam_map
  115. counts[true_case] += 1
  116. print('counts: ',counts)
  117. j+=1
  118. """
  119. #Plot INDIVIDUAL heatmaps - can't do this anymore because I removed CNN_gradcam, CNN_gb, CNN_guided_gradcam in order to save memory
  120. mean_maps_GGC["AD"] = case_maps_GGC["AD"]/counts["AD"]
  121. for j in range(len(test_data[0])):
  122. subplot_args = { 'nrows': 4, 'ncols': 1, 'figsize': (12, 12), 'sharey':True, 'sharex':True,
  123. 'subplot_kw': {'xticks': [], 'yticks': []} }
  124. fig, axes = plt.subplots(**subplot_args)
  125. vmin, vmax = 50, 99.5 #NOT SURE I WANT THIS (READ PAPER) - might be what is creating the 'mask' effect
  126. for ax, idx in zip(axes[:],[30, 40, 50, 60]):
  127. ax.text(-25, 22, "Slice " + str(idx), rotation="vertical", fontsize=20)
  128. fig, ax, im = self.plot_idv_brain(CNN_guided_gradcam[j], test_mri_nonorm[j], mean_maps_GGC["AD"],x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  129. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  130. ax.text(5, -20, "class: "+str(test_data[3][j])+", prediction: "+str(np.argmax(test_predsCNN[j])), fontsize=20)
  131. fig.tight_layout()
  132. fig.subplots_adjust(right=0.8, top=0.95, hspace=0.05, wspace=0.05)
  133. fig.suptitle("LRP for Patient "+str(test_data[4][j])+", ImageID: "+str(test_data[5][j]), fontsize=22, x=.41)
  134. cbar_ax = fig.add_axes([0.6, 0.15, 0.025, 0.7])
  135. cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
  136. vmin_val, vmax_val = np.percentile(mean_maps_GGC["AD"], vmin), np.percentile(mean_maps_GGC["AD"], vmax)
  137. cbar.set_ticks([vmin_val, vmax_val])
  138. cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
  139. fontsize=16)
  140. cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=18)
  141. fig.savefig(model_filepath+'/figures/CNN_GGC_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.png')
  142. fig.clf()
  143. plt.close(fig)
  144. """
  145. return case_maps_GGC, counts #Removed CNN_gradcam, CNN_gb, CNN_guided_gradcam to save memory
  146. #LAYERWISE RELEVANCE PROPAGATION #https://github.com/moboehle/Pytorch-LRP/blob/master/Plotting%20brain%20maps.ipynb
  147. def LRP(self, test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN):
  148. shape = test_mri_nonorm[0].shape
  149. print('length of test_data[3]: ',len(test_data[3]))
  150. #Run LRP for each test image
  151. cases = ["AD", "NC", "TP", "TN", "FP", "FN"]
  152. case_maps_LRP = {case: np.zeros(shape) for case in cases}
  153. mean_maps_LRP = {case: np.zeros(shape) for case in cases}
  154. counts = {case: 0 for case in cases}
  155. j=53 #CHANGE START POINT FOR NC DATA
  156. while j < len(test_data[0]): #CHANGE END POINT FOR NC DATA = len(test_data[0]), for AD data = len(test_data[0])/2
  157. #sitk_mri = sitk.GetImageFromArray(test_mri_nonorm[j], isVector=True) #use the non normalized image array
  158. #sitk.WriteImage(sitk_mri,model_filepath+'/figures/mri_'+str(seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.nii')
  159. LRP_analysis = netCNN.LRP_heatmap(test_data, j)
  160. CNN_LRP = LRP_analysis
  161. #CNN_sitk_LRP = sitk.GetImageFromArray(CNN_LRP[j], isVector=True)
  162. #CNN_sitk_LRP.CopyInformation(sitk_mri)
  163. #sitk.WriteImage(CNN_sitk_LRP,model_filepath+'/figures/CNN_LRP_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.nii')
  164. #Sort maps by cases
  165. true_case = "AD" if test_data[3][j]==0 else "NC"
  166. if np.argmax(test_predsCNN[j])==0 and true_case=="AD":
  167. case = "TP"
  168. elif np.argmax(test_predsCNN[j])==0 and true_case!="AD":
  169. case = "FP"
  170. elif np.argmax(test_predsCNN[j])==1 and true_case=="NC":
  171. case = "TN"
  172. elif np.argmax(test_predsCNN[j])==1 and true_case!="NC":
  173. case = "FN"
  174. #case_maps_LRP[case] += CNN_LRP[j]
  175. case_maps_LRP[case] += CNN_LRP
  176. counts[case] += 1
  177. #case_maps_LRP[true_case] += CNN_LRP[j]
  178. case_maps_LRP[true_case] += CNN_LRP
  179. counts[true_case] += 1
  180. print('counts: ',counts)
  181. j+=1
  182. """
  183. #Plot INDIVIDUAL heatmaps - can't do this anymore because I removed CNN_LRP in order to save memory
  184. mean_maps_LRP["AD"] = case_maps_LRP["AD"]/counts["AD"]
  185. for j in range(len(test_data[0])):
  186. subplot_args = { 'nrows': 4, 'ncols': 1, 'figsize': (12, 12), 'sharey':True, 'sharex':True,
  187. 'subplot_kw': {'xticks': [], 'yticks': []} }
  188. fig, axes = plt.subplots(**subplot_args)
  189. vmin, vmax = 50, 99.5 #NOT SURE I WANT THIS (READ PAPER) - might be what is creating the 'mask' effect
  190. for ax, idx in zip(axes[:],[30, 40, 50, 60]):
  191. ax.text(-25, 22, "Slice " + str(idx), rotation="vertical", fontsize=20)
  192. fig, ax, im = self.plot_idv_brain(CNN_LRP[j], test_mri_nonorm[j], mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  193. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  194. ax.text(5, -20, "class: "+str(test_data[3][j])+", prediction: "+str(np.argmax(test_predsCNN[j])), fontsize=20)
  195. fig.tight_layout()
  196. fig.subplots_adjust(right=0.8, top=0.95, hspace=0.05, wspace=0.05)
  197. fig.suptitle("LRP for Patient "+str(test_data[4][j])+", ImageID: "+str(test_data[5][j]), fontsize=22, x=.41)
  198. cbar_ax = fig.add_axes([0.6, 0.15, 0.025, 0.7])
  199. cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
  200. vmin_val, vmax_val = np.percentile(mean_maps_LRP["AD"], vmin), np.percentile(mean_maps_LRP["AD"], vmax)
  201. cbar.set_ticks([vmin_val, vmax_val])
  202. cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
  203. fontsize=16)
  204. cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=18)
  205. fig.savefig(model_filepath+'/figures/CNN_LRP_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.png')
  206. fig.clf()
  207. plt.close(fig)
  208. """
  209. return case_maps_LRP, counts #Removed CNN_LRP to save memory
  210. #Create AVERAGE heatmaps
  211. def plot_avg_maps(self, case_maps_LRP, counts, map_type, test_mri_nonorm, model_filepath, mean_map_AD):
  212. shape = test_mri_nonorm[0].shape
  213. cases = ["AD", "NC", "TP", "TN", "FP", "FN"]
  214. mean_maps_LRP = {case: np.zeros(shape) for case in cases}
  215. mean_maps_LRP["AD"] = mean_map_AD
  216. #Get the PET template
  217. proxy_image = nib.load(model_filepath + '/rbet_TEMPLATE_FDGPET_100.Resampled.nii')
  218. template = np.asarray(proxy_image.dataobj)
  219. PETtemplate = np.asarray(np.expand_dims(template, axis = -1))
  220. print('PET template shape: ', PETtemplate.shape)
  221. #Calculate the mean maps
  222. CNN_sitk_mean_maps_LRP = {case: np.zeros(shape) for case in cases}
  223. print('counts: ',counts)
  224. for case in cases:
  225. is_all_0 = np.all((mean_maps_LRP[case]==0))
  226. if is_all_0:
  227. mean_maps_LRP[case] = case_maps_LRP[case]/counts[case]
  228. sitk_mri = sitk.GetImageFromArray(test_mri_nonorm[0], isVector=True)
  229. CNN_sitk_mean_maps_LRP[case] = sitk.GetImageFromArray(mean_maps_LRP[case], isVector=True)
  230. CNN_sitk_mean_maps_LRP[case].CopyInformation(sitk_mri)
  231. sitk.WriteImage(CNN_sitk_mean_maps_LRP[case],model_filepath+'/figures/CNN_mean_'+str(map_type)+'_'+str(case)+'_'+str(self.seed)+'.nii')
  232. #Plot average heatmaps for AD vs NC
  233. subplot_args = { 'nrows': 3, 'ncols': 2, 'figsize': (12,12), 'sharey':True, 'sharex':True,
  234. 'subplot_kw': {'xticks': [], 'yticks': []},'constrained_layout':True }
  235. fig, axes = plt.subplots(**subplot_args)
  236. vmin, vmax = 50, 99.5 #NOT SURE I WANT THIS (READ PAPER) - might be what is creating the 'mask' effect
  237. #Plot all three views (matching ADRP format):
  238. ax = axes[0,0]
  239. idx = 36
  240. ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
  241. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
  242. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  243. ax = axes[1,0]
  244. idx = 58
  245. ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
  246. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
  247. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  248. ax = axes[2,0]
  249. idx = 58
  250. ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
  251. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  252. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  253. ax.text(45, -20, "AD", fontsize=20)
  254. ax = axes[0,1]
  255. idx = 36
  256. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
  257. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  258. ax = axes[1,1]
  259. idx = 58
  260. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
  261. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  262. ax = axes[2,1]
  263. idx = 58
  264. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  265. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  266. ax.text(45, -20, "NC", fontsize=20)
  267. #Plot several slices along z axis: (matching slices from Boehle paper (https://github.com/moboehle/Pytorch-LRP/blob/master/Plotting%20brain%20maps.ipynb)
  268. # for ax, idx in zip(axes[:, 0], [30, 40, 50, 60]):
  269. # ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
  270. # fig, ax, im = plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
  271. # vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  272. # ax.text(45, -20, "AD", fontsize=20)
  273. # for ax, idx in zip(axes[:, 1], [30, 40, 50, 60]):
  274. # fig, ax, im = plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
  275. # vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  276. # ax.text(45, -20, "NC", fontsize=20)
  277. #fig.tight_layout()
  278. fig.subplots_adjust(right=0.8, top=0.95, hspace=0.05, wspace=0.05)
  279. fig.suptitle("Average "+str(map_type)+" for AD and NC patients", fontsize=22, x=.41)
  280. cbar_ax = fig.add_axes([0.95, 0.15, 0.025, 0.7])
  281. cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
  282. vmin_val, vmax_val = np.percentile(mean_maps_LRP["AD"], vmin), np.percentile(mean_maps_LRP["AD"], vmax)
  283. cbar.set_ticks([vmin_val, vmax_val])
  284. cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
  285. fontsize=16)
  286. cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=18)
  287. fig.savefig(model_filepath+'/figures/CNN_'+str(map_type)+'_avg_ADvNC_'+str(self.seed)+'.png', bbox_inches='tight')
  288. plt.close(fig)
  289. """
  290. #Plot average heatmaps for TP, FP, TN, FN
  291. fig, axes = plt.subplots(4, 4, figsize=(12, 12), sharey=True, sharex=True)
  292. vmin, vmax = 50, 99.5
  293. for ax, idx in zip(axes[:, 0], [30, 40, 50, 60]):
  294. ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=18)
  295. fig, ax, im = plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
  296. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  297. ax.text(10, -20, "True positives", fontsize=18)
  298. for ax, idx in zip(axes[:, 1], [30, 40, 50, 60]):
  299. fig, ax, im = plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
  300. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  301. ax.text(10, -20, "False positives", fontsize=18)
  302. for ax, idx in zip(axes[:, 2], [30, 40, 50, 60]):
  303. fig, ax, im = plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
  304. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  305. ax.text(10, -20, "True negatives", fontsize=18)
  306. for ax, idx in zip(axes[:, 3], [30, 40, 50, 60]):
  307. fig, ax, im = plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
  308. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  309. ax.text(10, -20, "False negatives", fontsize=18)
  310. """
  311. #Plot average heatmaps for TP, FP, TN, FN
  312. fig, axes = plt.subplots(3, 4, figsize=(12, 12), sharey=True, sharex=True, constrained_layout=True)
  313. vmin, vmax = 50, 99.5
  314. ax = axes[0,0]
  315. idx = 36
  316. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
  317. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  318. ax = axes[1,0]
  319. idx = 58
  320. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
  321. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  322. ax = axes[2,0]
  323. idx = 58
  324. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  325. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  326. ax.text(10, -20, "True positives", fontsize=18)
  327. ax = axes[0,1]
  328. idx = 36
  329. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
  330. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  331. ax = axes[1,1]
  332. idx = 58
  333. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
  334. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  335. ax = axes[2,1]
  336. idx = 58
  337. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  338. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  339. ax.text(10, -20, "False positives", fontsize=18)
  340. ax = axes[0,2]
  341. idx = 36
  342. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
  343. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  344. ax = axes[1,2]
  345. idx = 58
  346. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
  347. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  348. ax = axes[2,2]
  349. idx = 58
  350. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  351. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  352. ax.text(10, -20, "True negatives", fontsize=18)
  353. ax = axes[0,3]
  354. idx = 36
  355. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
  356. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  357. ax = axes[1,3]
  358. idx = 58
  359. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
  360. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  361. ax = axes[2,3]
  362. idx = 58
  363. fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
  364. vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
  365. ax.text(10, -20, "False negatives", fontsize=18)
  366. fig.suptitle("Average "+str(map_type)+" for varying cases", fontsize=24, x=.42)
  367. # fig.tight_layout()
  368. fig.subplots_adjust(top=0.95, right=0.8, hspace=0.05, wspace=0.05)
  369. cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
  370. cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
  371. vmin_val, vmax_val = np.percentile(mean_maps_LRP["AD"], vmin), np.percentile(mean_maps_LRP["AD"], vmax)
  372. cbar.set_ticks([vmin_val, vmax_val])
  373. cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
  374. fontsize=16)
  375. cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=20)
  376. fig.savefig(model_filepath+'/figures/CNN_'+str(map_type)+'_avg_TPvFPvTNvFN_'+str(self.seed)+'.png', bbox_inches='tight')
  377. plt.close(fig)
  378. return mean_maps_LRP