pattern.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import range
  5. import six
  6. ###############################################################################
  7. ###############################################################################
  8. ###############################################################################
  9. import keras.backend as K
  10. import keras.layers
  11. import keras.models
  12. import keras.optimizers
  13. import keras.utils
  14. import numpy as np
  15. from .. import layers as ilayers
  16. from .. import utils as iutils
  17. from ..utils.keras import checks as kchecks
  18. from ..utils.keras import graph as kgraph
  19. __all__ = [
  20. "get_active_neuron_io",
  21. "get_pattern_class",
  22. "BasePattern",
  23. "DummyPattern",
  24. "LinearPattern",
  25. "ReLUPositivePattern",
  26. "ReLUNegativePattern",
  27. "PatternComputer",
  28. ]
  29. ###############################################################################
  30. ###############################################################################
  31. ###############################################################################
  32. def get_active_neuron_io(layer, active_node_indices,
  33. return_i=True, return_o=True,
  34. do_activation_search=False):
  35. """
  36. Returns the neuron-wise input output for the passed layer.
  37. This is done while taking care of only considering layer nodes that
  38. are listed as active.
  39. Starting from the passed layer this functions
  40. returns the first layer with an activation upstream in the model,
  41. if do_activation_search is an execution list.
  42. Otherwise the current layer's output is returned.
  43. """
  44. def contains_activation(layer):
  45. return (kchecks.contains_activation(layer) and
  46. not kchecks.contains_activation(layer, "linear"))
  47. def get_Xs(node_index):
  48. return iutils.to_list(layer.get_input_at(node_index))
  49. def get_Ys(node_index):
  50. ret = iutils.to_list(layer.get_output_at(node_index))
  51. if(do_activation_search is not False and
  52. not contains_activation(layer)):
  53. # Walk along execution graph until we find an activation function,
  54. # if current layer has none.
  55. execution_list = do_activation_search
  56. # First find current node.
  57. layer_i = None
  58. for i, node in enumerate(execution_list):
  59. if layer is node[0]:
  60. layer_i = i
  61. break
  62. assert layer_i is not None
  63. assert len(ret) == 1
  64. input_to_next_layer = ret[0]
  65. found = False
  66. for i in range(layer_i+1, len(execution_list)):
  67. l, Xs, Ys = execution_list[i]
  68. if input_to_next_layer in Xs:
  69. if not isinstance(
  70. l,
  71. kchecks.get_activation_search_safe_layers()):
  72. break
  73. if contains_activation(l):
  74. found = Ys
  75. break
  76. assert len(Ys) == 1
  77. input_to_next_layer = Ys[0]
  78. if found is not False:
  79. ret = Ys
  80. return ret
  81. # Get neuron-wise io for active layer nodes.
  82. tmp = [kgraph.get_layer_neuronwise_io(layer, Xs=get_Xs(i), Ys=get_Ys(i),
  83. return_i=return_i, return_o=return_o)
  84. for i in active_node_indices]
  85. if len(tmp) == 1:
  86. return tmp[0]
  87. else:
  88. raise NotImplementedError("This code seems not to handle several Ys.")
  89. # Layer is applied several times in model.
  90. # Concatenate the io of the applications.
  91. concatenate = keras.layers.Concatenate(axis=0)
  92. if return_i and return_o:
  93. return (concatenate([x[0] for x in tmp]),
  94. concatenate([x[1] for x in tmp]))
  95. else:
  96. return concatenate([x[0] for x in tmp])
  97. ###############################################################################
  98. ###############################################################################
  99. ###############################################################################
  100. class BasePattern(object):
  101. """
  102. Interface for pattern objects used to compute patterns by the
  103. PatternComputer class.
  104. The basic work-flow is that a pattern computes statistics for the
  105. passed layer, which are then used to compute the final pattern.
  106. """
  107. def __init__(self,
  108. model,
  109. layer,
  110. model_tensors=None,
  111. execution_list=None):
  112. self.model = model
  113. self.layer = layer
  114. # All the tensors used by the model.
  115. # Allows to filter nodes in layers that do not
  116. # belong to this model.
  117. self.model_tensors = model_tensors
  118. self.execution_list = execution_list
  119. self._active_node_indices = self._get_active_node_indices()
  120. def _get_active_node_indices(self):
  121. """
  122. A layer can be applied in several models.
  123. This functions returns a list with all nodes of the given
  124. layer that are active/used in the current model.
  125. If no model_tensors are passed to the pattern,
  126. it is assumed all nodes are active.
  127. """
  128. n_nodes = kgraph.get_layer_inbound_count(self.layer)
  129. if self.model_tensors is None:
  130. return list(range(n_nodes))
  131. else:
  132. ret = []
  133. for i in range(n_nodes):
  134. output_tensors = iutils.to_list(self.layer.get_output_at(i))
  135. # Check if output is used in the model.
  136. if all([tmp in self.model_tensors
  137. for tmp in output_tensors]):
  138. ret.append(i)
  139. return ret
  140. def has_pattern(self):
  141. return kchecks.contains_kernel(self.layer)
  142. def stats_from_batch(self):
  143. """
  144. Creates statistics while the PatternComputer passes the
  145. dataset once.
  146. """
  147. raise NotImplementedError()
  148. def compute_pattern(self):
  149. """
  150. Computes the pattern after computing the statistics.
  151. """
  152. raise NotImplementedError()
  153. class DummyPattern(BasePattern):
  154. """
  155. Computes a dummy pattern for test purposes.
  156. """
  157. def get_stats_from_batch(self):
  158. Xs, Ys = get_active_neuron_io(self.layer,
  159. self._active_node_indices)
  160. self.mean_x = ilayers.RunningMeans()
  161. count = ilayers.CountNonZero(axis=0)(Ys[0])
  162. sum_x = ilayers.Dot()([ilayers.Transpose()(Xs[0]), Ys[0]])
  163. mean_x, count_x = self.mean_x([sum_x, count])
  164. # Return dummy output to have connected graph!
  165. return ilayers.Sum(axis=None)(count_x)
  166. def compute_pattern(self):
  167. return self.mean_x.get_weights()[0]
  168. class LinearPattern(BasePattern):
  169. def _get_neuron_mask(self):
  170. """
  171. Select which neurons are considered for the pattern computation.
  172. """
  173. Ys = get_active_neuron_io(self.layer,
  174. self._active_node_indices,
  175. return_i=False, return_o=True)
  176. return ilayers.OnesLike()(Ys[0])
  177. def get_stats_from_batch(self):
  178. # Get the neuron-wise I/O for this layer.
  179. layer = kgraph.copy_layer_wo_activation(self.layer,
  180. keep_bias=False,
  181. reuse_symbolic_tensors=False)
  182. # Readjust the layer nodes.
  183. for i in range(kgraph.get_layer_inbound_count(self.layer)):
  184. layer(self.layer.get_input_at(i))
  185. Xs, Ys = get_active_neuron_io(layer, self._active_node_indices)
  186. if len(Ys) != 1:
  187. raise ValueError("Assume that kernel layer have only one output.")
  188. X, Y = Xs[0], Ys[0]
  189. # Create layers that keep a running mean for the desired stats.
  190. self.mean_x = ilayers.RunningMeans()
  191. self.mean_y = ilayers.RunningMeans()
  192. self.mean_xy = ilayers.RunningMeans()
  193. # Compute mask and active neuron counts.
  194. mask = ilayers.AsFloatX()(self._get_neuron_mask())
  195. Y_masked = keras.layers.multiply([Y, mask])
  196. count = ilayers.CountNonZero(axis=0)(mask)
  197. count_all = ilayers.Sum(axis=0)(ilayers.OnesLike()(mask))
  198. # Get means ...
  199. def norm(x, count):
  200. return ilayers.SafeDivide(factor=1)([x, count])
  201. # ... along active neurons.
  202. mean_x = norm(ilayers.Dot()([ilayers.Transpose()(X), mask]), count)
  203. mean_xy = norm(ilayers.Dot()([ilayers.Transpose()(X), Y_masked]),
  204. count)
  205. _, a = self.mean_x([mean_x, count])
  206. _, b = self.mean_xy([mean_xy, count])
  207. # ... along all neurons.
  208. mean_y = norm(ilayers.Sum(axis=0)(Y), count_all)
  209. _, c = self.mean_y([mean_y, count_all])
  210. # Create a dummy output to have a connected graph.
  211. # Needs to have the shape (mb_size, 1)
  212. dummy = keras.layers.Average()([a, b, c])
  213. return ilayers.Sum(axis=None)(dummy)
  214. def compute_pattern(self):
  215. """Computes the patterns according to the formula in the paper."""
  216. def safe_divide(a, b):
  217. return a / (b + (b == 0))
  218. W = kgraph.get_kernel(self.layer)
  219. W2D = W.reshape((-1, W.shape[-1]))
  220. mean_x, cnt_x = self.mean_x.get_weights()
  221. mean_y, cnt_y = self.mean_y.get_weights()
  222. mean_xy, cnt_xy = self.mean_xy.get_weights()
  223. ExEy = mean_x * mean_y
  224. cov_xy = mean_xy - ExEy
  225. w_cov_xy = np.diag(np.dot(W2D.T, cov_xy))
  226. A = safe_divide(cov_xy, w_cov_xy[None, :])
  227. # update length
  228. if False:
  229. norm = np.diag(np.dot(W2D.T, A))
  230. A = safe_divide(A, norm)
  231. # check pattern
  232. if False:
  233. tmp = np.diag(np.dot(W2D.T, A))
  234. print("pattern_check", W.shape, tmp.min(), tmp.max())
  235. return A.reshape(W.shape)
  236. class ReLUPositivePattern(LinearPattern):
  237. def _get_neuron_mask(self):
  238. Ys = get_active_neuron_io(self.layer,
  239. self._active_node_indices,
  240. return_i=False, return_o=True,
  241. do_activation_search=self.execution_list)
  242. return ilayers.GreaterThanZero()(Ys[0])
  243. class ReLUNegativePattern(LinearPattern):
  244. def _get_neuron_mask(self):
  245. Ys = get_active_neuron_io(self.layer,
  246. self._active_node_indices,
  247. return_i=False, return_o=True,
  248. do_activation_search=self.execution_list)
  249. return ilayers.LessEqualThanZero()(Ys[0])
  250. def get_pattern_class(pattern_type):
  251. return {
  252. "dummy": DummyPattern,
  253. "linear": LinearPattern,
  254. "relu": ReLUPositivePattern,
  255. "relu.positive": ReLUPositivePattern,
  256. "relu.negative": ReLUNegativePattern,
  257. }.get(pattern_type, pattern_type)
  258. ###############################################################################
  259. ###############################################################################
  260. ###############################################################################
  261. class PatternComputer(object):
  262. """Pattern computer.
  263. Computes a pattern for each layer with a kernel of a given model.
  264. :param model: A Keras model.
  265. :param pattern_type: A string or a tuple of strings. Valid types are
  266. 'linear', 'relu', 'relu.positive', 'relu.negative'.
  267. :param compute_layers_in_parallel: Not supported yet.
  268. Compute all patterns at once.
  269. Otherwise computer layer after layer.
  270. :param gpus: Not supported yet. Gpus to use.
  271. """
  272. def __init__(self, model,
  273. pattern_type="linear",
  274. # todo: this options seems to be buggy,
  275. # if it sequential tensorflow still pushes all models to gpus
  276. compute_layers_in_parallel=True,
  277. gpus=None):
  278. self.model = model
  279. # Break cyclic import.
  280. import innvestigate.analyzer.pattern_based
  281. supported_layers = (
  282. innvestigate.analyzer.pattern_based.SUPPORTED_LAYER_PATTERNNET)
  283. for layer in self.model.layers:
  284. if not isinstance(layer, supported_layers):
  285. raise Exception("Model contains not supported layer: %s"
  286. % layer)
  287. pattern_types = iutils.to_list(pattern_type)
  288. self.pattern_types = {k: get_pattern_class(k)
  289. for k in pattern_types}
  290. self.compute_layers_in_parallel = compute_layers_in_parallel
  291. self.gpus = gpus
  292. if self.compute_layers_in_parallel is False:
  293. raise NotImplementedError("Not supported.")
  294. def _create_computers(self):
  295. """
  296. Creates pattern objects and Keras models that are used to collect
  297. statistics and compute patterns.
  298. We compute the patterns by first computing statistics within
  299. the Keras framework, which are then used to compute the patterns.
  300. This is based on a workaround. We connect the stats computation
  301. via dummy outputs to a model's output and then iterate over the
  302. dataset to compute statistics.
  303. """
  304. # Create a broadcasting function that is used to connect
  305. # the dummy outputs.
  306. # Broadcaster has shape (mini_batch_size, 1)
  307. reduce_axes = list(range(len(K.int_shape(self.model.inputs[0]))))[1:]
  308. dummy_broadcaster = ilayers.Sum(axis=reduce_axes,
  309. keepdims=True)(self.model.inputs[0])
  310. def broadcast(x):
  311. return ilayers.Broadcast()([dummy_broadcaster, x])
  312. # Collect all tensors that are part of a model's execution.
  313. layers, execution_list, _ = kgraph.trace_model_execution(self.model)
  314. model_tensors = set()
  315. for _, input_tensors, output_tensors in execution_list:
  316. for t in input_tensors+output_tensors:
  317. model_tensors.add(t)
  318. # Create pattern instances and collect the dummy outputs.
  319. self._pattern_instances = {k: [] for k in self.pattern_types}
  320. computer_outputs = []
  321. for layer_id, layer in enumerate(layers):
  322. # This does not work with containers!
  323. # They should be replaced by trace_model_execution.
  324. if kchecks.is_network(layer):
  325. raise Exception("Network in network is not suppored!")
  326. for pattern_type, clazz in six.iteritems(self.pattern_types):
  327. pinstance = clazz(self.model, layer,
  328. model_tensors=model_tensors,
  329. execution_list=execution_list)
  330. if pinstance.has_pattern() is False:
  331. continue
  332. self._pattern_instances[pattern_type].append(pinstance)
  333. dummy_output = pinstance.get_stats_from_batch()
  334. # Broadcast dummy_output to right shape.
  335. computer_outputs += iutils.to_list(broadcast(dummy_output))
  336. # Now we create one or more Keras models to train the patterns.
  337. self._n_computer_outputs = len(computer_outputs)
  338. if self.compute_layers_in_parallel is True:
  339. self._computers = [
  340. keras.models.Model(inputs=self.model.inputs,
  341. outputs=computer_outputs)
  342. ]
  343. else:
  344. self._computers = [
  345. keras.models.Model(inputs=self.model.inputs,
  346. outputs=computer_output)
  347. for computer_output in computer_outputs
  348. ]
  349. # Distribute computation on more gpus.
  350. if self.gpus is not None and self.gpus > 1:
  351. raise NotImplementedError("Not supported yet.")
  352. self._computers = [keras.utils.multi_gpu_model(tmp, gpus=self.gpus)
  353. for tmp in self._computers]
  354. def compute(self, X, batch_size=32, verbose=0):
  355. """
  356. Compute and return the patterns for the model and the data `X`.
  357. :param X: Data to compute patterns.
  358. :param batch_size: Batch size to use.
  359. :param verbose: As for keras model.fit.
  360. """
  361. generator = iutils.BatchSequence(X, batch_size)
  362. return self.compute_generator(generator, verbose=verbose)
  363. def compute_generator(self, generator, **kwargs):
  364. """
  365. Compute and return the patterns for the model and the data `X`.
  366. :param generator: Data to compute patterns.
  367. :param kwargs: Same as for keras model.fit_generator.
  368. """
  369. self._create_computers()
  370. # We don't do gradient updates.
  371. class NoOptimizer(keras.optimizers.Optimizer):
  372. def get_updates(self, *args, **kwargs):
  373. return []
  374. optimizer = NoOptimizer()
  375. # We only pass the training data once.
  376. if "epochs" in kwargs and kwargs["epochs"] != 1:
  377. raise ValueError("Pattern are computed with "
  378. "a closed form solution. "
  379. "Only need to do one epoch.")
  380. kwargs["epochs"] = 1
  381. if self.compute_layers_in_parallel is True:
  382. n_dummy_outputs = self._n_computer_outputs
  383. else:
  384. n_dummy_outputs = 1
  385. # Augment the input with dummy targets.
  386. def get_dummy_targets(Xs):
  387. n, dtype = Xs[0].shape[0], Xs[0].dtype
  388. dummy = np.ones(shape=(n, 1), dtype=dtype)
  389. return [dummy for _ in range(n_dummy_outputs)]
  390. if isinstance(generator, keras.utils.Sequence):
  391. generator = iutils.TargetAugmentedSequence(generator,
  392. get_dummy_targets)
  393. else:
  394. base_generator = generator
  395. def generator(*args, **kwargs):
  396. for Xs in base_generator(*args, **kwargs):
  397. Xs = iutils.to_list(Xs)
  398. yield Xs, get_dummy_targets(Xs)
  399. # Compile models.
  400. for computer in self._computers:
  401. computer.compile(optimizer=optimizer, loss=lambda x, y: x)
  402. # Compute pattern statistics.
  403. for computer in self._computers:
  404. computer.fit_generator(generator, **kwargs)
  405. # Compute and retrieve the actual patterns.
  406. pis = self._pattern_instances
  407. patterns = {ptype: [tmp.compute_pattern() for tmp in pis[ptype]]
  408. for ptype in self.pattern_types}
  409. # Free memory.
  410. del self._computers
  411. del self._pattern_instances
  412. if len(self.pattern_types) == 1:
  413. return patterns[list(self.pattern_types.keys())[0]]
  414. else:
  415. return patterns