perturbate.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Get Python six functionality:
  2. from __future__ import \
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import range
  5. import six
  6. import numpy as np
  7. import warnings
  8. import time
  9. import keras.backend as K
  10. from keras.utils import Sequence
  11. from keras.utils.data_utils import OrderedEnqueuer, GeneratorEnqueuer
  12. import innvestigate.utils
  13. class Perturbation:
  14. """Perturbation of pixels based on analysis result.
  15. :param perturbation_function: Defines the function with which the samples are perturbated. Can be a function or a string that defines a predefined perturbation function.
  16. :type perturbation_function: function or callable or str
  17. :param num_perturbed_regions: Number of regions to be perturbed.
  18. :type num_perturbed_regions: int
  19. :param reduce_function: Function to reduce the analysis result to one channel, e.g. mean or max function.
  20. :type reduce_function: function or callable
  21. :param aggregation_function: Function to aggregate the analysis over subregions.
  22. :type aggregation_function: function or callable
  23. :param pad_mode: How to pad if the image cannot be subdivided into an integer number of regions. As in numpy.pad.
  24. :type pad_mode: str or function or callable
  25. :param in_place: If true, the perturbations are performed in place, i.e. the input samples are modified.
  26. :type in_place: bool
  27. :param value_range: Minimal and maximal value after perturbation as a tuple: (min_val, max_val). The input is clipped to this range
  28. :type value_range: tuple"""
  29. def __init__(self, perturbation_function, num_perturbed_regions=0, region_shape=(9, 9), reduce_function=np.mean,
  30. aggregation_function=np.mean, pad_mode="reflect", in_place=False, value_range=None):
  31. if isinstance(perturbation_function, six.string_types):
  32. if perturbation_function == "zeros":
  33. # This is equivalent to setting the perturbated values to the channel mean if the data are standardized.
  34. self.perturbation_function = np.zeros_like
  35. elif perturbation_function == "gaussian":
  36. # If scale = 1/3, most of the values will be between -1 and 1
  37. self.perturbation_function = lambda x: np.random.normal(loc=0.0, scale=0.3, size=x.shape)
  38. elif perturbation_function == "mean":
  39. self.perturbation_function = np.mean
  40. elif perturbation_function == "invert":
  41. self.perturbation_function = lambda x: -x
  42. else:
  43. raise ValueError("Perturbation function type '{}' not known.".format(perturbation_function))
  44. elif callable(perturbation_function):
  45. self.perturbation_function = perturbation_function
  46. else:
  47. raise TypeError("Cannot handle perturbation function of type {}.".format(type(perturbation_function)))
  48. self.num_perturbed_regions = num_perturbed_regions
  49. self.region_shape = region_shape
  50. self.reduce_function = reduce_function
  51. self.aggregation_function = aggregation_function
  52. self.pad_mode = pad_mode # numpy.pad
  53. self.in_place = in_place
  54. self.value_range = value_range
  55. @staticmethod
  56. def compute_perturbation_mask(ranks, num_perturbated_regions):
  57. perturbation_mask_regions = ranks <= num_perturbated_regions - 1
  58. return perturbation_mask_regions
  59. @staticmethod
  60. def compute_region_ordering(aggregated_regions):
  61. # 0 means highest scoring region
  62. new_shape = tuple(aggregated_regions.shape[:2]) + (-1,)
  63. order = np.argsort(-aggregated_regions.reshape(new_shape), axis=-1)
  64. ranks = order.argsort().reshape(aggregated_regions.shape)
  65. return ranks
  66. def expand_regions_to_pixels(self, regions):
  67. # Resize to pixels (repeat values).
  68. # (n, c, h_aggregated_region, w_aggregated_region) -> (n, c, h_aggregated_region, h_region, w_aggregated_region, w_region)
  69. regions_reshaped = np.expand_dims(np.expand_dims(regions, axis=3), axis=5)
  70. region_pixels = np.repeat(regions_reshaped, self.region_shape[0], axis=3)
  71. region_pixels = np.repeat(region_pixels, self.region_shape[1], axis=5)
  72. assert region_pixels.shape[0] == regions.shape[0] and region_pixels.shape[2:] == (
  73. regions.shape[2], self.region_shape[0], regions.shape[3], self.region_shape[1]), region_pixels.shape
  74. return region_pixels
  75. def reshape_region_pixels(self, region_pixels, target_shape):
  76. # Reshape to output shape
  77. pixels = region_pixels.reshape(target_shape)
  78. assert region_pixels.shape[0] == pixels.shape[0] and region_pixels.shape[1] == pixels.shape[1] and \
  79. region_pixels.shape[2] * region_pixels.shape[3] == pixels.shape[2] and region_pixels.shape[4] * \
  80. region_pixels.shape[5] == pixels.shape[3]
  81. return pixels
  82. def pad(self, analysis):
  83. pad_shape = self.region_shape - np.array(analysis.shape[2:]) % self.region_shape
  84. assert np.all(pad_shape < self.region_shape)
  85. # Pad half the window before and half after (on h and w axes)
  86. pad_shape_before = (pad_shape / 2).astype(int)
  87. pad_shape_after = pad_shape - pad_shape_before
  88. pad_shape = (
  89. (0, 0), (0, 0), (pad_shape_before[0], pad_shape_after[0]), (pad_shape_before[1], pad_shape_after[1]))
  90. analysis = np.pad(analysis, pad_shape, self.pad_mode)
  91. assert np.all(np.array(analysis.shape[2:]) % self.region_shape == 0), analysis.shape[2:]
  92. return analysis, pad_shape_before
  93. def reshape_to_regions(self, analysis):
  94. aggregated_shape = tuple((np.array(analysis.shape[2:]) / self.region_shape).astype(int))
  95. regions = analysis.reshape(
  96. (analysis.shape[0], analysis.shape[1], aggregated_shape[0], self.region_shape[0], aggregated_shape[1],
  97. self.region_shape[1]))
  98. return regions
  99. def aggregate_regions(self, analysis):
  100. regions = self.reshape_to_regions(analysis)
  101. aggregated_regions = self.aggregation_function(regions, axis=(3, 5))
  102. return aggregated_regions
  103. def perturbate_regions(self, x, perturbation_mask_regions):
  104. # Perturbate every region in tensor.
  105. # A single region (at region_x, region_y in sample) should be in mask[sample, channel, region_x, :, region_y, :]
  106. x_perturbated = self.reshape_to_regions(x)
  107. for sample_idx, channel_idx, region_row, region_col in np.ndindex(perturbation_mask_regions.shape):
  108. region = x_perturbated[sample_idx, channel_idx, region_row, :, region_col, :]
  109. region_mask = perturbation_mask_regions[sample_idx, channel_idx, region_row, region_col]
  110. if region_mask:
  111. x_perturbated[sample_idx, channel_idx, region_row, :, region_col, :] = self.perturbation_function(
  112. region)
  113. if self.value_range is not None:
  114. np.clip(x_perturbated,
  115. self.value_range[0],
  116. self.value_range[1],
  117. x_perturbated)
  118. x_perturbated = self.reshape_region_pixels(x_perturbated, x.shape)
  119. return x_perturbated
  120. def perturbate_on_batch(self, x, analysis):
  121. """
  122. :param x: Batch of images.
  123. :type x: numpy.ndarray
  124. :param analysis: Analysis of this batch.
  125. :type analysis: numpy.ndarray
  126. :return: Batch of perturbated images
  127. :rtype: numpy.ndarray
  128. """
  129. if K.image_data_format() == "channels_last":
  130. x = np.moveaxis(x, 3, 1)
  131. analysis = np.moveaxis(analysis, 3, 1)
  132. if not self.in_place:
  133. x = np.copy(x)
  134. assert analysis.shape == x.shape, analysis.shape
  135. original_shape = x.shape
  136. # reduce the analysis along channel axis -> n x 1 x h x w
  137. analysis = self.reduce_function(analysis, axis=1, keepdims=True)
  138. assert analysis.shape == (x.shape[0], 1, x.shape[2], x.shape[3]), analysis.shape
  139. padding = not np.all(np.array(analysis.shape[2:]) % self.region_shape == 0)
  140. if padding:
  141. analysis, pad_shape_before_analysis = self.pad(analysis)
  142. x, pad_shape_before_x = self.pad(x)
  143. aggregated_regions = self.aggregate_regions(analysis)
  144. # Compute perturbation mask (mask with ones where the input should be perturbated, zeros otherwise)
  145. ranks = self.compute_region_ordering(aggregated_regions)
  146. perturbation_mask_regions = self.compute_perturbation_mask(ranks, self.num_perturbed_regions)
  147. # Perturbate each region
  148. x_perturbated = self.perturbate_regions(x, perturbation_mask_regions)
  149. # Crop the original image region to remove the padding
  150. if padding:
  151. x_perturbated = x_perturbated[:, :, pad_shape_before_x[0]:pad_shape_before_x[0] + original_shape[2],
  152. pad_shape_before_x[1]:pad_shape_before_x[1] + original_shape[3]]
  153. if K.image_data_format() == "channels_last":
  154. x_perturbated = np.moveaxis(x_perturbated, 1, 3)
  155. x = np.moveaxis(x, 1, 3)
  156. analysis = np.moveaxis(analysis, 1, 3)
  157. return x_perturbated
  158. class PerturbationAnalysis:
  159. """
  160. Performs the perturbation analysis.
  161. :param analyzer: Analyzer.
  162. :type analyzer: innvestigate.analyzer.base.AnalyzerBase
  163. :param model: Trained Keras model.
  164. :type model: keras.engine.training.Model
  165. :param generator: Data generator.
  166. :type generator: innvestigate.utils.BatchSequence
  167. :param perturbation: Instance of Perturbation class that performs the perturbation.
  168. :type perturbation: innvestigate.tools.Perturbation
  169. :param steps: Number of perturbation steps.
  170. :type steps: int
  171. :param regions_per_step: Number of regions that are perturbed per step.
  172. :type regions_per_step: float
  173. :param recompute_analysis: If true, the analysis is recomputed after each perturbation step.
  174. :type recompute_analysis: bool
  175. :param verbose: If true, print some useful information, e.g. timing, progress etc.
  176. """
  177. def __init__(self, analyzer, model, generator, perturbation, steps=1, regions_per_step=1, recompute_analysis=False,
  178. verbose=False):
  179. self.analyzer = analyzer
  180. self.model = model
  181. self.generator = generator
  182. self.perturbation = perturbation
  183. # if not isinstance(perturbation, Perturbation):
  184. # raise TypeError(type(perturbation))
  185. self.steps = steps
  186. self.regions_per_step = regions_per_step
  187. self.recompute_analysis = recompute_analysis
  188. if not self.recompute_analysis:
  189. # Compute the analysis once in the beginning
  190. analysis = list()
  191. x = list()
  192. y = list()
  193. for xx, yy in self.generator:
  194. x.extend(list(xx))
  195. y.extend(list(yy))
  196. analysis.extend(list(self.analyzer.analyze(xx)))
  197. x = np.array(x)
  198. y = np.array(y)
  199. analysis = np.array(analysis)
  200. self.analysis_generator = innvestigate.utils.BatchSequence([x, y, analysis], batch_size=256)
  201. self.verbose = verbose
  202. def compute_on_batch(self, x, analysis=None, return_analysis=False):
  203. """
  204. Computes the analysis and perturbes the input batch accordingly.
  205. :param x: Samples.
  206. :param analysis: Analysis of x. If None, it is recomputed.
  207. :type x: numpy.ndarray
  208. """
  209. if analysis is None:
  210. analysis = self.analyzer.analyze(x)
  211. x_perturbated = self.perturbation.perturbate_on_batch(x, analysis)
  212. if return_analysis:
  213. return x_perturbated, analysis
  214. else:
  215. return x_perturbated
  216. def evaluate_on_batch(self, x, y, analysis=None, sample_weight=None):
  217. """
  218. Perturbs the input batch and scores the model on the perturbed batch.
  219. :param x: Samples.
  220. :type x: numpy.ndarray
  221. :param y: Labels.
  222. :type y: numpy.ndarray
  223. :param analysis: Analysis of x.
  224. :type analysis: numpy.ndarray
  225. :param sample_weight: Sample weights.
  226. :type sample_weight: None
  227. :return: List of test scores.
  228. :rtype: list
  229. """
  230. if sample_weight is not None:
  231. raise NotImplementedError("Sample weighting is not supported yet.") # TODO
  232. x_perturbated = self.compute_on_batch(x, analysis)
  233. score = self.model.test_on_batch(x_perturbated, y, sample_weight=sample_weight)
  234. return score
  235. def evaluate_generator(self, generator, steps=None,
  236. max_queue_size=10,
  237. workers=1,
  238. use_multiprocessing=False):
  239. """Evaluates the model on a data generator.
  240. The generator should return the same kind of data
  241. as accepted by `test_on_batch`.
  242. For documentation, refer to keras.engine.training.evaluate_generator (https://keras.io/models/model/)
  243. """
  244. steps_done = 0
  245. wait_time = 0.01
  246. all_outs = []
  247. batch_sizes = []
  248. is_sequence = isinstance(generator, Sequence)
  249. if not is_sequence and use_multiprocessing and workers > 1:
  250. warnings.warn(
  251. UserWarning('Using a generator with `use_multiprocessing=True`'
  252. ' and multiple workers may duplicate your data.'
  253. ' Please consider using the`keras.utils.Sequence'
  254. ' class.'))
  255. if steps is None:
  256. if is_sequence:
  257. steps = len(generator)
  258. else:
  259. raise ValueError('`steps=None` is only valid for a generator'
  260. ' based on the `keras.utils.Sequence` class.'
  261. ' Please specify `steps` or use the'
  262. ' `keras.utils.Sequence` class.')
  263. enqueuer = None
  264. try:
  265. if workers > 0:
  266. if is_sequence:
  267. enqueuer = OrderedEnqueuer(generator,
  268. use_multiprocessing=use_multiprocessing)
  269. else:
  270. enqueuer = GeneratorEnqueuer(generator,
  271. use_multiprocessing=use_multiprocessing,
  272. wait_time=wait_time)
  273. enqueuer.start(workers=workers, max_queue_size=max_queue_size)
  274. output_generator = enqueuer.get()
  275. else:
  276. output_generator = generator
  277. while steps_done < steps:
  278. generator_output = next(output_generator)
  279. if not hasattr(generator_output, '__len__'):
  280. raise ValueError('Output of generator should be a tuple '
  281. '(x, y, sample_weight) '
  282. 'or (x, y). Found: ' +
  283. str(generator_output))
  284. if len(generator_output) == 2:
  285. x, y = generator_output
  286. analysis = None
  287. elif len(generator_output) == 3:
  288. x, y, analysis = generator_output
  289. else:
  290. raise ValueError('Output of generator should be a tuple '
  291. '(x, y, analysis) '
  292. 'or (x, y). Found: ' +
  293. str(generator_output))
  294. outs = self.evaluate_on_batch(x, y, analysis=analysis, sample_weight=None)
  295. if isinstance(x, list):
  296. batch_size = x[0].shape[0]
  297. elif isinstance(x, dict):
  298. batch_size = list(x.values())[0].shape[0]
  299. else:
  300. batch_size = x.shape[0]
  301. if batch_size == 0:
  302. raise ValueError('Received an empty batch. '
  303. 'Batches should at least contain one item.')
  304. all_outs.append(outs)
  305. steps_done += 1
  306. batch_sizes.append(batch_size)
  307. finally:
  308. if enqueuer is not None:
  309. enqueuer.stop()
  310. if not isinstance(outs, list):
  311. return np.average(np.asarray(all_outs),
  312. weights=batch_sizes)
  313. else:
  314. averages = []
  315. for i in range(len(outs)):
  316. averages.append(np.average([out[i] for out in all_outs],
  317. weights=batch_sizes))
  318. return averages
  319. def compute_perturbation_analysis(self):
  320. scores = list()
  321. # Evaluate first on original data
  322. scores.append(self.model.evaluate_generator(self.generator))
  323. self.perturbation.num_perturbed_regions = 1
  324. time_start = time.time()
  325. for step in range(self.steps):
  326. tic = time.time()
  327. if self.verbose:
  328. print("Step {} of {}: {} regions perturbed.".format(step + 1, self.steps,
  329. self.perturbation.num_perturbed_regions), end=" ")
  330. scores.append(self.evaluate_generator(self.analysis_generator))
  331. self.perturbation.num_perturbed_regions += self.regions_per_step
  332. toc = time.time()
  333. if self.verbose:
  334. print("Time elapsed: {:.3f} seconds.".format(toc - tic))
  335. time_end = time.time()
  336. if self.verbose:
  337. print("Time elapsed for {} steps: {:.3f} seconds.".format(step + 1,
  338. time_end - time_start)) # Use step + 1 instead of self.steps because the analysis can stop prematurely.
  339. self.perturbation.num_perturbed_regions = 1 # Reset to original value
  340. assert len(scores) == self.steps + 1
  341. return scores