deeplift.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import importlib
  8. import keras.backend as K
  9. import keras.layers
  10. import numpy as np
  11. import tempfile
  12. import warnings
  13. from . import base
  14. from .. import layers as ilayers
  15. from .. import utils as iutils
  16. from ..utils import keras as kutils
  17. from ..utils.keras import checks as kchecks
  18. from ..utils.keras import graph as kgraph
  19. __all__ = [
  20. "DeepLIFT",
  21. "DeepLIFTWrapper",
  22. ]
  23. ###############################################################################
  24. ###############################################################################
  25. ###############################################################################
  26. def _create_deeplift_rules(reference_mapping, approximate_gradient=True):
  27. def RescaleRule(Xs, Ys, As, reverse_state, local_references={}):
  28. if approximate_gradient:
  29. def rescale_f(x):
  30. a, dx, dy, g = x
  31. return K.switch(K.less(K.abs(dx), K.epsilon()), g, a*(dy/dx))
  32. else:
  33. def rescale_f(x):
  34. a, dx, dy, _ = x
  35. return a*(dy/(dx + K.epsilon()))
  36. grad = ilayers.GradientWRT(len(Xs))
  37. rescale = keras.layers.Lambda(rescale_f)
  38. Xs_references = [
  39. reference_mapping.get(x, local_references.get(x, None))
  40. for x in Xs
  41. ]
  42. Ys_references = [
  43. reference_mapping.get(x, local_references.get(x, None))
  44. for x in Ys
  45. ]
  46. Xs_differences = [keras.layers.Subtract()([x, r])
  47. for x, r in zip(Xs, Xs_references)]
  48. Ys_differences = [keras.layers.Subtract()([x, r])
  49. for x, r in zip(Ys, Ys_references)]
  50. gradients = iutils.to_list(grad(Xs+Ys+As))
  51. return [rescale([a, dx, dy, g])
  52. for a, dx, dy, g
  53. in zip(As, Xs_differences, Ys_differences, gradients)]
  54. def LinearRule(Xs, Ys, As, reverse_state):
  55. if approximate_gradient:
  56. def switch_f(x):
  57. dx, a, g = x
  58. return K.switch(K.less(K.abs(dx), K.epsilon()), g, a)
  59. else:
  60. def switch_f(x):
  61. _, a, _ = x
  62. return a
  63. grad = ilayers.GradientWRT(len(Xs))
  64. switch = keras.layers.Lambda(switch_f)
  65. Xs_references = [reference_mapping[x] for x in Xs]
  66. Ys_references = [reference_mapping[x] for x in Ys]
  67. Xs_differences = [keras.layers.Subtract()([x, r])
  68. for x, r in zip(Xs, Xs_references)]
  69. Ys_differences = [keras.layers.Subtract()([x, r])
  70. for x, r in zip(Ys, Ys_references)]
  71. # Divide incoming relevance by the activations.
  72. tmp = [ilayers.SafeDivide()([a, b])
  73. for a, b in zip(As, Ys_differences)]
  74. # Propagate the relevance to input neurons
  75. # using the gradient.
  76. tmp = iutils.to_list(grad(Xs+Ys+tmp))
  77. # Re-weight relevance with the input values.
  78. tmp = [keras.layers.Multiply()([a, b])
  79. for a, b in zip(Xs_differences, tmp)]
  80. # only the gradient
  81. gradients = iutils.to_list(grad(Xs+Ys+As))
  82. return [switch([dx, a, g])
  83. for dx, a, g
  84. in zip(Xs_differences, tmp, gradients)]
  85. return RescaleRule, LinearRule
  86. class DeepLIFT(base.ReverseAnalyzerBase):
  87. """DeepLIFT-rescale algorithm
  88. This class implements the DeepLIFT algorithm using
  89. the rescale rule (as in DeepExplain (Ancona et.al.)).
  90. WARNING: This implementation contains bugs.
  91. :param model: A Keras model.
  92. """
  93. def __init__(self, model, *args, **kwargs):
  94. warnings.warn("This implementation contains bugs.")
  95. self._reference_inputs = kwargs.pop("reference_inputs", 0)
  96. self._approximate_gradient = kwargs.pop(
  97. "approximate_gradient", True)
  98. self._add_model_softmax_check()
  99. super(DeepLIFT, self).__init__(model, *args, **kwargs)
  100. def _prepare_model(self, model):
  101. ret = super(DeepLIFT, self)._prepare_model(model)
  102. # Store analysis input to create reference inputs.
  103. self._analysis_inputs = ret[1]
  104. return ret
  105. def _create_reference_activations(self, model):
  106. self._model_execution_trace = kgraph.trace_model_execution(model)
  107. layers, execution_list, outputs = self._model_execution_trace
  108. self._reference_activations = {}
  109. # Create references and graph inputs.
  110. tmp = kutils.broadcast_np_tensors_to_keras_tensors(
  111. model.inputs, self._reference_inputs)
  112. tmp = [K.variable(x) for x in tmp]
  113. constant_reference_inputs = [
  114. keras.layers.Input(tensor=x, shape=K.int_shape(x)[1:])
  115. for x in tmp
  116. ]
  117. for k, v in zip(model.inputs, constant_reference_inputs):
  118. self._reference_activations[k] = v
  119. for k, v in zip(self._analysis_inputs, self._analysis_inputs):
  120. self._reference_activations[k] = v
  121. # Compute intermediate states.
  122. for layer, Xs, Ys in execution_list:
  123. activations = [self._reference_activations[x] for x in Xs]
  124. if isinstance(layer, keras.layers.InputLayer):
  125. # Special case. Do nothing.
  126. next_activations = activations
  127. else:
  128. next_activations = iutils.to_list(
  129. kutils.apply(layer, activations))
  130. assert len(next_activations) == len(Ys)
  131. for k, v in zip(Ys, next_activations):
  132. self._reference_activations[k] = v
  133. return constant_reference_inputs
  134. def _create_analysis(self, model, *args, **kwargs):
  135. constant_reference_inputs = self._create_reference_activations(model)
  136. RescaleRule, LinearRule = _create_deeplift_rules(
  137. self._reference_activations, self._approximate_gradient)
  138. # Kernel layers.
  139. self._add_conditional_reverse_mapping(
  140. lambda l: kchecks.contains_kernel(l),
  141. LinearRule,
  142. name="deeplift_kernel_layer",
  143. )
  144. # Activation layers
  145. self._add_conditional_reverse_mapping(
  146. lambda l: (not kchecks.contains_kernel(l) and
  147. kchecks.contains_activation(l)),
  148. RescaleRule,
  149. name="deeplift_activation_layer",
  150. )
  151. tmp = super(DeepLIFT, self)._create_analysis(
  152. model, *args, **kwargs)
  153. if isinstance(tmp, tuple):
  154. if len(tmp) == 3:
  155. analysis_outputs, debug_outputs, constant_inputs = tmp
  156. elif len(tmp) == 2:
  157. analysis_outputs, debug_outputs = tmp
  158. constant_inputs = list()
  159. elif len(tmp) == 1:
  160. analysis_outputs = iutils.to_list(tmp[0])
  161. constant_inputs, debug_outputs = list(), list()
  162. else:
  163. raise Exception("Unexpected output from _create_analysis.")
  164. else:
  165. analysis_outputs = tmp
  166. constant_inputs, debug_outputs = list(), list()
  167. return (analysis_outputs,
  168. debug_outputs,
  169. constant_inputs+constant_reference_inputs)
  170. def _head_mapping(self, X):
  171. return keras.layers.Subtract()([X, self._reference_activations[X]])
  172. def _reverse_model(self,
  173. model,
  174. stop_analysis_at_tensors=[],
  175. return_all_reversed_tensors=False):
  176. return kgraph.reverse_model(
  177. model,
  178. reverse_mappings=self._reverse_mapping,
  179. default_reverse_mapping=self._default_reverse_mapping,
  180. head_mapping=self._head_mapping,
  181. stop_mapping_at_tensors=stop_analysis_at_tensors,
  182. verbose=self._reverse_verbose,
  183. clip_all_reversed_tensors=self._reverse_clip_values,
  184. project_bottleneck_tensors=self._reverse_project_bottleneck_layers,
  185. return_all_reversed_tensors=return_all_reversed_tensors,
  186. execution_trace=self._model_execution_trace)
  187. def _get_state(self):
  188. state = super(DeepLIFT, self)._get_state()
  189. state.update({"reference_inputs": self._reference_inputs})
  190. state.update({"approximate_gradient": self._approximate_gradient})
  191. return state
  192. @classmethod
  193. def _state_to_kwargs(clazz, state):
  194. reference_inputs = state.pop("reference_inputs")
  195. approximate_gradient = state.pop("approximate_gradient")
  196. kwargs = super(DeepLIFT, clazz)._state_to_kwargs(state)
  197. kwargs.update({"reference_inputs": reference_inputs})
  198. kwargs.update({"approximate_gradient": approximate_gradient})
  199. return kwargs
  200. ###############################################################################
  201. ###############################################################################
  202. ###############################################################################
  203. class DeepLIFTWrapper(base.AnalyzerNetworkBase):
  204. """Wrapper around DeepLIFT package
  205. This class wraps the DeepLIFT package.
  206. For further explanation of the parameters check out:
  207. https://github.com/kundajelab/deeplift
  208. :param model: A Keras model.
  209. :param nonlinear_mode: The nonlinear mode parameter.
  210. :param reference_inputs: The reference input used for DeepLIFT.
  211. :param verbose: Verbosity of the DeepLIFT package.
  212. :note: Requires the deeplift package.
  213. """
  214. def __init__(self, model, **kwargs):
  215. self._nonlinear_mode = kwargs.pop("nonlinear_mode", "rescale")
  216. self._reference_inputs = kwargs.pop("reference_inputs", 0)
  217. self._verbose = kwargs.pop("verbose", False)
  218. #relevant for "index" selection mode
  219. self._batch_size = kwargs.pop("batch_size", 32)
  220. self._add_model_softmax_check()
  221. try:
  222. self._deeplift_module = importlib.import_module("deeplift")
  223. except ImportError:
  224. raise ImportError("To use DeepLIFTWrapper please install "
  225. "the python module 'deeplift', e.g.: "
  226. "'pip install deeplift'")
  227. super(DeepLIFTWrapper, self).__init__(model, **kwargs)
  228. def _create_deep_lift_func(self):
  229. # Store model and load into deeplift format.
  230. kc = importlib.import_module("deeplift.conversion.kerasapi_conversion")
  231. modes = self._deeplift_module.layers.NonlinearMxtsMode
  232. key = self._nonlinear_mode
  233. nonlinear_mxts_mode = {
  234. "genomics_default": modes.DeepLIFT_GenomicsDefault,
  235. "reveal_cancel": modes.RevealCancel,
  236. "rescale": modes.Rescale,
  237. }[key]
  238. with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
  239. self._model.save(f.name)
  240. deeplift_model = kc.convert_model_from_saved_files(
  241. f.name, nonlinear_mxts_mode=nonlinear_mxts_mode,
  242. verbose=self._verbose)
  243. # Create function with respect to input layers
  244. def fix_name(s):
  245. return s.replace(":", "_")
  246. score_layer_names = [fix_name(l.name) for l in self._model.inputs]
  247. if len(self._model.outputs) > 1:
  248. raise ValueError("Only a single output layer is supported.")
  249. tmp = self._model.outputs[0]._keras_history
  250. target_layer_name = fix_name(tmp[0].name+"_%i" % tmp[1])
  251. self._func = deeplift_model.get_target_contribs_func(
  252. find_scores_layer_name=score_layer_names,
  253. pre_activation_target_layer_name=target_layer_name)
  254. self._references = kutils.broadcast_np_tensors_to_keras_tensors(
  255. self._model.inputs, self._reference_inputs)
  256. def _analyze_with_deeplift(self, X, neuron_idx, batch_size):
  257. return self._func(task_idx=neuron_idx,
  258. input_data_list=X,
  259. batch_size=batch_size,
  260. input_references_list=self._references,
  261. progress_update=None)
  262. def analyze(self, X, neuron_selection=None):
  263. if not hasattr(self, "_func"):
  264. self._create_deep_lift_func()
  265. X = iutils.to_list(X)
  266. if(neuron_selection is not None and
  267. self._neuron_selection_mode != "index"):
  268. raise ValueError("Only neuron_selection_mode 'index' expects "
  269. "the neuron_selection parameter.")
  270. if(neuron_selection is None and
  271. self._neuron_selection_mode == "index"):
  272. raise ValueError("neuron_selection_mode 'index' expects "
  273. "the neuron_selection parameter.")
  274. if self._neuron_selection_mode == "index":
  275. neuron_selection = np.asarray(neuron_selection).flatten()
  276. if neuron_selection.size != 1:
  277. # The code allows to select multiple neurons.
  278. raise ValueError("One neuron can be selected with DeepLIFT.")
  279. neuron_idx = neuron_selection[0]
  280. analysis = self._analyze_with_deeplift(X, neuron_idx, self._batch_size)
  281. # Parse the output.
  282. ret = []
  283. for x, analysis_for_x in zip(X, analysis):
  284. tmp = np.stack([a for a in analysis_for_x])
  285. tmp = tmp.reshape(x.shape)
  286. ret.append(tmp)
  287. elif self._neuron_selection_mode == "max_activation":
  288. neuron_idx = np.argmax(self._model.predict_on_batch(X), axis=1)
  289. analysis = []
  290. # run for each batch with its respective max activated neuron
  291. for i, ni in enumerate(neuron_idx):
  292. # slice input tensors
  293. tmp = [x[i:i+1] for x in X]
  294. analysis.append(self._analyze_with_deeplift(tmp, ni, 1))
  295. # Parse the output.
  296. ret = []
  297. for i, x in enumerate(X):
  298. tmp = np.stack([a[i] for a in analysis]).reshape(x.shape)
  299. ret.append(tmp)
  300. else:
  301. raise ValueError("Only neuron_selection_mode index or "
  302. "max_activation are supported.")
  303. if isinstance(ret, list) and len(ret) == 1:
  304. ret = ret[0]
  305. return ret
  306. def _get_state(self):
  307. state = super(DeepLIFTWrapper, self)._get_state()
  308. state.update({"nonlinear_mode": self._nonlinear_mode})
  309. state.update({"reference_inputs": self._reference_inputs})
  310. state.update({"verbose": self._verbose})
  311. return state
  312. @classmethod
  313. def _state_to_kwargs(clazz, state):
  314. nonlinear_mode = state.pop("nonlinear_mode")
  315. reference_inputs = state.pop("reference_inputs")
  316. verbose = state.pop("verbose")
  317. kwargs = super(DeepLIFTWrapper, clazz)._state_to_kwargs(state)
  318. kwargs.update({
  319. "nonlinear_mode": nonlinear_mode,
  320. "reference_inputs": reference_inputs,
  321. "verbose": verbose,
  322. })
  323. return kwargs