graph.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import range, zip
  5. import six
  6. ###############################################################################
  7. ###############################################################################
  8. ###############################################################################
  9. import inspect
  10. import keras.backend as K
  11. import keras.engine.topology
  12. import keras.layers
  13. import keras.models
  14. import numpy as np
  15. from . import checks as kchecks
  16. from ... import layers as ilayers
  17. from ... import utils as iutils
  18. __all__ = [
  19. "get_kernel",
  20. "get_layer_inbound_count",
  21. "get_layer_outbound_count",
  22. "get_layer_neuronwise_io",
  23. "copy_layer_wo_activation",
  24. "copy_layer",
  25. "pre_softmax_tensors",
  26. "model_wo_softmax",
  27. "get_model_layers",
  28. "model_contains",
  29. "trace_model_execution",
  30. "get_model_execution_trace",
  31. "get_model_execution_graph",
  32. "print_model_execution_graph",
  33. "get_bottleneck_nodes",
  34. "get_bottleneck_tensors",
  35. "ReverseMappingBase",
  36. "reverse_model",
  37. ]
  38. ###############################################################################
  39. ###############################################################################
  40. ###############################################################################
  41. def get_kernel(layer):
  42. """Returns the kernel weights of a layer, i.e, w/o biases."""
  43. ret = [x for x in layer.get_weights() if len(x.shape) > 1]
  44. assert len(ret) == 1
  45. return ret[0]
  46. def get_input_layers(layer):
  47. """Returns all layers that created this layer's inputs."""
  48. ret = set()
  49. for node_index in range(len(layer._inbound_nodes)):
  50. Xs = iutils.to_list(layer.get_input_at(node_index))
  51. for X in Xs:
  52. ret.add(X._keras_history[0])
  53. return ret
  54. ###############################################################################
  55. ###############################################################################
  56. ###############################################################################
  57. def get_layer_inbound_count(layer):
  58. """Returns the number inbound nodes of a layer."""
  59. return len(layer._inbound_nodes)
  60. def get_layer_outbound_count(layer):
  61. """Returns the number outbound nodes of a layer."""
  62. return len(layer.outbound_nodes)
  63. def get_layer_neuronwise_io(layer,
  64. node_index=0,
  65. Xs=None,
  66. Ys=None,
  67. return_i=True,
  68. return_o=True):
  69. """Returns the input and output for each neuron in a layer
  70. Returns the symbolic input and output for each neuron in a layer.
  71. For a dense layer this is the input output itself.
  72. For convolutional layers this method extracts for each neuron
  73. the input output mapping.
  74. At the moment this function is designed
  75. to work with dense and conv2d layers.
  76. :param layer: The targeted layer.
  77. :param node_index: Index of the layer node to use.
  78. :param Xs: Ignore the layer's input but use Xs instead.
  79. :param Ys: Ignore the layer's output but use Ys instead.
  80. :param return_i: Return the inputs.
  81. :param return_o: Return the outputs.
  82. :return: Inputs and outputs, if specified, for each individual neuron.
  83. """
  84. if not kchecks.contains_kernel(layer):
  85. raise NotImplementedError()
  86. if Xs is None:
  87. Xs = iutils.to_list(layer.get_input_at(node_index))
  88. if Ys is None:
  89. Ys = iutils.to_list(layer.get_output_at(node_index))
  90. if isinstance(layer, keras.layers.Dense):
  91. # Xs and Ys are already in shape.
  92. ret_Xs = Xs
  93. ret_Ys = Ys
  94. elif isinstance(layer, keras.layers.Conv2D):
  95. kernel = get_kernel(layer)
  96. # Expect filter dimension to be last.
  97. n_channels = kernel.shape[-1]
  98. if return_i:
  99. extract_patches = ilayers.ExtractConv2DPatches(kernel.shape[:2],
  100. kernel.shape[2],
  101. layer.strides,
  102. layer.dilation_rate,
  103. layer.padding)
  104. # shape [samples, out_row, out_col, weight_size]
  105. reshape = ilayers.Reshape((-1, np.product(kernel.shape[:3])))
  106. ret_Xs = [reshape(extract_patches(x)) for x in Xs]
  107. if return_o:
  108. # Get Ys into shape (samples, channels)
  109. if K.image_data_format() == "channels_first":
  110. # Ys shape is [samples, channels, out_row, out_col]
  111. def reshape(x):
  112. x = ilayers.Transpose((0, 2, 3, 1))(x)
  113. x = ilayers.Reshape((-1, n_channels))(x)
  114. return x
  115. else:
  116. # Ys shape is [samples, out_row, out_col, channels]
  117. def reshape(x):
  118. x = ilayers.Reshape((-1, n_channels))(x)
  119. return x
  120. ret_Ys = [reshape(x) for x in Ys]
  121. else:
  122. raise NotImplementedError()
  123. # Xs is (n, d) and Ys is (d, channels)
  124. if return_i and return_o:
  125. return ret_Xs, ret_Ys
  126. elif return_i:
  127. return ret_Xs
  128. elif return_o:
  129. return ret_Ys
  130. else:
  131. raise Exception()
  132. def get_symbolic_weight_names(layer, weights=None):
  133. """Attribute names for weights
  134. Looks up the attribute names of weight tensors.
  135. :param layer: Targeted layer.
  136. :param weights: A list of weight tensors.
  137. :return: The attribute names of the weights.
  138. """
  139. if weights is None:
  140. weights = layer.weights
  141. good_guesses = [
  142. "kernel",
  143. "bias",
  144. "gamma",
  145. "beta",
  146. "moving_mean",
  147. "moving_variance",
  148. "depthwise_kernel",
  149. "pointwise_kernel"
  150. ]
  151. ret = []
  152. for weight in weights:
  153. for attr_name in good_guesses+dir(layer):
  154. if(hasattr(layer, attr_name) and
  155. id(weight) == id(getattr(layer, attr_name))):
  156. ret.append(attr_name)
  157. break
  158. if len(weights) != len(ret):
  159. raise Exception("Could not find symoblic weight name(s).")
  160. return ret
  161. def update_symbolic_weights(layer, weight_mapping):
  162. """Updates the symbolic tensors of a layer
  163. Updates the symbolic tensors of a layer by replacing them.
  164. Note this does not update the loss or anything alike!
  165. Use with caution!
  166. :param layer: Targeted layer.
  167. :param weight_mapping: Dict with attribute name and weight tensors
  168. as keys and values.
  169. """
  170. trainable_weight_ids = [id(x) for x in layer._trainable_weights]
  171. non_trainable_weight_ids = [id(x) for x in layer._non_trainable_weights]
  172. for name, weight in six.iteritems(weight_mapping):
  173. current_weight = getattr(layer, name)
  174. current_weight_id = id(current_weight)
  175. if current_weight_id in trainable_weight_ids:
  176. idx = trainable_weight_ids.index(current_weight_id)
  177. layer._trainable_weights[idx] = weight
  178. else:
  179. idx = non_trainable_weight_ids.index(current_weight_id)
  180. layer._non_trainable_weights[idx] = weight
  181. setattr(layer, name, weight)
  182. def get_layer_from_config(old_layer,
  183. new_config,
  184. weights=None,
  185. reuse_symbolic_tensors=True):
  186. """Creates a new layer from a config
  187. Creates a new layer given a changed config and weights etc.
  188. :param old_layer: A layer that shall be used as base.
  189. :param new_config: The config to create the new layer.
  190. :param weights: Weights to set in the new layer.
  191. Options: np tensors, symbolic tensors, or None,
  192. in which case the weights from old_layers are used.
  193. :param reuse_symbolic_tensors: If the weights of the
  194. old_layer are used copy the symbolic ones or copy
  195. the Numpy weights.
  196. :return: The new layer instance.
  197. """
  198. new_layer = old_layer.__class__.from_config(new_config)
  199. if weights is None:
  200. if reuse_symbolic_tensors:
  201. weights = old_layer.weights
  202. else:
  203. weights = old_layer.get_weights()
  204. if len(weights) > 0:
  205. input_shapes = old_layer.get_input_shape_at(0)
  206. # todo: inspect and set initializers to something fast for speedup
  207. new_layer.build(input_shapes)
  208. is_np_weight = [isinstance(x, np.ndarray) for x in weights]
  209. if all(is_np_weight):
  210. new_layer.set_weights(weights)
  211. else:
  212. if any(is_np_weight):
  213. raise ValueError("Expect either all weights to be "
  214. "np tensors or symbolic tensors.")
  215. symbolic_names = get_symbolic_weight_names(old_layer)
  216. update = {name: weight
  217. for name, weight in zip(symbolic_names, weights)}
  218. update_symbolic_weights(new_layer, update)
  219. return new_layer
  220. def copy_layer_wo_activation(layer,
  221. keep_bias=True,
  222. name_template=None,
  223. weights=None,
  224. reuse_symbolic_tensors=True,
  225. **kwargs):
  226. """Copy a Keras layer and remove the activations
  227. Copies a Keras layer but remove potential activations.
  228. :param layer: A layer that should be copied.
  229. :param keep_bias: Keep a potential bias.
  230. :param weights: Weights to set in the new layer.
  231. Options: np tensors, symbolic tensors, or None,
  232. in which case the weights from old_layers are used.
  233. :param reuse_symbolic_tensors: If the weights of the
  234. old_layer are used copy the symbolic ones or copy
  235. the Numpy weights.
  236. :return: The new layer instance.
  237. """
  238. config = layer.get_config()
  239. if name_template is None:
  240. config["name"] = None
  241. else:
  242. config["name"] = name_template % config["name"]
  243. if kchecks.contains_activation(layer):
  244. config["activation"] = None
  245. if hasattr(layer, "use_bias"):
  246. if keep_bias is False and config.get("use_bias", True):
  247. config["use_bias"] = False
  248. if weights is None:
  249. if reuse_symbolic_tensors:
  250. weights = layer.weights[:-1]
  251. else:
  252. weights = layer.get_weights()[:-1]
  253. return get_layer_from_config(layer, config, weights=weights, **kwargs)
  254. def copy_layer(layer,
  255. keep_bias=True,
  256. name_template=None,
  257. weights=None,
  258. reuse_symbolic_tensors=True,
  259. **kwargs):
  260. """Copy a Keras layer
  261. Copies a Keras layer.
  262. :param layer: A layer that should be copied.
  263. :param keep_bias: Keep a potential bias.
  264. :param weights: Weights to set in the new layer.
  265. Options: np tensors, symbolic tensors, or None,
  266. in which case the weights from old_layers are used.
  267. :param reuse_symbolic_tensors: If the weights of the
  268. old_layer are used copy the symbolic ones or copy
  269. the Numpy weights.
  270. :return: The new layer instance.
  271. """
  272. config = layer.get_config()
  273. if name_template is None:
  274. config["name"] = None
  275. else:
  276. config["name"] = name_template % config["name"]
  277. if hasattr(layer, "use_bias"):
  278. if keep_bias is False and config.get("use_bias", True):
  279. config["use_bias"] = False
  280. if weights is None:
  281. if reuse_symbolic_tensors:
  282. weights = layer.weights[:-1]
  283. else:
  284. weights = layer.get_weights()[:-1]
  285. return get_layer_from_config(layer, config, weights=weights, **kwargs)
  286. def pre_softmax_tensors(Xs, should_find_softmax=True):
  287. """Finds the tensors that were preceeding a potential softmax."""
  288. softmax_found = False
  289. Xs = iutils.to_list(Xs)
  290. ret = []
  291. for x in Xs:
  292. layer, node_index, tensor_index = x._keras_history
  293. if kchecks.contains_activation(layer, activation="softmax"):
  294. softmax_found = True
  295. if isinstance(layer, keras.layers.Activation):
  296. ret.append(layer.get_input_at(node_index))
  297. else:
  298. layer_wo_act = copy_layer_wo_activation(layer)
  299. ret.append(layer_wo_act(layer.get_input_at(node_index)))
  300. if should_find_softmax and not softmax_found:
  301. raise Exception("No softmax found.")
  302. return ret
  303. def model_wo_softmax(model):
  304. """Creates a new model w/o the final softmax activation."""
  305. return keras.models.Model(inputs=model.inputs,
  306. outputs=pre_softmax_tensors(model.outputs),
  307. name=model.name)
  308. ###############################################################################
  309. ###############################################################################
  310. ###############################################################################
  311. def get_model_layers(model):
  312. """Returns all layers of a model."""
  313. ret = []
  314. def collect_layers(container):
  315. for layer in container.layers:
  316. assert layer not in ret
  317. ret.append(layer)
  318. if kchecks.is_network(layer):
  319. collect_layers(layer)
  320. collect_layers(model)
  321. return ret
  322. def model_contains(model, layer_condition, return_only_counts=False):
  323. if callable(layer_condition):
  324. layer_condition = [layer_condition, ]
  325. single_condition = True
  326. else:
  327. single_condition = False
  328. layers = get_model_layers(model)
  329. collected_layers = []
  330. for condition in layer_condition:
  331. tmp = [layer for layer in layers if condition(layer)]
  332. collected_layers.append(tmp)
  333. if return_only_counts is True:
  334. collected_layers = [len(v) for v in collected_layers]
  335. if single_condition is True:
  336. return collected_layers[0]
  337. else:
  338. return collected_layers
  339. ###############################################################################
  340. ###############################################################################
  341. ###############################################################################
  342. def apply_mapping_to_fused_bn_layer(mapping, fuse_mode="one_linear"):
  343. """
  344. Applies a mapping to a linearized Batch Normalization layer.
  345. :param mapping: The mapping to be applied.
  346. Should take parameters layer and reverse_state and
  347. return a mapping function.
  348. :param fuse_mode: Either 'one_linear': apply the mapping
  349. to a once linearized layer, or
  350. 'two_linear': apply to twice to a twice linearized layer.
  351. """
  352. if fuse_mode not in ["one_linear", "two_linear"]:
  353. raise ValueError("fuse_mode can only be 'one_linear' or 'two_linear'")
  354. # todo(alber): remove this workaround and make a proper class
  355. def ScaleLayer(kernel, bias):
  356. _kernel = kernel
  357. _bias = bias
  358. class ScaleLayer(keras.layers.Layer):
  359. def __init__(self, use_bias=True, **kwargs):
  360. self._kernel_to_be = _kernel
  361. self._bias_to_be = _bias
  362. self.use_bias = use_bias
  363. super(ScaleLayer, self).__init__(**kwargs)
  364. def build(self, input_shape):
  365. self.kernel = self.add_weight(
  366. name='kernel',
  367. shape=K.int_shape(self._kernel_to_be),
  368. initializer=lambda a, b=None: self._kernel_to_be,
  369. trainable=False)
  370. if self.use_bias:
  371. self.bias = self.add_weight(
  372. name='bias',
  373. shape=K.int_shape(self._bias_to_be),
  374. initializer=lambda a, b=None: self._bias_to_be,
  375. trainable=False)
  376. super(ScaleLayer, self).build(input_shape)
  377. def call(self, x):
  378. ret = (x * self.kernel)
  379. if self.use_bias:
  380. ret += self.bias
  381. return ret
  382. def compute_output_shape(self, input_shape):
  383. return input_shape
  384. return ScaleLayer()
  385. def meta_mapping(layer, reverse_state):
  386. # get bn params
  387. weights = layer.weights[:]
  388. if layer.scale:
  389. gamma = weights.pop(0)
  390. else:
  391. gamma = K.ones_like(weights[0])
  392. if layer.center:
  393. beta = weights.pop(0)
  394. else:
  395. beta = K.zeros_like(weights[0])
  396. mean, variance = weights
  397. if fuse_mode == "one_linear":
  398. tmp = K.sqrt(variance**2 + layer.epsilon)
  399. tmp_k = gamma / tmp
  400. tmp_b = -mean / tmp + beta
  401. inputs = layer.get_input_at(0)
  402. surrogate_layer = ScaleLayer(tmp_k, tmp_b)
  403. # init layer
  404. surrogate_layer(inputs)
  405. actual_mapping = mapping(surrogate_layer, reverse_state).apply
  406. else:
  407. tmp = K.sqrt(variance**2 + layer.epsilon)
  408. tmp_k1 = 1 / tmp
  409. tmp_b1 = -mean / tmp
  410. tmp_k2 = gamma
  411. tmp_b2 = beta
  412. inputs = layer.get_input_at(0)
  413. surrogate_layer1 = ScaleLayer(tmp_k1, tmp_b1)
  414. surrogate_layer2 = ScaleLayer(tmp_k2, tmp_b2)
  415. # init layers
  416. surrogate_layer1(inputs)
  417. surrogate_layer2(inputs)
  418. # todo(alber): update reverse state
  419. actual_mapping_1 = mapping(surrogate_layer1, reverse_state).apply
  420. actual_mapping_2 = mapping(surrogate_layer2, reverse_state).apply
  421. def actual_mapping(Xs, Ys, reversed_Ys, reverse_state):
  422. from . import apply as kapply
  423. X2s = kapply(surrogate_layer1, Xs)
  424. # Apply first mapping
  425. # todo(alber): update reverse state
  426. reversed_X2s = actual_mapping_2(
  427. X2s, Ys, reversed_Ys, reverse_state)
  428. return actual_mapping_1(Xs, X2s, reversed_X2s, reverse_state)
  429. return actual_mapping
  430. return meta_mapping
  431. ###############################################################################
  432. ###############################################################################
  433. ###############################################################################
  434. def trace_model_execution(model, reapply_on_copied_layers=False):
  435. """
  436. Trace and linearize excecution of a model and it's possible containers.
  437. Return a triple with all layers, a list with a linearized execution
  438. with (layer, input_tensors, output_tensors), and, possible regenerated,
  439. outputs of the exectution.
  440. :param model: A kera model.
  441. :param reapply_on_copied_layers: If the execution needs to be linearized,
  442. reapply with copied layers. Might be slow. Prevents changes of the
  443. original layer's node lists.
  444. """
  445. # Get all layers in model.
  446. layers = get_model_layers(model)
  447. # Check if some layers are containers.
  448. # Ignoring the outermost container, i.e. the passed model.
  449. contains_container = any([((l is not model) and kchecks.is_network(l))
  450. for l in layers])
  451. # If so rebuild the graph, otherwise recycle computations,
  452. # and create executed node list. (Keep track of paths?)
  453. if contains_container is True:
  454. # When containers/models are used as layers, then layers
  455. # inside the container/model do not keep track of nodes.
  456. # This makes it impossible to iterate of the nodes list and
  457. # recover the input output tensors. (see else clause)
  458. #
  459. # To recover the computational graph we need to re-apply it.
  460. # This implies that the tensors-object we use for the forward
  461. # pass are different to the passed model. This it not the case
  462. # for the else clause.
  463. #
  464. # Note that reapplying the model does only change the inbound
  465. # and outbound nodes of the model itself. We copy the model
  466. # so the passed model should not be affected from the
  467. # reapplication.
  468. executed_nodes = []
  469. # Monkeypatch the call function in all the used layer classes.
  470. monkey_patches = [(layer, getattr(layer, "call")) for layer in layers]
  471. try:
  472. def patch(self, method):
  473. if hasattr(method, "__patched__") is True:
  474. raise Exception("Should not happen as we patch "
  475. "objects not classes.")
  476. def f(*args, **kwargs):
  477. input_tensors = args[0]
  478. output_tensors = method(*args, **kwargs)
  479. executed_nodes.append((self,
  480. input_tensors,
  481. output_tensors))
  482. return output_tensors
  483. f.__patched__ = True
  484. return f
  485. # Apply the patches.
  486. for layer in layers:
  487. setattr(layer, "call", patch(layer, getattr(layer, "call")))
  488. # Trigger reapplication of model.
  489. model_copy = keras.models.Model(inputs=model.inputs,
  490. outputs=model.outputs)
  491. outputs = iutils.to_list(model_copy(model.inputs))
  492. finally:
  493. # Revert the monkey patches
  494. for layer, old_method in monkey_patches:
  495. setattr(layer, "call", old_method)
  496. # Now we have the problem that all the tensors
  497. # do not have a keras_history attribute as they are not part
  498. # of any node. Apply the flat model to get it.
  499. from . import apply as kapply
  500. new_executed_nodes = []
  501. tensor_mapping = {tmp: tmp for tmp in model.inputs}
  502. if reapply_on_copied_layers is True:
  503. layer_mapping = {layer: copy_layer(layer) for layer in layers}
  504. else:
  505. layer_mapping = {layer: layer for layer in layers}
  506. for layer, Xs, Ys in executed_nodes:
  507. layer = layer_mapping[layer]
  508. Xs, Ys = iutils.to_list(Xs), iutils.to_list(Ys)
  509. if isinstance(layer, keras.layers.InputLayer):
  510. # Special case. Do nothing.
  511. new_Xs, new_Ys = Xs, Ys
  512. else:
  513. new_Xs = [tensor_mapping[x] for x in Xs]
  514. new_Ys = iutils.to_list(kapply(layer, new_Xs))
  515. tensor_mapping.update({k: v for k, v in zip(Ys, new_Ys)})
  516. new_executed_nodes.append((layer, new_Xs, new_Ys))
  517. layers = [layer_mapping[layer] for layer in layers]
  518. outputs = [tensor_mapping[x] for x in outputs]
  519. executed_nodes = new_executed_nodes
  520. else:
  521. # Easy and safe way.
  522. reverse_executed_nodes = [
  523. (node.outbound_layer, node.input_tensors, node.output_tensors)
  524. for depth in sorted(model._nodes_by_depth.keys())
  525. for node in model._nodes_by_depth[depth]
  526. ]
  527. outputs = model.outputs
  528. executed_nodes = reversed(reverse_executed_nodes)
  529. # This list contains potentially nodes that are not part
  530. # final execution graph.
  531. # E.g., a layer was also applied outside of the model. Then its
  532. # node list contains nodes that do not contribute to the model's output.
  533. # Those nodes are filtered here.
  534. used_as_input = [x for x in outputs]
  535. tmp = []
  536. for l, Xs, Ys in reversed(list(executed_nodes)):
  537. if all([y in used_as_input for y in Ys]):
  538. used_as_input += Xs
  539. tmp.append((l, Xs, Ys))
  540. executed_nodes = list(reversed(tmp))
  541. return layers, executed_nodes, outputs
  542. def get_model_execution_trace(model,
  543. keep_input_layers=False,
  544. reapply_on_copied_layers=False):
  545. """
  546. Returns a list representing the execution graph.
  547. Each key is the node's id as it is used by the reverse_model method.
  548. Each associated value contains a dictionary with the following items:
  549. * nid: the node id.
  550. * layer: the layer creating this node.
  551. * Xs: the input tensors (only valid if not in a nested container).
  552. * Ys: the output tensors (only valid if not in a nested container).
  553. * Xs_nids: the ids of the nodes creating the Xs.
  554. * Ys_nids: the ids of nodes using the according output tensor.
  555. * Xs_layers: the layer that created the accodring input tensor.
  556. * Ys_layers: the layers using the according output tensor.
  557. :param model: A kera model.
  558. :param keep_input_layers: Keep input layers.
  559. :param reapply_on_copied_layers: If the execution needs to be linearized,
  560. reapply with copied layers. Might be slow. Prevents changes of the
  561. original layer's node lists.
  562. """
  563. _, execution_trace, _ = trace_model_execution(
  564. model,
  565. reapply_on_copied_layers=reapply_on_copied_layers)
  566. # Enrich trace with node ids.
  567. current_nid = 0
  568. tmp = []
  569. for l, Xs, Ys in execution_trace:
  570. if isinstance(l, keras.layers.InputLayer):
  571. tmp.append((None, l, Xs, Ys))
  572. else:
  573. tmp.append((current_nid, l, Xs, Ys))
  574. current_nid += 1
  575. execution_trace = tmp
  576. # Create lookups from tensor to creating or receiving layer-node
  577. inputs_to_node = {}
  578. outputs_to_node = {}
  579. for nid, l, Xs, Ys in execution_trace:
  580. if nid is not None:
  581. for X in Xs:
  582. Xid = id(X)
  583. if Xid in inputs_to_node:
  584. inputs_to_node[Xid].append(nid)
  585. else:
  586. inputs_to_node[Xid] = [nid]
  587. if keep_input_layers or nid is not None:
  588. for Y in Ys:
  589. Yid = id(Y)
  590. if Yid in inputs_to_node:
  591. raise Exception("Cannot be more than one creating node.")
  592. outputs_to_node[Yid] = nid
  593. # Enrich trace with this info.
  594. nid_to_nodes = {t[0]: t for t in execution_trace}
  595. tmp = []
  596. for nid, l, Xs, Ys in execution_trace:
  597. if isinstance(l, keras.layers.InputLayer):
  598. # The nids that created or receive the tensors.
  599. Xs_nids = [] # Input layer does not receive.
  600. Ys_nids = [inputs_to_node[id(Y)] for Y in Ys]
  601. # The layers that created or receive the tensors.
  602. Xs_layers = [] # Input layer does not receive.
  603. Ys_layers = [[nid_to_nodes[Ynid][1] for Ynid in Ynids2]
  604. for Ynids2 in Ys_nids]
  605. else:
  606. # The nids that created or receive the tensors.
  607. Xs_nids = [outputs_to_node.get(id(X), None) for X in Xs]
  608. Ys_nids = [inputs_to_node.get(id(Y), [None]) for Y in Ys]
  609. # The layers that created or receive the tensors.
  610. Xs_layers = [nid_to_nodes[Xnid][1]
  611. for Xnid in Xs_nids if Xnid is not None]
  612. Ys_layers = [[nid_to_nodes[Ynid][1]
  613. for Ynid in Ynids2 if Ynid is not None]
  614. for Ynids2 in Ys_nids]
  615. entry = {
  616. "nid": nid,
  617. "layer": l,
  618. "Xs": Xs,
  619. "Ys": Ys,
  620. "Xs_nids": Xs_nids,
  621. "Ys_nids": Ys_nids,
  622. "Xs_layers": Xs_layers,
  623. "Ys_layers": Ys_layers,
  624. }
  625. tmp.append(entry)
  626. execution_trace = tmp
  627. if not keep_input_layers:
  628. execution_trace = [tmp
  629. for tmp in execution_trace
  630. if tmp["nid"] is not None]
  631. return execution_trace
  632. def get_model_execution_graph(model, keep_input_layers=False):
  633. """
  634. Returns a dictionary representing the execution graph.
  635. Each key is the node's id as it is used by the reverse_model method.
  636. Each associated value contains a dictionary with the following items:
  637. * nid: the node id.
  638. * layer: the layer creating this node.
  639. * Xs: the input tensors (only valid if not in a nested container).
  640. * Ys: the output tensors (only valid if not in a nested container).
  641. * Xs_nids: the ids of the nodes creating the Xs.
  642. * Ys_nids: the ids of nodes using the according output tensor.
  643. * Xs_layers: the layer that created the accodring input tensor.
  644. * Ys_layers: the layers using the according output tensor.
  645. :param model: A kera model.
  646. :param keep_input_layers: Keep input layers.
  647. """
  648. trace = get_model_execution_trace(model,
  649. keep_input_layers=keep_input_layers,
  650. reapply_on_copied_layers=False)
  651. input_layers = [tmp for tmp in trace if tmp["nid"] is None]
  652. graph = {tmp["nid"]: tmp for tmp in trace}
  653. if keep_input_layers:
  654. graph[None] = input_layers
  655. return graph
  656. def print_model_execution_graph(graph):
  657. """Pretty print of a model execution graph."""
  658. def nids_as_str(nids):
  659. return ", ".join(["%s" % nid for nid in nids])
  660. def print_node(node):
  661. print(" [NID: %4s] [Layer: %20s] "
  662. "[Inputs from: %20s] [Outputs to: %20s]" %
  663. (node["nid"],
  664. node["layer"].name,
  665. nids_as_str(node["Xs_nids"]),
  666. nids_as_str(node["Ys_nids"]),))
  667. if None in graph:
  668. print("Graph input layers:")
  669. for tmp in graph[None]:
  670. print_node(tmp)
  671. print("Graph nodes:")
  672. for nid in sorted([k for k in graph if k is not None]):
  673. if nid is None:
  674. continue
  675. print_node(graph[nid])
  676. def get_bottleneck_nodes(inputs, outputs, execution_list):
  677. """
  678. Given an execution list this function returns all nodes that
  679. are a bottleneck in the network, i.e., "all information" must pass
  680. through this node.
  681. """
  682. forward_connections = {}
  683. for l, Xs, Ys in execution_list:
  684. if isinstance(l, keras.layers.InputLayer):
  685. # Special case, do nothing.
  686. continue
  687. for x in Xs:
  688. if x in forward_connections:
  689. forward_connections[x] += Ys
  690. else:
  691. forward_connections[x] = list(Ys)
  692. open_connections = {}
  693. for x in inputs:
  694. for fw_c in forward_connections[x]:
  695. open_connections[fw_c] = True
  696. ret = list()
  697. for l, Xs, Ys in execution_list:
  698. if isinstance(l, keras.layers.InputLayer):
  699. # Special case, do nothing.
  700. # Note: if a single input branches
  701. # this is not detected.
  702. continue
  703. for y in Ys:
  704. assert y in open_connections
  705. del open_connections[y]
  706. if len(open_connections) == 0:
  707. ret.append((l, (Xs, Ys)))
  708. for y in Ys:
  709. if y not in outputs:
  710. for fw_c in forward_connections[y]:
  711. open_connections[fw_c] = True
  712. return ret
  713. def get_bottleneck_tensors(inputs, outputs, execution_list):
  714. """
  715. Given an execution list this function returns all tensors that
  716. are a bottleneck in the network, i.e., "all information" must pass
  717. through this tensor.
  718. """
  719. nodes = get_bottleneck_nodes(inputs, outputs, execution_list)
  720. ret = list()
  721. for l, (Xs, Ys) in nodes:
  722. for tensor_list in (Xs, Ys):
  723. if len(tensor_list) == 1:
  724. tensor = tensor_list[0]
  725. if tensor not in ret:
  726. ret.append(tensor)
  727. else:
  728. # TODO(albermax): put warning here?
  729. pass
  730. return ret
  731. ###############################################################################
  732. ###############################################################################
  733. ###############################################################################
  734. class ReverseMappingBase(object):
  735. def __init__(self, layer, state):
  736. pass
  737. def apply(self, Xs, Yx, reversed_Ys, reverse_state):
  738. raise NotImplementedError()
  739. def reverse_model(model, reverse_mappings,
  740. default_reverse_mapping=None,
  741. head_mapping=None,
  742. stop_mapping_at_tensors=[],
  743. verbose=False,
  744. return_all_reversed_tensors=False,
  745. clip_all_reversed_tensors=False,
  746. project_bottleneck_tensors=False,
  747. execution_trace=None,
  748. reapply_on_copied_layers=False):
  749. """
  750. Reverses a Keras model based on the given reverse functions.
  751. It returns the reverted tensors for the according model inputs.
  752. :param model: A Keras model.
  753. :param reverse_mappings: Either a callable that matches layers to
  754. mappings or a dictionary with layers as keys and mappings as values.
  755. Allowed as mapping forms are:
  756. * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state).
  757. * A function of form f(B) f(layer, reverse_state) that returns
  758. a function of form (A).
  759. * A :class:`ReverseMappingBase` subclass.
  760. :param default_reverse_mapping: A function that reverses layers for
  761. which no mapping was given by param "reverse_mappings".
  762. :param head_mapping: Map output tensors to new values before passing
  763. them into the reverted network.
  764. :param stop_mapping_at_tensors: Tensors at which to stop the mapping.
  765. Similar to stop_gradient parameters for gradient computation.
  766. :param verbose: Print what's going on.
  767. :param return_all_reversed_tensors: Return all reverted tensors in addition
  768. to reverted model input tensors.
  769. :param clip_all_reversed_tensors: Clip each reverted tensor. False or tuple
  770. with min/max value.
  771. :param project_bottleneck_tensors: Project bottleneck layers in the
  772. reverting process into a given value range. False, True or (a, b) for
  773. projection range.
  774. :param reapply_on_copied_layers: When a model execution needs to
  775. linearized and copy layers before reapplying them. See
  776. :func:`trace_model_execution`.
  777. """
  778. # Set default values ######################################################
  779. if head_mapping is None:
  780. def head_mapping(X):
  781. return X
  782. if not callable(reverse_mappings):
  783. # not callable, assume a dict that maps from layer to mapping
  784. reverse_mapping_data = reverse_mappings
  785. def reverse_mappings(layer):
  786. try:
  787. return reverse_mapping_data[type(layer)]
  788. except KeyError:
  789. return None
  790. def _print(s):
  791. if verbose is True:
  792. print(s)
  793. # Initialize structure that keeps track of reversed tensors ###############
  794. reversed_tensors = {}
  795. bottleneck_tensors = set()
  796. def add_reversed_tensors(nid,
  797. tensors_list,
  798. reversed_tensors_list):
  799. def add_reversed_tensor(i, X, reversed_X):
  800. # Do not keep tensors that should stop the mapping.
  801. if X in stop_mapping_at_tensors:
  802. return
  803. if X not in reversed_tensors:
  804. reversed_tensors[X] = {"id": (nid, i),
  805. "tensor": reversed_X}
  806. else:
  807. tmp = reversed_tensors[X]
  808. if "tensor" in tmp and "tensors" in tmp:
  809. raise Exception("Wrong order, tensors already aggregated!")
  810. if "tensor" in tmp:
  811. tmp["tensors"] = [tmp["tensor"], reversed_X]
  812. del tmp["tensor"]
  813. else:
  814. tmp["tensors"].append(reversed_X)
  815. tmp = zip(tensors_list, reversed_tensors_list)
  816. for i, (X, reversed_X) in enumerate(tmp):
  817. add_reversed_tensor(i, X, reversed_X)
  818. def get_reversed_tensor(tensor):
  819. tmp = reversed_tensors[tensor]
  820. if "final_tensor" not in tmp:
  821. if "tensor" not in tmp:
  822. final_tensor = keras.layers.Add()(tmp["tensors"])
  823. else:
  824. final_tensor = tmp["tensor"]
  825. if project_bottleneck_tensors is not False:
  826. if tensor in bottleneck_tensors:
  827. project = ilayers.Project(project_bottleneck_tensors)
  828. final_tensor = project(final_tensor)
  829. if clip_all_reversed_tensors is not False:
  830. clip = ilayers.Clip(*clip_all_reversed_tensors)
  831. final_tensor = clip(final_tensor)
  832. tmp["final_tensor"] = final_tensor
  833. return tmp["final_tensor"]
  834. # Reverse the model #######################################################
  835. _print("Reverse model: {}".format(model))
  836. # Create a list with nodes in reverse execution order.
  837. if execution_trace is None:
  838. execution_trace = trace_model_execution(
  839. model,
  840. reapply_on_copied_layers=reapply_on_copied_layers)
  841. layers, execution_list, outputs = execution_trace
  842. len_execution_list = len(execution_list)
  843. num_input_layers = len([_ for l, _, _ in execution_list
  844. if isinstance(l, keras.layers.InputLayer)])
  845. len_execution_list_wo_inputs_layers = len_execution_list - num_input_layers
  846. reverse_execution_list = reversed(execution_list)
  847. # Initialize the reverse mapping functions.
  848. initialized_reverse_mappings = {}
  849. for layer in layers:
  850. # A layer can be shared, i.e., applied several times.
  851. # Allow to share a ReverMappingBase for each layer instance
  852. # in order to reduce the overhead.
  853. meta_reverse_mapping = reverse_mappings(layer)
  854. if meta_reverse_mapping is None:
  855. reverse_mapping = default_reverse_mapping
  856. elif(inspect.isclass(meta_reverse_mapping) and
  857. issubclass(meta_reverse_mapping, ReverseMappingBase)):
  858. # Mapping is a class
  859. reverse_mapping_obj = meta_reverse_mapping(
  860. layer,
  861. {
  862. "model": model,
  863. "layer": layer,
  864. }
  865. )
  866. reverse_mapping = reverse_mapping_obj.apply
  867. else:
  868. def parameter_count(func):
  869. if hasattr(inspect, "signature"):
  870. ret = len(inspect.signature(func).parameters)
  871. else:
  872. spec = inspect.getargspec(func)
  873. ret = len(spec.args)
  874. if spec.varargs is not None:
  875. ret += len(spec.varargs)
  876. if spec.keywords is not None:
  877. ret += len(spec.keywords)
  878. if ret == 3:
  879. # assume class function with self
  880. ret -= 1
  881. return ret
  882. if(callable(meta_reverse_mapping) and
  883. parameter_count(meta_reverse_mapping) == 2):
  884. # Function that returns mapping
  885. reverse_mapping = meta_reverse_mapping(
  886. layer,
  887. {
  888. "model": model,
  889. "layer": layer,
  890. }
  891. )
  892. else:
  893. # Nothing meta here
  894. reverse_mapping = meta_reverse_mapping
  895. initialized_reverse_mappings[layer] = reverse_mapping
  896. if project_bottleneck_tensors:
  897. bottleneck_tensors.update(
  898. get_bottleneck_tensors(
  899. model.inputs,
  900. outputs,
  901. execution_list))
  902. # Initialize the reverse tensor mappings.
  903. add_reversed_tensors(-1,
  904. outputs,
  905. [head_mapping(tmp) for tmp in outputs])
  906. # Follow the list and revert the graph.
  907. for _nid, (layer, Xs, Ys) in enumerate(reverse_execution_list):
  908. nid = len_execution_list_wo_inputs_layers - _nid - 1
  909. if isinstance(layer, keras.layers.InputLayer):
  910. # Special case. Do nothing.
  911. pass
  912. elif kchecks.is_network(layer):
  913. raise Exception("This is not supposed to happen!")
  914. else:
  915. Xs, Ys = iutils.to_list(Xs), iutils.to_list(Ys)
  916. if not all([ys in reversed_tensors for ys in Ys]):
  917. # This node is not part of our computational graph.
  918. # The (node-)world is bigger than this model.
  919. # Potentially this node is also not part of the
  920. # reversed tensor set because it depends on a tensor
  921. # that is listed in stop_mapping_at_tensors.
  922. continue
  923. reversed_Ys = [get_reversed_tensor(ys)
  924. for ys in Ys]
  925. local_stop_mapping_at_tensors = [x for x in Xs
  926. if x in stop_mapping_at_tensors]
  927. _print(" [NID: {}] Reverse layer-node {}".format(nid, layer))
  928. reverse_mapping = initialized_reverse_mappings[layer]
  929. reversed_Xs = reverse_mapping(
  930. Xs, Ys, reversed_Ys,
  931. {
  932. "nid": nid,
  933. "model": model,
  934. "layer": layer,
  935. "stop_mapping_at_tensors": local_stop_mapping_at_tensors,
  936. })
  937. reversed_Xs = iutils.to_list(reversed_Xs)
  938. add_reversed_tensors(nid, Xs, reversed_Xs)
  939. # Return requested values #################################################
  940. #THIS LINE ADDED FOR 3 INPUT MODEL:
  941. stop_mapping_at_tensors = [model.inputs[1],model.inputs[2]]
  942. reversed_input_tensors = [get_reversed_tensor(tmp)
  943. for tmp in model.inputs
  944. if tmp not in stop_mapping_at_tensors]
  945. if return_all_reversed_tensors is True:
  946. return reversed_input_tensors, reversed_tensors
  947. else:
  948. return reversed_input_tensors