wrapper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import zip
  5. ###############################################################################
  6. ###############################################################################
  7. ###############################################################################
  8. import keras.models
  9. import keras.backend as K
  10. import numpy as np
  11. from . import base
  12. from .. import layers as ilayers
  13. from .. import utils as iutils
  14. from ..utils import keras as kutils
  15. __all__ = [
  16. "WrapperBase",
  17. "AugmentReduceBase",
  18. "GaussianSmoother",
  19. "PathIntegrator",
  20. ]
  21. ###############################################################################
  22. ###############################################################################
  23. ###############################################################################
  24. class WrapperBase(base.AnalyzerBase):
  25. """Interface for wrappers around analyzers
  26. This class is the basic interface for wrappers around analyzers.
  27. :param subanalyzer: The analyzer to be wrapped.
  28. """
  29. def __init__(self, subanalyzer, *args, **kwargs):
  30. self._subanalyzer = subanalyzer
  31. model = None
  32. super(WrapperBase, self).__init__(model,
  33. *args, **kwargs)
  34. def analyze(self, *args, **kwargs):
  35. return self._subanalyzer.analyze(*args, **kwargs)
  36. def _get_state(self):
  37. sa_class_name, sa_state = self._subanalyzer.save()
  38. state = {}
  39. state.update({"subanalyzer_class_name": sa_class_name})
  40. state.update({"subanalyzer_state": sa_state})
  41. return state
  42. @classmethod
  43. def _state_to_kwargs(clazz, state):
  44. sa_class_name = state.pop("subanalyzer_class_name")
  45. sa_state = state.pop("subanalyzer_state")
  46. assert len(state) == 0
  47. subanalyzer = base.AnalyzerBase.load(sa_class_name, sa_state)
  48. kwargs = {"subanalyzer": subanalyzer}
  49. return kwargs
  50. ###############################################################################
  51. ###############################################################################
  52. ###############################################################################
  53. class AugmentReduceBase(WrapperBase):
  54. """Interface for wrappers that augment the input and reduce the analysis.
  55. This class is an interface for wrappers that:
  56. * augment the input to the analyzer by creating new samples.
  57. * reduce the returned analysis to match the initial input shapes.
  58. :param subanalyzer: The analyzer to be wrapped.
  59. :param augment_by_n: Number of samples to create.
  60. """
  61. def __init__(self, subanalyzer, *args, **kwargs):
  62. self._augment_by_n = kwargs.pop("augment_by_n", 2)
  63. self._neuron_selection_mode = subanalyzer._neuron_selection_mode
  64. if self._neuron_selection_mode != "all":
  65. # TODO: this is not transparent, find a better way.
  66. subanalyzer._neuron_selection_mode = "index"
  67. super(AugmentReduceBase, self).__init__(subanalyzer,
  68. *args, **kwargs)
  69. if isinstance(self._subanalyzer, base.AnalyzerNetworkBase):
  70. # Take the keras analyzer model and
  71. # add augment and reduce functionality.
  72. self._keras_based_augment_reduce = True
  73. else:
  74. raise NotImplementedError("Keras-based subanalyzer required.")
  75. def create_analyzer_model(self):
  76. if not self._keras_based_augment_reduce:
  77. return
  78. self._subanalyzer.create_analyzer_model()
  79. if self._subanalyzer._n_debug_output > 0:
  80. raise Exception("No debug output at subanalyzer is supported.")
  81. model = self._subanalyzer._analyzer_model
  82. if None in model.input_shape[1:]:
  83. raise ValueError("The input shape for the model needs "
  84. "to be fully specified (except the batch axis). "
  85. "Model input shape is: %s" % (model.input_shape,))
  86. inputs = model.inputs[:self._subanalyzer._n_data_input]
  87. extra_inputs = model.inputs[self._subanalyzer._n_data_input:]
  88. # todo: check this, index seems not right.
  89. #outputs = model.outputs[:self._subanalyzer._n_data_input]
  90. extra_outputs = model.outputs[self._subanalyzer._n_data_input:]
  91. if len(extra_outputs) > 0:
  92. raise Exception("No extra output is allowed "
  93. "with this wrapper.")
  94. new_inputs = iutils.to_list(self._augment(inputs))
  95. # print(type(new_inputs), type(extra_inputs))
  96. tmp = iutils.to_list(model(new_inputs+extra_inputs))
  97. new_outputs = iutils.to_list(self._reduce(tmp))
  98. new_constant_inputs = self._keras_get_constant_inputs()
  99. new_model = keras.models.Model(
  100. inputs=inputs+extra_inputs+new_constant_inputs,
  101. outputs=new_outputs+extra_outputs)
  102. self._subanalyzer._analyzer_model = new_model
  103. def analyze(self, X, *args, **kwargs):
  104. if self._keras_based_augment_reduce is True:
  105. if not hasattr(self._subanalyzer, "_analyzer_model"):
  106. self.create_analyzer_model()
  107. ns_mode = self._neuron_selection_mode
  108. if ns_mode in ["max_activation", "index"]:
  109. if ns_mode == "max_activation":
  110. tmp = self._subanalyzer._model.predict(X)
  111. indices = np.argmax(tmp, axis=1)
  112. else:
  113. if len(args):
  114. args = list(args)
  115. indices = args.pop(0)
  116. else:
  117. indices = kwargs.pop("neuron_selection")
  118. # broadcast to match augmented samples.
  119. indices = np.repeat(indices, self._augment_by_n)
  120. kwargs["neuron_selection"] = indices
  121. return self._subanalyzer.analyze(X, *args, **kwargs)
  122. else:
  123. raise DeprecationWarning("Not supported anymore.")
  124. def _keras_get_constant_inputs(self):
  125. return list()
  126. def _augment(self, X):
  127. repeat = ilayers.Repeat(self._augment_by_n, axis=0)
  128. return [repeat(x) for x in iutils.to_list(X)]
  129. def _reduce(self, X):
  130. X_shape = [K.int_shape(x) for x in iutils.to_list(X)]
  131. reshape = [ilayers.Reshape((-1, self._augment_by_n)+shape[1:])
  132. for shape in X_shape]
  133. mean = ilayers.Mean(axis=1)
  134. return [mean(reshape_x(x)) for x, reshape_x in zip(X, reshape)]
  135. def _get_state(self):
  136. if self._neuron_selection_mode != "all":
  137. # TODO: this is not transparent, find a better way.
  138. # revert the tempering in __init__
  139. tmp = self._neuron_selection_mode
  140. self._subanalyzer._neuron_selection_mode = tmp
  141. state = super(AugmentReduceBase, self)._get_state()
  142. state.update({"augment_by_n": self._augment_by_n})
  143. return state
  144. @classmethod
  145. def _state_to_kwargs(clazz, state):
  146. augment_by_n = state.pop("augment_by_n")
  147. kwargs = super(AugmentReduceBase, clazz)._state_to_kwargs(state)
  148. kwargs.update({"augment_by_n": augment_by_n})
  149. return kwargs
  150. ###############################################################################
  151. ###############################################################################
  152. ###############################################################################
  153. class GaussianSmoother(AugmentReduceBase):
  154. """Wrapper that adds noise to the input and averages over analyses
  155. This wrapper creates new samples by adding Gaussian noise
  156. to the input. The final analysis is an average of the returned analyses.
  157. :param subanalyzer: The analyzer to be wrapped.
  158. :param noise_scale: The stddev of the applied noise.
  159. :param augment_by_n: Number of samples to create.
  160. """
  161. def __init__(self, subanalyzer, *args, **kwargs):
  162. self._noise_scale = kwargs.pop("noise_scale", 1)
  163. super(GaussianSmoother, self).__init__(subanalyzer,
  164. *args, **kwargs)
  165. def _augment(self, X):
  166. tmp = super(GaussianSmoother, self)._augment(X)
  167. noise = ilayers.TestPhaseGaussianNoise(stddev=self._noise_scale)
  168. return [noise(x) for x in tmp]
  169. def _get_state(self):
  170. state = super(GaussianSmoother, self)._get_state()
  171. state.update({"noise_scale": self._noise_scale})
  172. return state
  173. @classmethod
  174. def _state_to_kwargs(clazz, state):
  175. noise_scale = state.pop("noise_scale")
  176. kwargs = super(GaussianSmoother, clazz)._state_to_kwargs(state)
  177. kwargs.update({"noise_scale": noise_scale})
  178. return kwargs
  179. ###############################################################################
  180. ###############################################################################
  181. ###############################################################################
  182. class PathIntegrator(AugmentReduceBase):
  183. """Integrated the analysis along a path
  184. This analyzer:
  185. * creates a path from input to reference image.
  186. * creates steps number of intermediate inputs and
  187. crests an analysis for them.
  188. * sums the analyses and multiplies them with the input-reference_input.
  189. This wrapper is used to implement Integrated Gradients.
  190. We refer to the paper for further information.
  191. :param subanalyzer: The analyzer to be wrapped.
  192. :param steps: Number of steps for integration.
  193. :param reference_inputs: The reference input.
  194. """
  195. def __init__(self, subanalyzer, *args, **kwargs):
  196. steps = kwargs.pop("steps", 16)
  197. self._reference_inputs = kwargs.pop("reference_inputs", 0)
  198. self._keras_constant_inputs = None
  199. super(PathIntegrator, self).__init__(subanalyzer,
  200. *args,
  201. augment_by_n=steps,
  202. **kwargs)
  203. def _keras_set_constant_inputs(self, inputs):
  204. tmp = [K.variable(x) for x in inputs]
  205. self._keras_constant_inputs = [
  206. keras.layers.Input(tensor=x, shape=x.shape[1:])
  207. for x in tmp]
  208. def _keras_get_constant_inputs(self):
  209. return self._keras_constant_inputs
  210. def _compute_difference(self, X):
  211. if self._keras_constant_inputs is None:
  212. tmp = kutils.broadcast_np_tensors_to_keras_tensors(
  213. X, self._reference_inputs)
  214. self._keras_set_constant_inputs(tmp)
  215. reference_inputs = self._keras_get_constant_inputs()
  216. return [keras.layers.Subtract()([x, ri])
  217. for x, ri in zip(X, reference_inputs)]
  218. def _augment(self, X):
  219. tmp = super(PathIntegrator, self)._augment(X)
  220. tmp = [ilayers.Reshape((-1, self._augment_by_n)+K.int_shape(x)[1:])(x)
  221. for x in tmp]
  222. difference = self._compute_difference(X)
  223. self._keras_difference = difference
  224. # Make broadcastable.
  225. difference = [ilayers.Reshape((-1, 1)+K.int_shape(x)[1:])(x)
  226. for x in difference]
  227. # Compute path steps.
  228. multiply_with_linspace = ilayers.MultiplyWithLinspace(
  229. 0, 1,
  230. n=self._augment_by_n,
  231. axis=1)
  232. path_steps = [multiply_with_linspace(d) for d in difference]
  233. reference_inputs = self._keras_get_constant_inputs()
  234. ret = [keras.layers.Add()([x, p]) for x, p in zip(reference_inputs, path_steps)]
  235. ret = [ilayers.Reshape((-1,)+K.int_shape(x)[2:])(x) for x in ret]
  236. return ret
  237. def _reduce(self, X):
  238. tmp = super(PathIntegrator, self)._reduce(X)
  239. difference = self._keras_difference
  240. del self._keras_difference
  241. return [keras.layers.Multiply()([x, d])
  242. for x, d in zip(tmp, difference)]
  243. def _get_state(self):
  244. state = super(PathIntegrator, self)._get_state()
  245. state.update({"reference_inputs": self._reference_inputs})
  246. return state
  247. @classmethod
  248. def _state_to_kwargs(clazz, state):
  249. reference_inputs = state.pop("reference_inputs")
  250. kwargs = super(PathIntegrator, clazz)._state_to_kwargs(state)
  251. kwargs.update({"reference_inputs": reference_inputs})
  252. # We use steps instead.
  253. kwargs.update({"steps": kwargs["augment_by_n"]})
  254. del kwargs["augment_by_n"]
  255. return kwargs