base.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import zip
  5. import six
  6. ###############################################################################
  7. ###############################################################################
  8. ###############################################################################
  9. import keras.backend as K
  10. import keras.layers
  11. import keras.models
  12. import numpy as np
  13. import warnings
  14. from .. import layers as ilayers
  15. from .. import utils as iutils
  16. from ..utils.keras import checks as kchecks
  17. from ..utils.keras import graph as kgraph
  18. __all__ = [
  19. "NotAnalyzeableModelException",
  20. "AnalyzerBase",
  21. "TrainerMixin",
  22. "OneEpochTrainerMixin",
  23. "AnalyzerNetworkBase",
  24. "ReverseAnalyzerBase"
  25. ]
  26. ###############################################################################
  27. ###############################################################################
  28. ###############################################################################
  29. class NotAnalyzeableModelException(Exception):
  30. """Indicates that the model cannot be analyzed by an analyzer."""
  31. pass
  32. class AnalyzerBase(object):
  33. """ The basic interface of an iNNvestigate analyzer.
  34. This class defines the basic interface for analyzers:
  35. >>> model = create_keras_model()
  36. >>> a = Analyzer(model)
  37. >>> a.fit(X_train) # If analyzer needs training.
  38. >>> analysis = a.analyze(X_test)
  39. >>>
  40. >>> state = a.save()
  41. >>> a_new = A.load(*state)
  42. >>> analysis = a_new.analyze(X_test)
  43. :param model: A Keras model.
  44. :param disable_model_checks: Do not execute model checks that enforce
  45. compatibility of analyzer and model.
  46. .. note:: To develop a new analyzer derive from
  47. :class:`AnalyzerNetworkBase`.
  48. """
  49. def __init__(self, model, disable_model_checks=False):
  50. self._model = model
  51. self._disable_model_checks = disable_model_checks
  52. self._do_model_checks()
  53. def _add_model_check(self, check, message, check_type="exception"):
  54. if getattr(self, "_model_check_done", False):
  55. raise Exception("Cannot add model check anymore."
  56. " Check was already performed.")
  57. if not hasattr(self, "_model_checks"):
  58. self._model_checks = []
  59. check_instance = {
  60. "check": check,
  61. "message": message,
  62. "type": check_type,
  63. }
  64. self._model_checks.append(check_instance)
  65. def _do_model_checks(self):
  66. model_checks = getattr(self, "_model_checks", [])
  67. if not self._disable_model_checks and len(model_checks) > 0:
  68. check = [x["check"] for x in model_checks]
  69. types = [x["type"] for x in model_checks]
  70. messages = [x["message"] for x in model_checks]
  71. checked = kgraph.model_contains(self._model, check)
  72. tmp = zip(iutils.to_list(checked), messages, types)
  73. for checked_layers, message, check_type in tmp:
  74. if len(checked_layers) > 0:
  75. tmp_message = ("%s\nCheck triggerd by layers: %s" %
  76. (message, checked_layers))
  77. if check_type == "exception":
  78. raise NotAnalyzeableModelException(tmp_message)
  79. elif check_type == "warning":
  80. # TODO(albermax) only the first warning will be shown
  81. warnings.warn(tmp_message)
  82. else:
  83. raise NotImplementedError()
  84. self._model_check_done = True
  85. def fit(self, *args, **kwargs):
  86. """
  87. Stub that eats arguments. If an analyzer needs training
  88. include :class:`TrainerMixin`.
  89. :param disable_no_training_warning: Do not warn if this function is
  90. called despite no training is needed.
  91. """
  92. disable_no_training_warning = kwargs.pop("disable_no_training_warning",
  93. False)
  94. if not disable_no_training_warning:
  95. # issue warning if not training is foreseen,
  96. # but is fit is still called.
  97. warnings.warn("This analyzer does not need to be trained."
  98. " Still fit() is called.", RuntimeWarning)
  99. def fit_generator(self, *args, **kwargs):
  100. """
  101. Stub that eats arguments. If an analyzer needs training
  102. include :class:`TrainerMixin`.
  103. :param disable_no_training_warning: Do not warn if this function is
  104. called despite no training is needed.
  105. """
  106. disable_no_training_warning = kwargs.pop("disable_no_training_warning",
  107. False)
  108. if not disable_no_training_warning:
  109. # issue warning if not training is foreseen,
  110. # but is fit is still called.
  111. warnings.warn("This analyzer does not need to be trained."
  112. " Still fit_generator() is called.", RuntimeWarning)
  113. def analyze(self, X):
  114. """
  115. Analyze the behavior of model on input `X`.
  116. :param X: Input as expected by model.
  117. """
  118. raise NotImplementedError()
  119. def _get_state(self):
  120. state = {
  121. "model_json": self._model.to_json(),
  122. "model_weights": self._model.get_weights(),
  123. "disable_model_checks": self._disable_model_checks,
  124. }
  125. return state
  126. def save(self):
  127. """
  128. Save state of analyzer, can be passed to :func:`Analyzer.load`
  129. to resemble the analyzer.
  130. :return: The class name and the state.
  131. """
  132. state = self._get_state()
  133. class_name = self.__class__.__name__
  134. return class_name, state
  135. def save_npz(self, fname):
  136. """
  137. Save state of analyzer, can be passed to :func:`Analyzer.load_npz`
  138. to resemble the analyzer.
  139. :param fname: The file's name.
  140. """
  141. class_name, state = self.save()
  142. np.savez(fname, **{"class_name": class_name,
  143. "state": state})
  144. @classmethod
  145. def _state_to_kwargs(clazz, state):
  146. model_json = state.pop("model_json")
  147. model_weights = state.pop("model_weights")
  148. disable_model_checks = state.pop("disable_model_checks")
  149. assert len(state) == 0
  150. model = keras.models.model_from_json(model_json)
  151. model.set_weights(model_weights)
  152. return {"model": model,
  153. "disable_model_checks": disable_model_checks}
  154. @staticmethod
  155. def load(class_name, state):
  156. """
  157. Resembles an analyzer from the state created by
  158. :func:`analyzer.save()`.
  159. :param class_name: The analyzer's class name.
  160. :param state: The analyzer's state.
  161. """
  162. # Todo:do in a smarter way!
  163. import innvestigate.analyzer
  164. clazz = getattr(innvestigate.analyzer, class_name)
  165. kwargs = clazz._state_to_kwargs(state)
  166. return clazz(**kwargs)
  167. @staticmethod
  168. def load_npz(fname):
  169. """
  170. Resembles an analyzer from the file created by
  171. :func:`analyzer.save_npz()`.
  172. :param fname: The file's name.
  173. """
  174. f = np.load(fname)
  175. class_name = f["class_name"].item()
  176. state = f["state"].item()
  177. return AnalyzerBase.load(class_name, state)
  178. ###############################################################################
  179. ###############################################################################
  180. ###############################################################################
  181. class TrainerMixin(object):
  182. """Mixin for analyzer that adapt to data.
  183. This convenience interface exposes a Keras like training routing
  184. to the user.
  185. """
  186. # todo: extend with Y
  187. def fit(self,
  188. X=None,
  189. batch_size=32,
  190. **kwargs):
  191. """
  192. Takes the same parameters as Keras's :func:`model.fit` function.
  193. """
  194. generator = iutils.BatchSequence(X, batch_size)
  195. return self._fit_generator(generator,
  196. **kwargs)
  197. def fit_generator(self, *args, **kwargs):
  198. """
  199. Takes the same parameters as Keras's :func:`model.fit_generator`
  200. function.
  201. """
  202. return self._fit_generator(*args, **kwargs)
  203. def _fit_generator(self,
  204. generator,
  205. steps_per_epoch=None,
  206. epochs=1,
  207. max_queue_size=10,
  208. workers=1,
  209. use_multiprocessing=False,
  210. verbose=0,
  211. disable_no_training_warning=None):
  212. raise NotImplementedError()
  213. class OneEpochTrainerMixin(TrainerMixin):
  214. """Exposes the same interface and functionality as :class:`TrainerMixin`
  215. except that the training is limited to one epoch.
  216. """
  217. def fit(self, *args, **kwargs):
  218. """
  219. Same interface as :func:`fit` of :class:`TrainerMixin` except that
  220. the parameter epoch is fixed to 1.
  221. """
  222. return super(OneEpochTrainerMixin, self).fit(*args, epochs=1, **kwargs)
  223. def fit_generator(self, *args, **kwargs):
  224. """
  225. Same interface as :func:`fit_generator` of :class:`TrainerMixin` except that
  226. the parameter epoch is fixed to 1.
  227. """
  228. steps = kwargs.pop("steps", None)
  229. return super(OneEpochTrainerMixin, self).fit_generator(
  230. *args,
  231. steps_per_epoch=steps,
  232. epochs=1,
  233. **kwargs)
  234. ###############################################################################
  235. ###############################################################################
  236. ###############################################################################
  237. class AnalyzerNetworkBase(AnalyzerBase):
  238. """Convenience interface for analyzers.
  239. This class provides helpful functionality to create analyzer's.
  240. Basically it:
  241. * takes the input model and adds a layer that selects
  242. the desired output neuron to analyze.
  243. * passes the new model to :func:`_create_analysis` which should
  244. return the analysis as Keras tensors.
  245. * compiles the function and serves the output to :func:`analyze` calls.
  246. * allows :func:`_create_analysis` to return tensors
  247. that are intercept for debugging purposes.
  248. :param neuron_selection_mode: How to select the neuron to analyze.
  249. Possible values are 'max_activation', 'index' for the neuron
  250. (expects indices at :func:`analyze` calls), 'all' take all neurons.
  251. :param allow_lambda_layers: Allow the model to contain lambda layers.
  252. """
  253. def __init__(self, model,
  254. neuron_selection_mode="max_activation",
  255. allow_lambda_layers=False,
  256. **kwargs):
  257. if neuron_selection_mode not in ["max_activation", "index", "all"]:
  258. raise ValueError("neuron_selection parameter is not valid.")
  259. self._neuron_selection_mode = neuron_selection_mode
  260. self._allow_lambda_layers = allow_lambda_layers
  261. self._add_model_check(
  262. lambda layer: (not self._allow_lambda_layers and
  263. isinstance(layer, keras.layers.core.Lambda)),
  264. ("Lamda layers are not allowed. "
  265. "To force use set allow_lambda_layers parameter."),
  266. check_type="exception",
  267. )
  268. self._special_helper_layers = []
  269. super(AnalyzerNetworkBase, self).__init__(model, **kwargs)
  270. def _add_model_softmax_check(self):
  271. """
  272. Adds check that prevents models from containing a softmax.
  273. """
  274. self._add_model_check(
  275. lambda layer: kchecks.contains_activation(
  276. layer, activation="softmax"),
  277. "This analysis method does not support softmax layers.",
  278. check_type="exception",
  279. )
  280. def _prepare_model(self, model):
  281. """
  282. Prepares the model to analyze before it gets actually analyzed.
  283. This class adds the code to select a specific output neuron.
  284. """
  285. neuron_selection_mode = self._neuron_selection_mode
  286. model_inputs = model.inputs
  287. model_output = model.outputs
  288. if len(model_output) > 1:
  289. raise ValueError("Only models with one output tensor are allowed.")
  290. analysis_inputs = []
  291. stop_analysis_at_tensors = []
  292. # Flatten to form (batch_size, other_dimensions):
  293. if K.ndim(model_output[0]) > 2:
  294. model_output = keras.layers.Flatten()(model_output)
  295. if neuron_selection_mode == "max_activation":
  296. l = ilayers.Max(name="iNNvestigate_max")
  297. model_output = l(model_output)
  298. self._special_helper_layers.append(l)
  299. elif neuron_selection_mode == "index":
  300. neuron_indexing = keras.layers.Input(
  301. batch_shape=[None, None], dtype=np.int32,
  302. name='iNNvestigate_neuron_indexing')
  303. self._special_helper_layers.append(
  304. neuron_indexing._keras_history[0])
  305. analysis_inputs.append(neuron_indexing)
  306. # The indexing tensor should not be analyzed.
  307. stop_analysis_at_tensors.append(neuron_indexing)
  308. l = ilayers.GatherND(name="iNNvestigate_gather_nd")
  309. model_output = l(model_output+[neuron_indexing])
  310. self._special_helper_layers.append(l)
  311. elif neuron_selection_mode == "all":
  312. pass
  313. else:
  314. raise NotImplementedError()
  315. model = keras.models.Model(inputs=model_inputs+analysis_inputs,
  316. outputs=model_output)
  317. return model, analysis_inputs, stop_analysis_at_tensors
  318. def create_analyzer_model(self):
  319. """
  320. Creates the analyze functionality. If not called beforehand
  321. it will be called by :func:`analyze`.
  322. """
  323. model_inputs = self._model.inputs
  324. tmp = self._prepare_model(self._model)
  325. model, analysis_inputs, stop_analysis_at_tensors = tmp
  326. self._analysis_inputs = analysis_inputs
  327. self._prepared_model = model
  328. tmp = self._create_analysis(
  329. model, stop_analysis_at_tensors=stop_analysis_at_tensors)
  330. if isinstance(tmp, tuple):
  331. if len(tmp) == 3:
  332. analysis_outputs, debug_outputs, constant_inputs = tmp
  333. elif len(tmp) == 2:
  334. analysis_outputs, debug_outputs = tmp
  335. constant_inputs = list()
  336. elif len(tmp) == 1:
  337. analysis_outputs = iutils.to_list(tmp[0])
  338. constant_inputs, debug_outputs = list(), list()
  339. else:
  340. raise Exception("Unexpected output from _create_analysis.")
  341. else:
  342. analysis_outputs = tmp
  343. constant_inputs, debug_outputs = list(), list()
  344. analysis_outputs = iutils.to_list(analysis_outputs)
  345. debug_outputs = iutils.to_list(debug_outputs)
  346. constant_inputs = iutils.to_list(constant_inputs)
  347. self._n_data_input = len(model_inputs)
  348. self._n_constant_input = len(constant_inputs)
  349. self._n_data_output = len(analysis_outputs)
  350. self._n_debug_output = len(debug_outputs)
  351. self._analyzer_model = keras.models.Model(
  352. inputs=model_inputs+analysis_inputs+constant_inputs,
  353. outputs=analysis_outputs+debug_outputs)
  354. def _create_analysis(self, model, stop_analysis_at_tensors=[]):
  355. """
  356. Interface that needs to be implemented by a derived class.
  357. This function is expected to create a Keras graph that creates
  358. a custom analysis for the model inputs given the model outputs.
  359. :param model: Target of analysis.
  360. :param stop_analysis_at_tensors: A list of tensors where to stop the
  361. analysis. Similar to stop_gradient arguments when computing the
  362. gradient of a graph.
  363. :return: Either one-, two- or three-tuple of lists of tensors.
  364. * The first list of tensors represents the analysis for each
  365. model input tensor. Tensors present in stop_analysis_at_tensors
  366. should be omitted.
  367. * The second list, if present, is a list of debug tensors that will
  368. be passed to :func:`_handle_debug_output` after the analysis
  369. is executed.
  370. * The third list, if present, is a list of constant input tensors
  371. added to the analysis model.
  372. """
  373. raise NotImplementedError()
  374. def _handle_debug_output(self, debug_values):
  375. raise NotImplementedError()
  376. def analyze(self, X, neuron_selection=None):
  377. """
  378. Same interface as :class:`Analyzer` besides
  379. :param neuron_selection: If neuron_selection_mode is 'index' this
  380. should be an integer with the index for the chosen neuron.
  381. """
  382. if not hasattr(self, "_analyzer_model"):
  383. self.create_analyzer_model()
  384. X = iutils.to_list(X)
  385. if(neuron_selection is not None and
  386. self._neuron_selection_mode != "index"):
  387. raise ValueError("Only neuron_selection_mode 'index' expects "
  388. "the neuron_selection parameter.")
  389. if(neuron_selection is None and
  390. self._neuron_selection_mode == "index"):
  391. raise ValueError("neuron_selection_mode 'index' expects "
  392. "the neuron_selection parameter.")
  393. if self._neuron_selection_mode == "index":
  394. neuron_selection = np.asarray(neuron_selection).flatten()
  395. if neuron_selection.size == 1:
  396. neuron_selection = np.repeat(neuron_selection, len(X[0]))
  397. # Add first axis indices for gather_nd
  398. neuron_selection = np.hstack(
  399. (np.arange(len(neuron_selection)).reshape((-1, 1)),
  400. neuron_selection.reshape((-1, 1)))
  401. )
  402. ret = self._analyzer_model.predict_on_batch(X+[neuron_selection])
  403. else:
  404. ret = self._analyzer_model.predict_on_batch(X)
  405. if self._n_debug_output > 0:
  406. self._handle_debug_output(ret[-self._n_debug_output:])
  407. ret = ret[:-self._n_debug_output]
  408. if isinstance(ret, list) and len(ret) == 1:
  409. ret = ret[0]
  410. return ret
  411. def _get_state(self):
  412. state = super(AnalyzerNetworkBase, self)._get_state()
  413. state.update({"neuron_selection_mode": self._neuron_selection_mode})
  414. state.update({"allow_lambda_layers": self._allow_lambda_layers})
  415. return state
  416. @classmethod
  417. def _state_to_kwargs(clazz, state):
  418. neuron_selection_mode = state.pop("neuron_selection_mode")
  419. allow_lambda_layers = state.pop("allow_lambda_layers")
  420. kwargs = super(AnalyzerNetworkBase, clazz)._state_to_kwargs(state)
  421. kwargs.update({
  422. "neuron_selection_mode": neuron_selection_mode,
  423. "allow_lambda_layers": allow_lambda_layers
  424. })
  425. return kwargs
  426. class ReverseAnalyzerBase(AnalyzerNetworkBase):
  427. """Convenience class for analyzers that revert the model's structure.
  428. This class contains many helper functions around the graph
  429. reverse function :func:`innvestigate.utils.keras.graph.reverse_model`.
  430. The deriving classes should specify how the graph should be reverted
  431. by implementing the following functions:
  432. * :func:`_reverse_mapping(layer)` given a layer this function
  433. returns a reverse mapping for the layer as specified in
  434. :func:`innvestigate.utils.keras.graph.reverse_model` or None.
  435. This function can be implemented, but it is encouraged to
  436. implement a default mapping and add additional changes with
  437. the function :func:`_add_conditional_reverse_mapping` (see below).
  438. The default behavior is finding a conditional mapping (see below),
  439. if none is found, :func:`_default_reverse_mapping` is applied.
  440. * :func:`_default_reverse_mapping` defines the default
  441. reverse mapping.
  442. * :func:`_head_mapping` defines how the outputs of the model
  443. should be instantiated before the are passed to the reversed
  444. network.
  445. Furthermore other parameters of the function
  446. :func:`innvestigate.utils.keras.graph.reverse_model` can
  447. be changed by setting the according parameters of the
  448. init function:
  449. :param reverse_verbose: Print information on the reverse process.
  450. :param reverse_clip_values: Clip the values that are passed along
  451. the reverted network. Expects tuple (min, max).
  452. :param reverse_project_bottleneck_layers: Project the value range
  453. of bottleneck tensors in the reverse network into another range.
  454. :param reverse_check_min_max_values: Print the min/max values
  455. observed in each tensor along the reverse network whenever
  456. :func:`analyze` is called.
  457. :param reverse_check_finite: Check if values passed along the
  458. reverse network are finite.
  459. :param reverse_keep_tensors: Keeps the tensors created in the
  460. backward pass and stores them in the attribute
  461. :attr:`_reversed_tensors`.
  462. :param reverse_reapply_on_copied_layers: See
  463. :func:`innvestigate.utils.keras.graph.reverse_model`.
  464. """
  465. def __init__(self,
  466. model,
  467. reverse_verbose=False,
  468. reverse_clip_values=False,
  469. reverse_project_bottleneck_layers=False,
  470. reverse_check_min_max_values=False,
  471. reverse_check_finite=False,
  472. reverse_keep_tensors=False,
  473. reverse_reapply_on_copied_layers=False,
  474. **kwargs):
  475. self._reverse_verbose = reverse_verbose
  476. self._reverse_clip_values = reverse_clip_values
  477. self._reverse_project_bottleneck_layers = (
  478. reverse_project_bottleneck_layers)
  479. self._reverse_check_min_max_values = reverse_check_min_max_values
  480. self._reverse_check_finite = reverse_check_finite
  481. self._reverse_keep_tensors = reverse_keep_tensors
  482. self._reverse_reapply_on_copied_layers = (
  483. reverse_reapply_on_copied_layers)
  484. super(ReverseAnalyzerBase, self).__init__(model, **kwargs)
  485. def _gradient_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state):
  486. mask = [x not in reverse_state["stop_mapping_at_tensors"] for x in Xs]
  487. return ilayers.GradientWRT(len(Xs), mask=mask)(Xs+Ys+reversed_Ys)
  488. def _reverse_mapping(self, layer):
  489. """
  490. This function should return a reverse mapping for the passed layer.
  491. If this function returns None, :func:`_default_reverse_mapping`
  492. is applied.
  493. :param layer: The layer for which a mapping should be returned.
  494. :return: The mapping can be of the following forms:
  495. * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state)
  496. that maps reversed_Ys to reversed_Xs (which should contain
  497. tensors of the same shape and type).
  498. * A function of form f(B) f(layer, reverse_state) that returns
  499. a function of form (A).
  500. * A :class:`ReverseMappingBase` subclass.
  501. """
  502. if layer in self._special_helper_layers:
  503. # Special layers added by AnalyzerNetworkBase
  504. # that should not be exposed to user.
  505. return self._gradient_reverse_mapping
  506. return self._apply_conditional_reverse_mappings(layer)
  507. def _add_conditional_reverse_mapping(
  508. self, condition, mapping, priority=-1, name=None):
  509. """
  510. This function should return a reverse mapping for the passed layer.
  511. If this function returns None, :func:`_default_reverse_mapping`
  512. is applied.
  513. :param condition: Condition when this mapping should be applied.
  514. Form: f(layer) -> bool
  515. :param mapping: The mapping can be of the following forms:
  516. * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state)
  517. that maps reversed_Ys to reversed_Xs (which should contain
  518. tensors of the same shape and type).
  519. * A function of form f(B) f(layer, reverse_state) that returns
  520. a function of form (A).
  521. * A :class:`ReverseMappingBase` subclass.
  522. :param priority: The higher the earlier the condition gets
  523. evaluated.
  524. :param name: An identifying name.
  525. """
  526. if getattr(self, "_reverse_mapping_applied", False):
  527. raise Exception("Cannot add conditional mapping "
  528. "after first application.")
  529. if not hasattr(self, "_conditional_reverse_mappings"):
  530. self._conditional_reverse_mappings = {}
  531. if priority not in self._conditional_reverse_mappings:
  532. self._conditional_reverse_mappings[priority] = []
  533. tmp = {"condition": condition, "mapping": mapping, "name": name}
  534. self._conditional_reverse_mappings[priority].append(tmp)
  535. def _apply_conditional_reverse_mappings(self, layer):
  536. mappings = getattr(self, "_conditional_reverse_mappings", {})
  537. self._reverse_mapping_applied = True
  538. # Search for mapping. First consider ones with highest priority,
  539. # inside priority in order of adding.
  540. sorted_keys = sorted(mappings.keys())[::-1]
  541. for key in sorted_keys:
  542. for mapping in mappings[key]:
  543. if mapping["condition"](layer):
  544. return mapping["mapping"]
  545. return None
  546. def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state):
  547. """
  548. Fallback function to map reversed_Ys to reversed_Xs
  549. (which should contain tensors of the same shape and type).
  550. """
  551. return self._gradient_reverse_mapping(
  552. Xs, Ys, reversed_Ys, reverse_state)
  553. def _head_mapping(self, X):
  554. """
  555. Map output tensors to new values before passing
  556. them into the reverted network.
  557. """
  558. return X
  559. def _postprocess_analysis(self, X):
  560. return X
  561. def _reverse_model(self,
  562. model,
  563. stop_analysis_at_tensors=[],
  564. return_all_reversed_tensors=False):
  565. return kgraph.reverse_model(
  566. model,
  567. reverse_mappings=self._reverse_mapping,
  568. default_reverse_mapping=self._default_reverse_mapping,
  569. head_mapping=self._head_mapping,
  570. stop_mapping_at_tensors=stop_analysis_at_tensors,
  571. verbose=self._reverse_verbose,
  572. clip_all_reversed_tensors=self._reverse_clip_values,
  573. project_bottleneck_tensors=self._reverse_project_bottleneck_layers,
  574. return_all_reversed_tensors=return_all_reversed_tensors)
  575. def _create_analysis(self, model, stop_analysis_at_tensors=[]):
  576. return_all_reversed_tensors = (
  577. self._reverse_check_min_max_values or
  578. self._reverse_check_finite or
  579. self._reverse_keep_tensors
  580. )
  581. ret = self._reverse_model(
  582. model,
  583. stop_analysis_at_tensors=stop_analysis_at_tensors,
  584. return_all_reversed_tensors=return_all_reversed_tensors)
  585. if return_all_reversed_tensors:
  586. ret = (self._postprocess_analysis(ret[0]), ret[1])
  587. else:
  588. ret = self._postprocess_analysis(ret)
  589. if return_all_reversed_tensors:
  590. debug_tensors = []
  591. self._debug_tensors_indices = {}
  592. values = list(six.itervalues(ret[1]))
  593. mapping = {i: v["id"] for i, v in enumerate(values)}
  594. tensors = [v["final_tensor"] for v in values]
  595. self._reverse_tensors_mapping = mapping
  596. if self._reverse_check_min_max_values:
  597. tmp = [ilayers.Min(None)(x) for x in tensors]
  598. self._debug_tensors_indices["min"] = (
  599. len(debug_tensors),
  600. len(debug_tensors)+len(tmp))
  601. debug_tensors += tmp
  602. tmp = [ilayers.Max(None)(x) for x in tensors]
  603. self._debug_tensors_indices["max"] = (
  604. len(debug_tensors),
  605. len(debug_tensors)+len(tmp))
  606. debug_tensors += tmp
  607. if self._reverse_check_finite:
  608. tmp = iutils.to_list(ilayers.FiniteCheck()(tensors))
  609. self._debug_tensors_indices["finite"] = (
  610. len(debug_tensors),
  611. len(debug_tensors)+len(tmp))
  612. debug_tensors += tmp
  613. if self._reverse_keep_tensors:
  614. self._debug_tensors_indices["keep"] = (
  615. len(debug_tensors),
  616. len(debug_tensors)+len(tensors))
  617. debug_tensors += tensors
  618. ret = (ret[0], debug_tensors)
  619. return ret
  620. def _handle_debug_output(self, debug_values):
  621. if self._reverse_check_min_max_values:
  622. indices = self._debug_tensors_indices["min"]
  623. tmp = debug_values[indices[0]:indices[1]]
  624. tmp = sorted([(self._reverse_tensors_mapping[i], v)
  625. for i, v in enumerate(tmp)])
  626. print("Minimum values in tensors: "
  627. "((NodeID, TensorID), Value) - {}".format(tmp))
  628. indices = self._debug_tensors_indices["max"]
  629. tmp = debug_values[indices[0]:indices[1]]
  630. tmp = sorted([(self._reverse_tensors_mapping[i], v)
  631. for i, v in enumerate(tmp)])
  632. print("Maximum values in tensors: "
  633. "((NodeID, TensorID), Value) - {}".format(tmp))
  634. if self._reverse_check_finite:
  635. indices = self._debug_tensors_indices["finite"]
  636. tmp = debug_values[indices[0]:indices[1]]
  637. nfinite_tensors = np.flatnonzero(np.asarray(tmp) > 0)
  638. if len(nfinite_tensors) > 0:
  639. nfinite_tensors = sorted([self._reverse_tensors_mapping[i]
  640. for i in nfinite_tensors])
  641. print("Not finite values found in following nodes: "
  642. "(NodeID, TensorID) - {}".format(nfinite_tensors))
  643. if self._reverse_keep_tensors:
  644. indices = self._debug_tensors_indices["keep"]
  645. tmp = debug_values[indices[0]:indices[1]]
  646. tmp = sorted([(self._reverse_tensors_mapping[i], v)
  647. for i, v in enumerate(tmp)])
  648. self._reversed_tensors = tmp
  649. def _get_state(self):
  650. state = super(ReverseAnalyzerBase, self)._get_state()
  651. state.update({"reverse_verbose": self._reverse_verbose})
  652. state.update({"reverse_clip_values": self._reverse_clip_values})
  653. state.update({"reverse_project_bottleneck_layers":
  654. self._reverse_project_bottleneck_layers})
  655. state.update({"reverse_check_min_max_values":
  656. self._reverse_check_min_max_values})
  657. state.update({"reverse_check_finite": self._reverse_check_finite})
  658. state.update({"reverse_keep_tensors": self._reverse_keep_tensors})
  659. state.update({"reverse_reapply_on_copied_layers":
  660. self._reverse_reapply_on_copied_layers})
  661. return state
  662. @classmethod
  663. def _state_to_kwargs(clazz, state):
  664. reverse_verbose = state.pop("reverse_verbose")
  665. reverse_clip_values = state.pop("reverse_clip_values")
  666. reverse_project_bottleneck_layers = (
  667. state.pop("reverse_project_bottleneck_layers"))
  668. reverse_check_min_max_values = (
  669. state.pop("reverse_check_min_max_values"))
  670. reverse_check_finite = state.pop("reverse_check_finite")
  671. reverse_keep_tensors = state.pop("reverse_keep_tensors")
  672. reverse_reapply_on_copied_layers = (
  673. state.pop("reverse_reapply_on_copied_layers"))
  674. kwargs = super(ReverseAnalyzerBase, clazz)._state_to_kwargs(state)
  675. kwargs.update({"reverse_verbose": reverse_verbose,
  676. "reverse_clip_values": reverse_clip_values,
  677. "reverse_project_bottleneck_layers":
  678. reverse_project_bottleneck_layers,
  679. "reverse_check_min_max_values":
  680. reverse_check_min_max_values,
  681. "reverse_check_finite": reverse_check_finite,
  682. "reverse_keep_tensors": reverse_keep_tensors,
  683. "reverse_reapply_on_copied_layers":
  684. reverse_reapply_on_copied_layers})
  685. return kwargs