pattern_based.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import keras.activations
  8. import keras.engine.topology
  9. import keras.layers
  10. import keras.layers.core
  11. import keras.layers.pooling
  12. import keras.models
  13. import keras
  14. import numpy as np
  15. import warnings
  16. from . import base
  17. from .. import layers as ilayers
  18. from .. import utils
  19. from .. import tools as itools
  20. from ..utils import keras as kutils
  21. from ..utils.keras import checks as kchecks
  22. from ..utils.keras import graph as kgraph
  23. __all__ = [
  24. "PatternNet",
  25. "PatternAttribution",
  26. ]
  27. ###############################################################################
  28. ###############################################################################
  29. ###############################################################################
  30. SUPPORTED_LAYER_PATTERNNET = (
  31. keras.engine.topology.InputLayer,
  32. keras.layers.convolutional.Conv2D,
  33. keras.layers.core.Dense,
  34. keras.layers.core.Dropout,
  35. keras.layers.core.Flatten,
  36. keras.layers.core.Masking,
  37. keras.layers.core.Permute,
  38. keras.layers.core.Reshape,
  39. keras.layers.Concatenate,
  40. keras.layers.pooling.GlobalMaxPooling1D,
  41. keras.layers.pooling.GlobalMaxPooling2D,
  42. keras.layers.pooling.GlobalMaxPooling3D,
  43. keras.layers.pooling.MaxPooling1D,
  44. keras.layers.pooling.MaxPooling2D,
  45. keras.layers.pooling.MaxPooling3D,
  46. )
  47. class PatternNetReverseKernelLayer(kgraph.ReverseMappingBase):
  48. """
  49. PatternNet backward mapping for layers with kernels.
  50. Applies the (filter) weights on the forward pass and
  51. on the backward pass applies the gradient computation
  52. where the filter weights are replaced with the patterns.
  53. """
  54. def __init__(self, layer, state, pattern):
  55. config = layer.get_config()
  56. # Layer can contain a kernel and an activation.
  57. # Split layers in a kernel layer and an activation layer.
  58. activation = None
  59. if "activation" in config:
  60. activation = config["activation"]
  61. config["activation"] = None
  62. self._act_layer = keras.layers.Activation(
  63. activation,
  64. name="reversed_act_%s" % config["name"])
  65. self._filter_layer = kgraph.copy_layer_wo_activation(
  66. layer, name_template="reversed_filter_%s")
  67. # Replace filter/kernel weights with patterns.
  68. filter_weights = layer.get_weights()
  69. # Assume that only one weight has a corresponding pattern.
  70. # E.g., biases have no pattern.
  71. tmp = [pattern.shape == x.shape for x in filter_weights]
  72. if np.sum(tmp) != 1:
  73. raise Exception("Cannot match pattern to filter.")
  74. filter_weights[np.argmax(tmp)] = pattern
  75. self._pattern_layer = kgraph.copy_layer_wo_activation(
  76. layer,
  77. name_template="reversed_pattern_%s",
  78. weights=filter_weights)
  79. def apply(self, Xs, Ys, reversed_Ys, reverse_state):
  80. # Reapply the prepared layers.
  81. act_Xs = kutils.apply(self._filter_layer, Xs)
  82. act_Ys = kutils.apply(self._act_layer, act_Xs)
  83. pattern_Ys = kutils.apply(self._pattern_layer, Xs)
  84. # Layers that apply the backward pass.
  85. grad_act = ilayers.GradientWRT(len(act_Xs))
  86. grad_pattern = ilayers.GradientWRT(len(Xs))
  87. # First step: propagate through the activation layer.
  88. # Workaround for linear activations.
  89. linear_activations = [None, keras.activations.get("linear")]
  90. if self._act_layer.activation in linear_activations:
  91. tmp = reversed_Ys
  92. else:
  93. # if linear activation this behaves strange
  94. tmp = utils.to_list(grad_act(act_Xs+act_Ys+reversed_Ys))
  95. # Second step: propagate through the pattern layer.
  96. return grad_pattern(Xs+pattern_Ys+tmp)
  97. class PatternNet(base.OneEpochTrainerMixin, base.ReverseAnalyzerBase):
  98. """PatternNet analyzer.
  99. Applies the "PatternNet" algorithm to analyze the model's predictions.
  100. :param model: A Keras model.
  101. :param patterns: Pattern computed by
  102. :class:`innvestigate.tools.PatternComputer`. If None :func:`fit` needs
  103. to be called.
  104. :param allow_lambda_layers: Approximate lambda layers with the gradient.
  105. :param reverse_project_bottleneck_layers: Project the analysis vector into
  106. range [-1, +1]. (default: True)
  107. """
  108. def __init__(self,
  109. model,
  110. patterns=None,
  111. pattern_type=None,
  112. **kwargs):
  113. self._add_model_softmax_check()
  114. self._add_model_check(
  115. lambda layer: not kchecks.only_relu_activation(layer),
  116. ("PatternNet is not well defined for "
  117. "networks with non-ReLU activations."),
  118. check_type="warning",
  119. )
  120. self._add_model_check(
  121. lambda layer: not kchecks.is_convnet_layer(layer),
  122. ("PatternNet is only well defined for "
  123. "convolutional neural networks."),
  124. check_type="warning",
  125. )
  126. self._add_model_check(
  127. lambda layer: not isinstance(layer,
  128. SUPPORTED_LAYER_PATTERNNET),
  129. ("PatternNet is only well defined for "
  130. "conv2d/max-pooling/dense layers."),
  131. check_type="exception",
  132. )
  133. self._patterns = patterns
  134. if self._patterns is not None:
  135. # copy pattern references
  136. self._patterns = list(patterns)
  137. self._pattern_type = pattern_type
  138. # Pattern projections can lead to +-inf value with long networks.
  139. # We are only interested in the direction, therefore it is save to
  140. # Prevent this by projecting the values in bottleneck layers to +-1.
  141. if not kwargs.get("reverse_project_bottleneck_layers", True):
  142. warnings.warn("The standard setting for "
  143. "'reverse_project_bottleneck_layers' "
  144. "is overwritten.")
  145. else:
  146. kwargs["reverse_project_bottleneck_layers"] = True
  147. super(PatternNet, self).__init__(model, **kwargs)
  148. def _get_pattern_for_layer(self, layer, state):
  149. layers = [l for l in kgraph.get_model_layers(self._model)
  150. if kchecks.contains_kernel(l)]
  151. return self._patterns[layers.index(layer)]
  152. def _prepare_pattern(self, layer, state, pattern):
  153. """""Prepares a pattern before it is set in the back-ward pass."""
  154. return pattern
  155. def _create_analysis(self, *args, **kwargs):
  156. # Apply the pattern mapping on all layers that contain a kernel.
  157. def create_kernel_layer_mapping(layer, state):
  158. pattern = self._get_pattern_for_layer(layer, state)
  159. pattern = self._prepare_pattern(layer, state, pattern)
  160. mapping_obj = PatternNetReverseKernelLayer(layer, state, pattern)
  161. return mapping_obj.apply
  162. self._add_conditional_reverse_mapping(
  163. kchecks.contains_kernel,
  164. create_kernel_layer_mapping,
  165. name="patternnet_kernel_layer_mapping"
  166. )
  167. return super(PatternNet, self)._create_analysis(*args, **kwargs)
  168. def _fit_generator(self,
  169. generator,
  170. steps_per_epoch=None,
  171. epochs=1,
  172. max_queue_size=10,
  173. workers=1,
  174. use_multiprocessing=False,
  175. verbose=0,
  176. disable_no_training_warning=None,
  177. **kwargs):
  178. pattern_type = self._pattern_type
  179. if pattern_type is None:
  180. pattern_type = "relu"
  181. if isinstance(pattern_type, (list, tuple)):
  182. raise ValueError("Only one pattern type allowed. "
  183. "Please pass a string.")
  184. computer = itools.PatternComputer(self._model,
  185. pattern_type=pattern_type,
  186. **kwargs)
  187. self._patterns = computer.compute_generator(
  188. generator,
  189. steps_per_epoch=steps_per_epoch,
  190. max_queue_size=max_queue_size,
  191. workers=workers,
  192. use_multiprocessing=use_multiprocessing,
  193. verbose=verbose)
  194. def _get_state(self):
  195. state = super(PatternNet, self)._get_state()
  196. state.update({"patterns": self._patterns,
  197. "pattern_type": self._pattern_type})
  198. return state
  199. @classmethod
  200. def _state_to_kwargs(clazz, state):
  201. patterns = state.pop("patterns")
  202. pattern_type = state.pop("pattern_type")
  203. kwargs = super(PatternNet, clazz)._state_to_kwargs(state)
  204. kwargs.update({"patterns": patterns,
  205. "pattern_type": pattern_type})
  206. return kwargs
  207. class PatternAttribution(PatternNet):
  208. """PatternAttribution analyzer.
  209. Applies the "PatternNet" algorithm to analyze the model's predictions.
  210. :param model: A Keras model.
  211. :param patterns: Pattern computed by
  212. :class:`innvestigate.tools.PatternComputer`. If None :func:`fit` needs
  213. to be called.
  214. :param allow_lambda_layers: Approximate lambda layers with the gradient.
  215. :param reverse_project_bottleneck_layers: Project the analysis vector into
  216. range [-1, +1]. (default: True)
  217. """
  218. def _prepare_pattern(self, layer, state, pattern):
  219. weights = layer.get_weights()
  220. tmp = [pattern.shape == x.shape for x in weights]
  221. if np.sum(tmp) != 1:
  222. raise Exception("Cannot match pattern to kernel.")
  223. weight = weights[np.argmax(tmp)]
  224. return np.multiply(pattern, weight)