123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- import numpy as np
- import nibabel as nib
- import math
- import csv
- import matplotlib.cm as cm
- import SimpleITK as sitk
- import csv
- from copy import deepcopy
- import matplotlib.colors as mcolors
- import nibabel as nib
- from matplotlib import pyplot as plt
- class heatmapPlotter():
- def __init__(self, seed=None):
- self.seed = seed
- #self.shape = test_mri_nonorm[0].shape
- #ATTEMPT AT VISUALIZE_SALIENCY:
- #for j in range(len(test_data[0])):
- # grads = netCNN.make_vis_saliency(test_data,j)
- # plt.imshow(grads,alpha=0.6)
-
- def plot_idv_brain(self, heat_map, brain_img, ref_scale, fig=None, ax=None, contour_areas=[],
- x_idx=slice(0, 91), y_idx=slice(0, 109), z_idx=slice(0, 91),
- vmin=90, vmax=99.5, set_nan=True, cmap=None, c=None):
- if fig is None or ax is None:
- fig, ax = plt.subplots(1, figsize=(12, 12))
- img = deepcopy(heat_map)
- #if set_nan:
- #img[nmm_mask==0]=np.nan
- if cmap is None:
- cmap = mcolors.LinearSegmentedColormap.from_list(name='alphared',colors=[(1, 0, 0, 0),"darkred", "red", "darkorange", "orange", "yellow"],N=5000)
- grey_vmin, grey_vmax = np.min(brain_img), np.max(brain_img)
- if brain_img is not None:
- brain = deepcopy(brain_img)
- ax.imshow(np.squeeze(brain[x_idx, y_idx, z_idx],-1), cmap="gray", #was .T before (but I dont need to transpose the indices I dont think)
- vmin=grey_vmin, vmax=grey_vmax ) #,alpha=.9
- vmin, vmax = np.percentile(ref_scale, vmin), np.percentile(ref_scale, vmax)
- im = ax.imshow(np.squeeze(img[x_idx, y_idx, z_idx],-1), cmap=cmap, #was .T before (but I dont need to transpose the indices I dont think)
- vmin=vmin, vmax=vmax, interpolation="gaussian", alpha=.7)
- ax.axis('off')
- #plot_contours(contour_areas, x_idx, y_idx, z_idx, fig=fig, ax=ax, c=c)
- plt.gca().invert_yaxis()
- return fig, ax, im
- ##GRAD-CAM
- def GuidedGradCAM(self, test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN):
- last_conv_layer_name = "features" #maybe supposed to be fc1?
- classifier_layer_names = "CNNclass_output" #supposed to have 2 layers??
- shape = test_mri_nonorm[0].shape
- cases = ["AD", "NC", "TP", "TN", "FP", "FN"]
- case_maps_GGC = {case: np.zeros(shape) for case in cases}
- mean_maps_GGC = {case: np.zeros(shape) for case in cases}
- counts = {case: 0 for case in cases}
- j=53 #CHANGE START POINT FOR NC DATA
- while j < len(test_data[0]): #CHANGE END POINT FOR NC DATA = len(test_data[0]), for AD data = len(test_data[0])/2
- #sitk_mri = sitk.GetImageFromArray(test_mri_nonorm[j], isVector=True) #use the non normalized image array
- #sitk.WriteImage(sitk_mri,model_filepath+'/figures/mri_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')
- #sitk_mri_normed = sitk.GetImageFromArray(test_data[0][j],isVector=True) #check out the normalized image
- #sitk.WriteImage(sitk_mri_normed,model_filepath+'/figures/mri_normed_'+str(seed)+'_'+str(j)+'_'+test_data[4][j]+'_'+test_data[3][j]+'.nii')
-
- CNN_gradcam_map = netCNN.make_gradcam_heatmap2(test_data,j)
- #CNN_gradcam[j] = CNN_gradcam_map
- #CNN_sitk_gradcam = sitk.GetImageFromArray(CNN_gradcam_map, isVector=True)
- #CNN_sitk_gradcam.CopyInformation(sitk_mri)
- #sitk.WriteImage(CNN_sitk_gradcam,model_filepath+'/figures/CNN_gradcam_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')
-
- #GUIDED BACKPROP
- CNN_gb_map = netCNN.guided_backprop(test_data,j)
- #CNN_gb[j] = CNN_gb_map
- #CNN_sitk_gb = sitk.GetImageFromArray(CNN_gb_map)
- #sitk.WriteImage(CNN_sitk_gb,model_filepath+'/figures/CNN_gb_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')#
- #GUIDED GRAD-CAM
- CNN_guided_gradcam_map = CNN_gb_map * CNN_gradcam_map
- #CNN_guided_gradcam[j] = CNN_guided_gradcam_map
- #CNN_sitk_guided_gradcam = sitk.GetImageFromArray(CNN_guided_gradcam_map)
- #sitk.WriteImage(CNN_sitk_guided_gradcam,model_filepath+'/figures/CNN_guided_gradcam_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.nii')
- """ Just for now for memory purposes
- #Plot middle slice of each
- subplot_args = { 'nrows': 1, 'ncols': 5, 'figsize': (12, 4),
- 'subplot_kw': {'xticks': [], 'yticks': []} }
- f, ax = plt.subplots(**subplot_args)
- ax[0].set_title('Original Image', fontsize=11)
- ax[0].imshow(test_mri_nonorm[j][:,:,45,0],cmap='gray')
- ax[1].set_title('Guided Backprop overlay', fontsize=11)
- ax[1].imshow(test_mri_nonorm[j][:,:,45,0],cmap='gray')
- ax[1].imshow(CNN_gb_map[:,:,45,0],cmap='jet', alpha=0.4)
- ax[2].set_title('GRAD-CAM', fontsize=11)
- ax[2].imshow(CNN_gradcam_map[:,:,45,0],cmap='jet')
- ax[3].set_title('Guided GRAD-CAM', fontsize=11)
- ax[3].imshow(CNN_guided_gradcam_map[:,:,45,0],cmap='jet')
- ax[4].set_title('Guided GRAD-CAM overlay', fontsize=11)
- ax[4].imshow(test_mri_nonorm[j][:,:,45,0],cmap='gray')
- ax[4].imshow(CNN_guided_gradcam_map[:,:,45,0],cmap='jet', alpha=0.4)
- plt.savefig(model_filepath+'/figures/CNN_grad_maps_z45_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[4][j])+'_'+str(test_data[3][j])+'.png')
- #plt.show()
- fig.clf()
- plt.close(f)
- """
- #Sort maps by cases
- true_case = "AD" if test_data[3][j]==0 else "NC"
- if np.argmax(test_predsCNN[j])==0 and true_case=="AD":
- case = "TP"
- elif np.argmax(test_predsCNN[j])==0 and true_case!="AD":
- case = "FP"
- elif np.argmax(test_predsCNN[j])==1 and true_case=="NC":
- case = "TN"
- elif np.argmax(test_predsCNN[j])==1 and true_case!="NC":
- case = "FN"
- """
- #for Guided Grad Cam
- case_maps_GGC[case] += CNN_guided_gradcam_map
- counts[case] += 1
- case_maps_GGC[true_case] += CNN_guided_gradcam_map
- counts[true_case] += 1
- """
- #for Grad Cam
- case_maps_GGC[case] += CNN_gradcam_map
- counts[case] += 1
- case_maps_GGC[true_case] += CNN_gradcam_map
- counts[true_case] += 1
-
- print('counts: ',counts)
- j+=1
- """
- #Plot INDIVIDUAL heatmaps - can't do this anymore because I removed CNN_gradcam, CNN_gb, CNN_guided_gradcam in order to save memory
- mean_maps_GGC["AD"] = case_maps_GGC["AD"]/counts["AD"]
- for j in range(len(test_data[0])):
- subplot_args = { 'nrows': 4, 'ncols': 1, 'figsize': (12, 12), 'sharey':True, 'sharex':True,
- 'subplot_kw': {'xticks': [], 'yticks': []} }
- fig, axes = plt.subplots(**subplot_args)
- vmin, vmax = 50, 99.5 #NOT SURE I WANT THIS (READ PAPER) - might be what is creating the 'mask' effect
- for ax, idx in zip(axes[:],[30, 40, 50, 60]):
- ax.text(-25, 22, "Slice " + str(idx), rotation="vertical", fontsize=20)
- fig, ax, im = self.plot_idv_brain(CNN_guided_gradcam[j], test_mri_nonorm[j], mean_maps_GGC["AD"],x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(5, -20, "class: "+str(test_data[3][j])+", prediction: "+str(np.argmax(test_predsCNN[j])), fontsize=20)
- fig.tight_layout()
- fig.subplots_adjust(right=0.8, top=0.95, hspace=0.05, wspace=0.05)
- fig.suptitle("LRP for Patient "+str(test_data[4][j])+", ImageID: "+str(test_data[5][j]), fontsize=22, x=.41)
- cbar_ax = fig.add_axes([0.6, 0.15, 0.025, 0.7])
- cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
- vmin_val, vmax_val = np.percentile(mean_maps_GGC["AD"], vmin), np.percentile(mean_maps_GGC["AD"], vmax)
- cbar.set_ticks([vmin_val, vmax_val])
- cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
- fontsize=16)
- cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=18)
- fig.savefig(model_filepath+'/figures/CNN_GGC_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.png')
- fig.clf()
- plt.close(fig)
- """
- return case_maps_GGC, counts #Removed CNN_gradcam, CNN_gb, CNN_guided_gradcam to save memory
-
- #LAYERWISE RELEVANCE PROPAGATION #https://github.com/moboehle/Pytorch-LRP/blob/master/Plotting%20brain%20maps.ipynb
- def LRP(self, test_data, test_mri_nonorm, model_filepath, netCNN, test_predsCNN):
- shape = test_mri_nonorm[0].shape
- print('length of test_data[3]: ',len(test_data[3]))
- #Run LRP for each test image
- cases = ["AD", "NC", "TP", "TN", "FP", "FN"]
- case_maps_LRP = {case: np.zeros(shape) for case in cases}
- mean_maps_LRP = {case: np.zeros(shape) for case in cases}
- counts = {case: 0 for case in cases}
- j=53 #CHANGE START POINT FOR NC DATA
- while j < len(test_data[0]): #CHANGE END POINT FOR NC DATA = len(test_data[0]), for AD data = len(test_data[0])/2
- #sitk_mri = sitk.GetImageFromArray(test_mri_nonorm[j], isVector=True) #use the non normalized image array
- #sitk.WriteImage(sitk_mri,model_filepath+'/figures/mri_'+str(seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.nii')
- LRP_analysis = netCNN.LRP_heatmap(test_data, j)
- CNN_LRP = LRP_analysis
- #CNN_sitk_LRP = sitk.GetImageFromArray(CNN_LRP[j], isVector=True)
- #CNN_sitk_LRP.CopyInformation(sitk_mri)
- #sitk.WriteImage(CNN_sitk_LRP,model_filepath+'/figures/CNN_LRP_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.nii')
-
- #Sort maps by cases
- true_case = "AD" if test_data[3][j]==0 else "NC"
- if np.argmax(test_predsCNN[j])==0 and true_case=="AD":
- case = "TP"
- elif np.argmax(test_predsCNN[j])==0 and true_case!="AD":
- case = "FP"
- elif np.argmax(test_predsCNN[j])==1 and true_case=="NC":
- case = "TN"
- elif np.argmax(test_predsCNN[j])==1 and true_case!="NC":
- case = "FN"
- #case_maps_LRP[case] += CNN_LRP[j]
- case_maps_LRP[case] += CNN_LRP
- counts[case] += 1
- #case_maps_LRP[true_case] += CNN_LRP[j]
- case_maps_LRP[true_case] += CNN_LRP
- counts[true_case] += 1
- print('counts: ',counts)
- j+=1
- """
- #Plot INDIVIDUAL heatmaps - can't do this anymore because I removed CNN_LRP in order to save memory
- mean_maps_LRP["AD"] = case_maps_LRP["AD"]/counts["AD"]
- for j in range(len(test_data[0])):
- subplot_args = { 'nrows': 4, 'ncols': 1, 'figsize': (12, 12), 'sharey':True, 'sharex':True,
- 'subplot_kw': {'xticks': [], 'yticks': []} }
- fig, axes = plt.subplots(**subplot_args)
- vmin, vmax = 50, 99.5 #NOT SURE I WANT THIS (READ PAPER) - might be what is creating the 'mask' effect
- for ax, idx in zip(axes[:],[30, 40, 50, 60]):
- ax.text(-25, 22, "Slice " + str(idx), rotation="vertical", fontsize=20)
- fig, ax, im = self.plot_idv_brain(CNN_LRP[j], test_mri_nonorm[j], mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(5, -20, "class: "+str(test_data[3][j])+", prediction: "+str(np.argmax(test_predsCNN[j])), fontsize=20)
- fig.tight_layout()
- fig.subplots_adjust(right=0.8, top=0.95, hspace=0.05, wspace=0.05)
- fig.suptitle("LRP for Patient "+str(test_data[4][j])+", ImageID: "+str(test_data[5][j]), fontsize=22, x=.41)
- cbar_ax = fig.add_axes([0.6, 0.15, 0.025, 0.7])
- cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
- vmin_val, vmax_val = np.percentile(mean_maps_LRP["AD"], vmin), np.percentile(mean_maps_LRP["AD"], vmax)
- cbar.set_ticks([vmin_val, vmax_val])
- cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
- fontsize=16)
- cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=18)
- fig.savefig(model_filepath+'/figures/CNN_LRP_'+str(self.seed)+'_'+str(j)+'_'+str(test_data[5][j])+'_'+str(test_data[3][j])+'.png')
- fig.clf()
- plt.close(fig)
-
- """
- return case_maps_LRP, counts #Removed CNN_LRP to save memory
- #Create AVERAGE heatmaps
- def plot_avg_maps(self, case_maps_LRP, counts, map_type, test_mri_nonorm, model_filepath, mean_map_AD):
- shape = test_mri_nonorm[0].shape
- cases = ["AD", "NC", "TP", "TN", "FP", "FN"]
- mean_maps_LRP = {case: np.zeros(shape) for case in cases}
- mean_maps_LRP["AD"] = mean_map_AD
- #Get the PET template
- proxy_image = nib.load(model_filepath + '/rbet_TEMPLATE_FDGPET_100.Resampled.nii')
- template = np.asarray(proxy_image.dataobj)
- PETtemplate = np.asarray(np.expand_dims(template, axis = -1))
- print('PET template shape: ', PETtemplate.shape)
- #Calculate the mean maps
- CNN_sitk_mean_maps_LRP = {case: np.zeros(shape) for case in cases}
- print('counts: ',counts)
- for case in cases:
- is_all_0 = np.all((mean_maps_LRP[case]==0))
- if is_all_0:
- mean_maps_LRP[case] = case_maps_LRP[case]/counts[case]
- sitk_mri = sitk.GetImageFromArray(test_mri_nonorm[0], isVector=True)
- CNN_sitk_mean_maps_LRP[case] = sitk.GetImageFromArray(mean_maps_LRP[case], isVector=True)
- CNN_sitk_mean_maps_LRP[case].CopyInformation(sitk_mri)
- sitk.WriteImage(CNN_sitk_mean_maps_LRP[case],model_filepath+'/figures/CNN_mean_'+str(map_type)+'_'+str(case)+'_'+str(self.seed)+'.nii')
-
- #Plot average heatmaps for AD vs NC
- subplot_args = { 'nrows': 3, 'ncols': 2, 'figsize': (12,12), 'sharey':True, 'sharex':True,
- 'subplot_kw': {'xticks': [], 'yticks': []},'constrained_layout':True }
- fig, axes = plt.subplots(**subplot_args)
- vmin, vmax = 50, 99.5 #NOT SURE I WANT THIS (READ PAPER) - might be what is creating the 'mask' effect
- #Plot all three views (matching ADRP format):
- ax = axes[0,0]
- idx = 36
- ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[1,0]
- idx = 58
- ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[2,0]
- idx = 58
- ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(45, -20, "AD", fontsize=20)
- ax = axes[0,1]
- idx = 36
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[1,1]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[2,1]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(45, -20, "NC", fontsize=20)
- #Plot several slices along z axis: (matching slices from Boehle paper (https://github.com/moboehle/Pytorch-LRP/blob/master/Plotting%20brain%20maps.ipynb)
- # for ax, idx in zip(axes[:, 0], [30, 40, 50, 60]):
- # ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=20)
- # fig, ax, im = plot_idv_brain(mean_maps_LRP["AD"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
- # vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- # ax.text(45, -20, "AD", fontsize=20)
- # for ax, idx in zip(axes[:, 1], [30, 40, 50, 60]):
- # fig, ax, im = plot_idv_brain(mean_maps_LRP["NC"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
- # vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- # ax.text(45, -20, "NC", fontsize=20)
- #fig.tight_layout()
- fig.subplots_adjust(right=0.8, top=0.95, hspace=0.05, wspace=0.05)
- fig.suptitle("Average "+str(map_type)+" for AD and NC patients", fontsize=22, x=.41)
- cbar_ax = fig.add_axes([0.95, 0.15, 0.025, 0.7])
- cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
- vmin_val, vmax_val = np.percentile(mean_maps_LRP["AD"], vmin), np.percentile(mean_maps_LRP["AD"], vmax)
- cbar.set_ticks([vmin_val, vmax_val])
- cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
- fontsize=16)
- cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=18)
- fig.savefig(model_filepath+'/figures/CNN_'+str(map_type)+'_avg_ADvNC_'+str(self.seed)+'.png', bbox_inches='tight')
- plt.close(fig)
- """
- #Plot average heatmaps for TP, FP, TN, FN
- fig, axes = plt.subplots(4, 4, figsize=(12, 12), sharey=True, sharex=True)
- vmin, vmax = 50, 99.5
- for ax, idx in zip(axes[:, 0], [30, 40, 50, 60]):
- ax.text(-25, 20, "Slice " + str(idx), rotation="vertical", fontsize=18)
- fig, ax, im = plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "True positives", fontsize=18)
- for ax, idx in zip(axes[:, 1], [30, 40, 50, 60]):
- fig, ax, im = plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "False positives", fontsize=18)
- for ax, idx in zip(axes[:, 2], [30, 40, 50, 60]):
- fig, ax, im = plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "True negatives", fontsize=18)
- for ax, idx in zip(axes[:, 3], [30, 40, 50, 60]):
- fig, ax, im = plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "False negatives", fontsize=18)
- """
- #Plot average heatmaps for TP, FP, TN, FN
- fig, axes = plt.subplots(3, 4, figsize=(12, 12), sharey=True, sharex=True, constrained_layout=True)
- vmin, vmax = 50, 99.5
- ax = axes[0,0]
- idx = 36
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[1,0]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[2,0]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "True positives", fontsize=18)
- ax = axes[0,1]
- idx = 36
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[1,1]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[2,1]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FP"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "False positives", fontsize=18)
- ax = axes[0,2]
- idx = 36
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[1,2]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[2,2]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["TN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "True negatives", fontsize=18)
- ax = axes[0,3]
- idx = 36
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], x_idx=idx,y_idx=slice(0, shape[1]),z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[1,3]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=idx,z_idx=slice(0, shape[2]), contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax = axes[2,3]
- idx = 58
- fig, ax, im = self.plot_idv_brain(mean_maps_LRP["FN"], PETtemplate, mean_maps_LRP["AD"], x_idx=slice(0, shape[0]),y_idx=slice(0, shape[1]),z_idx=idx, contour_areas=[],
- vmin=vmin, vmax=vmax, fig=fig, ax=ax, set_nan=False, cmap="hot");
- ax.text(10, -20, "False negatives", fontsize=18)
- fig.suptitle("Average "+str(map_type)+" for varying cases", fontsize=24, x=.42)
- # fig.tight_layout()
- fig.subplots_adjust(top=0.95, right=0.8, hspace=0.05, wspace=0.05)
- cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
- cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)
- vmin_val, vmax_val = np.percentile(mean_maps_LRP["AD"], vmin), np.percentile(mean_maps_LRP["AD"], vmax)
- cbar.set_ticks([vmin_val, vmax_val])
- cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
- fontsize=16)
- cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=20)
- fig.savefig(model_filepath+'/figures/CNN_'+str(map_type)+'_avg_TPvFPvTNvFN_'+str(self.seed)+'.png', bbox_inches='tight')
- plt.close(fig)
-
- return mean_maps_LRP
|