gradient_based.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import keras.models
  8. import keras
  9. from . import base
  10. from . import wrapper
  11. from .. import layers as ilayers
  12. from .. import utils as iutils
  13. from ..utils import keras as kutils
  14. from ..utils.keras import checks as kchecks
  15. from ..utils.keras import graph as kgraph
  16. __all__ = [
  17. "BaselineGradient",
  18. "Gradient",
  19. "InputTimesGradient",
  20. "Deconvnet",
  21. "GuidedBackprop",
  22. "IntegratedGradients",
  23. "SmoothGrad",
  24. ]
  25. ###############################################################################
  26. ###############################################################################
  27. ###############################################################################
  28. class BaselineGradient(base.AnalyzerNetworkBase):
  29. """Gradient analyzer based on build-in gradient.
  30. Returns as analysis the function value with respect to the input.
  31. The gradient is computed via the build in function.
  32. Is mainly used for debugging purposes.
  33. :param model: A Keras model.
  34. """
  35. def __init__(self, model, postprocess=None, **kwargs):
  36. if postprocess not in [None, "abs", "square"]:
  37. raise ValueError("Parameter 'postprocess' must be either "
  38. "None, 'abs', or 'square'.")
  39. self._postprocess = postprocess
  40. self._add_model_softmax_check()
  41. super(BaselineGradient, self).__init__(model, **kwargs)
  42. def _create_analysis(self, model, stop_analysis_at_tensors=[]):
  43. tensors_to_analyze = [x for x in iutils.to_list(model.inputs)
  44. if x not in stop_analysis_at_tensors]
  45. ret = iutils.to_list(ilayers.Gradient()(
  46. tensors_to_analyze+[model.outputs[0]]))
  47. if self._postprocess == "abs":
  48. ret = ilayers.Abs()(ret)
  49. elif self._postprocess == "square":
  50. ret = ilayers.Square()(ret)
  51. return iutils.to_list(ret)
  52. def _get_state(self):
  53. state = super(BaselineGradient, self)._get_state()
  54. state.update({"postprocess": self._postprocess})
  55. return state
  56. @classmethod
  57. def _state_to_kwargs(clazz, state):
  58. postprocess = state.pop("postprocess")
  59. kwargs = super(BaselineGradient, clazz)._state_to_kwargs(state)
  60. kwargs.update({
  61. "postprocess": postprocess,
  62. })
  63. return kwargs
  64. class Gradient(base.ReverseAnalyzerBase):
  65. """Gradient analyzer.
  66. Returns as analysis the function value with respect to the input.
  67. The gradient is computed via the librarie's network reverting.
  68. :param model: A Keras model.
  69. """
  70. def __init__(self, model, postprocess=None, **kwargs):
  71. if postprocess not in [None, "abs", "square"]:
  72. raise ValueError("Parameter 'postprocess' must be either "
  73. "None, 'abs', or 'square'.")
  74. self._postprocess = postprocess
  75. self._add_model_softmax_check()
  76. super(Gradient, self).__init__(model, **kwargs)
  77. def _head_mapping(self, X):
  78. return ilayers.OnesLike()(X)
  79. def _postprocess_analysis(self, X):
  80. ret = super(Gradient, self)._postprocess_analysis(X)
  81. if self._postprocess == "abs":
  82. ret = ilayers.Abs()(ret)
  83. elif self._postprocess == "square":
  84. ret = ilayers.Square()(ret)
  85. return iutils.to_list(ret)
  86. def _get_state(self):
  87. state = super(Gradient, self)._get_state()
  88. state.update({"postprocess": self._postprocess})
  89. return state
  90. @classmethod
  91. def _state_to_kwargs(clazz, state):
  92. postprocess = state.pop("postprocess")
  93. kwargs = super(Gradient, clazz)._state_to_kwargs(state)
  94. kwargs.update({
  95. "postprocess": postprocess,
  96. })
  97. return kwargs
  98. ###############################################################################
  99. ###############################################################################
  100. ###############################################################################
  101. class InputTimesGradient(Gradient):
  102. """Input*Gradient analyzer.
  103. :param model: A Keras model.
  104. """
  105. def __init__(self, model, **kwargs):
  106. self._add_model_softmax_check()
  107. super(InputTimesGradient, self).__init__(model, **kwargs)
  108. def _create_analysis(self, model, stop_analysis_at_tensors=[]):
  109. tensors_to_analyze = [x for x in iutils.to_list(model.inputs)
  110. if x not in stop_analysis_at_tensors]
  111. gradients = super(InputTimesGradient, self)._create_analysis(
  112. model, stop_analysis_at_tensors=stop_analysis_at_tensors)
  113. return [keras.layers.Multiply()([i, g])
  114. for i, g in zip(tensors_to_analyze, gradients)]
  115. ###############################################################################
  116. ###############################################################################
  117. ###############################################################################
  118. class DeconvnetReverseReLULayer(kgraph.ReverseMappingBase):
  119. def __init__(self, layer, state):
  120. self._activation = keras.layers.Activation("relu")
  121. self._layer_wo_relu = kgraph.copy_layer_wo_activation(
  122. layer,
  123. name_template="reversed_%s",
  124. )
  125. def apply(self, Xs, Ys, reversed_Ys, reverse_state):
  126. # Apply relus conditioned on backpropagated values.
  127. reversed_Ys = kutils.apply(self._activation, reversed_Ys)
  128. # Apply gradient of forward pass without relus.
  129. Ys_wo_relu = kutils.apply(self._layer_wo_relu, Xs)
  130. return ilayers.GradientWRT(len(Xs))(Xs+Ys_wo_relu+reversed_Ys)
  131. class Deconvnet(base.ReverseAnalyzerBase):
  132. """Deconvnet analyzer.
  133. Applies the "deconvnet" algorithm to analyze the model.
  134. :param model: A Keras model.
  135. """
  136. def __init__(self, model, **kwargs):
  137. self._add_model_softmax_check()
  138. self._add_model_check(
  139. lambda layer: not kchecks.only_relu_activation(layer),
  140. "Deconvnet is only specified for networks with ReLU activations.",
  141. check_type="exception",
  142. )
  143. super(Deconvnet, self).__init__(model, **kwargs)
  144. def _create_analysis(self, *args, **kwargs):
  145. self._add_conditional_reverse_mapping(
  146. lambda layer: kchecks.contains_activation(layer, "relu"),
  147. DeconvnetReverseReLULayer,
  148. name="deconvnet_reverse_relu_layer",
  149. )
  150. return super(Deconvnet, self)._create_analysis(*args, **kwargs)
  151. def GuidedBackpropReverseReLULayer(Xs, Ys, reversed_Ys, reverse_state):
  152. activation = keras.layers.Activation("relu")
  153. # Apply relus conditioned on backpropagated values.
  154. reversed_Ys = kutils.apply(activation, reversed_Ys)
  155. # Apply gradient of forward pass.
  156. return ilayers.GradientWRT(len(Xs))(Xs+Ys+reversed_Ys)
  157. class GuidedBackprop(base.ReverseAnalyzerBase):
  158. """Guided backprop analyzer.
  159. Applies the "guided backprop" algorithm to analyze the model.
  160. :param model: A Keras model.
  161. """
  162. def __init__(self, model, **kwargs):
  163. self._add_model_softmax_check()
  164. self._add_model_check(
  165. lambda layer: not kchecks.only_relu_activation(layer),
  166. "GuidedBackprop is only specified for "
  167. "networks with ReLU activations.",
  168. check_type="exception",
  169. )
  170. super(GuidedBackprop, self).__init__(model, **kwargs)
  171. def _create_analysis(self, *args, **kwargs):
  172. self._add_conditional_reverse_mapping(
  173. lambda layer: kchecks.contains_activation(layer, "relu"),
  174. GuidedBackpropReverseReLULayer,
  175. name="guided_backprop_reverse_relu_layer",
  176. )
  177. return super(GuidedBackprop, self)._create_analysis(*args, **kwargs)
  178. ###############################################################################
  179. ###############################################################################
  180. ###############################################################################
  181. class IntegratedGradients(wrapper.PathIntegrator):
  182. """Integrated gradient analyzer.
  183. Applies the "integrated gradient" algorithm to analyze the model.
  184. :param model: A Keras model.
  185. :param steps: Number of steps to use average along integration path.
  186. """
  187. def __init__(self, model, steps=64, **kwargs):
  188. subanalyzer_kwargs = {}
  189. kwargs_keys = ["neuron_selection_mode", "postprocess"]
  190. for key in kwargs_keys:
  191. if key in kwargs:
  192. subanalyzer_kwargs[key] = kwargs.pop(key)
  193. subanalyzer = Gradient(model, **subanalyzer_kwargs)
  194. super(IntegratedGradients, self).__init__(subanalyzer,
  195. steps=steps,
  196. **kwargs)
  197. ###############################################################################
  198. ###############################################################################
  199. ###############################################################################
  200. class SmoothGrad(wrapper.GaussianSmoother):
  201. """Smooth grad analyzer.
  202. Applies the "smooth grad" algorithm to analyze the model.
  203. :param model: A Keras model.
  204. :param augment_by_n: Number of distortions to average for smoothing.
  205. """
  206. def __init__(self, model, augment_by_n=64, **kwargs):
  207. subanalyzer_kwargs = {}
  208. kwargs_keys = ["neuron_selection_mode", "postprocess"]
  209. for key in kwargs_keys:
  210. if key in kwargs:
  211. subanalyzer_kwargs[key] = kwargs.pop(key)
  212. subanalyzer = Gradient(model, **subanalyzer_kwargs)
  213. super(SmoothGrad, self).__init__(subanalyzer,
  214. augment_by_n=augment_by_n,
  215. **kwargs)