dryrun.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. import six
  5. ###############################################################################
  6. ###############################################################################
  7. ###############################################################################
  8. import keras.backend as K
  9. import keras.models
  10. import numpy as np
  11. import unittest
  12. from ...analyzer.base import AnalyzerBase
  13. from . import networks
  14. __all__ = [
  15. "AnalyzerTestCase",
  16. "EqualAnalyzerTestCase",
  17. "PatternComputerTestCase",
  18. ]
  19. ###############################################################################
  20. ###############################################################################
  21. ###############################################################################
  22. def _set_zero_weights_to_random(weights):
  23. ret = []
  24. for weight in weights:
  25. if weight.sum() == 0:
  26. weight = np.random.rand(*weight.shape)
  27. ret.append(weight)
  28. return ret
  29. ###############################################################################
  30. ###############################################################################
  31. ###############################################################################
  32. class BaseLayerTestCase(unittest.TestCase):
  33. """
  34. A dryrun test on various networks for an analyzing method.
  35. For each network the test check that the generated network
  36. has the right output shape, can be compiled
  37. and executed with random inputs.
  38. """
  39. _network_filter = "trivia.*"
  40. def __init__(self, *args, **kwargs):
  41. network_filter = kwargs.pop("network_filter", None)
  42. if network_filter is not None:
  43. self._network_filter = network_filter
  44. super(BaseLayerTestCase, self).__init__(*args, **kwargs)
  45. def _apply_test(self, network):
  46. raise NotImplementedError("Set in subclass.")
  47. def runTest(self):
  48. np.random.seed(2349784365)
  49. K.clear_session()
  50. for network in networks.iterator(self._network_filter,
  51. clear_sessions=True):
  52. if six.PY2:
  53. self._apply_test(network)
  54. else:
  55. with self.subTest(network_name=network["name"]):
  56. self._apply_test(network)
  57. ###############################################################################
  58. ###############################################################################
  59. ###############################################################################
  60. class AnalyzerTestCase(BaseLayerTestCase):
  61. """TestCase for analyzers execution
  62. TestCase that applies the method to several networks and
  63. runs the analyzer with random data.
  64. :param method: A function that returns an Analyzer class.
  65. """
  66. def __init__(self, *args, **kwargs):
  67. method = kwargs.pop("method", None)
  68. if method is not None:
  69. self._method = method
  70. super(AnalyzerTestCase, self).__init__(*args, **kwargs)
  71. def _method(self, model):
  72. raise NotImplementedError("Set in subclass.")
  73. def _apply_test(self, network):
  74. # Create model.
  75. model = keras.models.Model(inputs=network["in"],
  76. outputs=network["out"])
  77. model.set_weights(_set_zero_weights_to_random(model.get_weights()))
  78. # Get analyzer.
  79. analyzer = self._method(model)
  80. # Dryrun.
  81. x = np.random.rand(1, *(network["input_shape"][1:]))
  82. analysis = analyzer.analyze(x)
  83. self.assertEqual(tuple(analysis.shape),
  84. (1,)+tuple(network["input_shape"][1:]))
  85. self.assertFalse(np.any(np.isinf(analysis.ravel())))
  86. self.assertFalse(np.any(np.isnan(analysis.ravel())))
  87. def test_analyzer(method, network_filter):
  88. """Workaround for move from unit-tests to pytest."""
  89. # todo: Mixing of pytest and unittest is not ideal.
  90. # Move completely to pytest.
  91. test_case = AnalyzerTestCase(method=method,
  92. network_filter=network_filter)
  93. test_result = unittest.TextTestRunner().run(test_case)
  94. assert len(test_result.errors) == 0
  95. assert len(test_result.failures) == 0
  96. class AnalyzerTrainTestCase(BaseLayerTestCase):
  97. """TestCase for analyzers execution
  98. TestCase that applies the method to several networks and
  99. trains and runs the analyzer with random data.
  100. :param method: A function that returns an Analyzer class.
  101. """
  102. def __init__(self, *args, **kwargs):
  103. method = kwargs.pop("method", None)
  104. if method is not None:
  105. self._method = method
  106. super(AnalyzerTrainTestCase, self).__init__(*args, **kwargs)
  107. def _method(self, model):
  108. raise NotImplementedError("Set in subclass.")
  109. def _apply_test(self, network):
  110. # Create model.
  111. model = keras.models.Model(inputs=network["in"],
  112. outputs=network["out"])
  113. model.set_weights(_set_zero_weights_to_random(model.get_weights()))
  114. # Get analyzer.
  115. analyzer = self._method(model)
  116. # Dryrun.
  117. x = np.random.rand(16, *(network["input_shape"][1:]))
  118. analyzer.fit(x)
  119. x = np.random.rand(1, *(network["input_shape"][1:]))
  120. analysis = analyzer.analyze(x)
  121. self.assertEqual(tuple(analysis.shape),
  122. (1,)+tuple(network["input_shape"][1:]))
  123. self.assertFalse(np.any(np.isinf(analysis.ravel())))
  124. self.assertFalse(np.any(np.isnan(analysis.ravel())))
  125. self.assertFalse(True)
  126. def test_train_analyzer(method, network_filter):
  127. """Workaround for move from unit-tests to pytest."""
  128. # todo: Mixing of pytest and unittest is not ideal.
  129. # Move completely to pytest.
  130. test_case = AnalyzerTrainTestCase(method=method,
  131. network_filter=network_filter)
  132. test_result = unittest.TextTestRunner().run(test_case)
  133. assert len(test_result.errors) == 0
  134. assert len(test_result.failures) == 0
  135. class EqualAnalyzerTestCase(BaseLayerTestCase):
  136. """TestCase for analyzers execution
  137. TestCase that applies two method to several networks and
  138. runs the analyzer with random data and checks for equality
  139. of the results.
  140. :param method1: A function that returns an Analyzer class.
  141. :param method2: A function that returns an Analyzer class.
  142. """
  143. def __init__(self, *args, **kwargs):
  144. method1 = kwargs.pop("method1", None)
  145. method2 = kwargs.pop("method2", None)
  146. if method1 is not None:
  147. self._method1 = method1
  148. if method2 is not None:
  149. self._method2 = method2
  150. super(EqualAnalyzerTestCase, self).__init__(*args, **kwargs)
  151. def _method1(self, model):
  152. raise NotImplementedError("Set in subclass.")
  153. def _method2(self, model):
  154. raise NotImplementedError("Set in subclass.")
  155. def _apply_test(self, network):
  156. # Create model.
  157. model = keras.models.Model(inputs=network["in"],
  158. outputs=network["out"])
  159. model.set_weights(_set_zero_weights_to_random(model.get_weights()))
  160. # Get analyzer.
  161. analyzer1 = self._method1(model)
  162. analyzer2 = self._method2(model)
  163. # Dryrun.
  164. x = np.random.rand(1, *(network["input_shape"][1:]))*100
  165. analysis1 = analyzer1.analyze(x)
  166. analysis2 = analyzer2.analyze(x)
  167. self.assertEqual(tuple(analysis1.shape),
  168. (1,)+tuple(network["input_shape"][1:]))
  169. self.assertFalse(np.any(np.isinf(analysis1.ravel())))
  170. self.assertFalse(np.any(np.isnan(analysis1.ravel())))
  171. self.assertEqual(tuple(analysis2.shape),
  172. (1,)+tuple(network["input_shape"][1:]))
  173. self.assertFalse(np.any(np.isinf(analysis2.ravel())))
  174. self.assertFalse(np.any(np.isnan(analysis2.ravel())))
  175. all_close_kwargs = {}
  176. if hasattr(self, "_all_close_rtol"):
  177. all_close_kwargs["rtol"] = self._all_close_rtol
  178. if hasattr(self, "_all_close_atol"):
  179. all_close_kwargs["atol"] = self._all_close_atol
  180. #print(analysis1.sum(), analysis2.sum())
  181. self.assertTrue(np.allclose(analysis1, analysis2, **all_close_kwargs))
  182. def test_equal_analyzer(method1, method2, network_filter):
  183. """Workaround for move from unit-tests to pytest."""
  184. # todo: Mixing of pytest and unittest is not ideal.
  185. # Move completely to pytest.
  186. test_case = EqualAnalyzerTestCase(method1=method1,
  187. method2=method2,
  188. network_filter=network_filter)
  189. test_result = unittest.TextTestRunner().run(test_case)
  190. assert len(test_result.errors) == 0
  191. assert len(test_result.failures) == 0
  192. # todo: merge with base test case? if we don't run the analysis
  193. # its only half the test.
  194. class SerializeAnalyzerTestCase(BaseLayerTestCase):
  195. """TestCase for analyzers serialization
  196. TestCase that applies the method to several networks and
  197. runs the analyzer with random data, serializes it, and
  198. runs it again.
  199. :param method: A function that returns an Analyzer class.
  200. """
  201. def __init__(self, *args, **kwargs):
  202. method = kwargs.pop("method", None)
  203. if method is not None:
  204. self._method = method
  205. super(SerializeAnalyzerTestCase, self).__init__(*args, **kwargs)
  206. def _method(self, model):
  207. raise NotImplementedError("Set in subclass.")
  208. def _apply_test(self, network):
  209. # Create model.
  210. model = keras.models.Model(inputs=network["in"],
  211. outputs=network["out"])
  212. model.set_weights(_set_zero_weights_to_random(model.get_weights()))
  213. # Get analyzer.
  214. analyzer = self._method(model)
  215. # Dryrun.
  216. x = np.random.rand(1, *(network["input_shape"][1:]))
  217. class_name, state = analyzer.save()
  218. new_analyzer = AnalyzerBase.load(class_name, state)
  219. analysis = new_analyzer.analyze(x)
  220. self.assertEqual(tuple(analysis.shape),
  221. (1,)+tuple(network["input_shape"][1:]))
  222. self.assertFalse(np.any(np.isinf(analysis.ravel())))
  223. self.assertFalse(np.any(np.isnan(analysis.ravel())))
  224. def test_serialize_analyzer(method, network_filter):
  225. """Workaround for move from unit-tests to pytest."""
  226. # todo: Mixing of pytest and unittest is not ideal.
  227. # Move completely to pytest.
  228. test_case = SerializeAnalyzerTestCase(method=method,
  229. network_filter=network_filter)
  230. test_result = unittest.TextTestRunner().run(test_case)
  231. assert len(test_result.errors) == 0
  232. assert len(test_result.failures) == 0
  233. ###############################################################################
  234. ###############################################################################
  235. ###############################################################################
  236. class PatternComputerTestCase(BaseLayerTestCase):
  237. """TestCase pattern computation
  238. :param method: A function that returns an PatternComputer class.
  239. """
  240. def __init__(self, *args, **kwargs):
  241. method = kwargs.pop("method", None)
  242. if method is not None:
  243. self._method = method
  244. super(PatternComputerTestCase, self).__init__(*args, **kwargs)
  245. def _method(self, model):
  246. raise NotImplementedError("Set in subclass.")
  247. def _apply_test(self, network):
  248. # Create model.
  249. model = keras.models.Model(inputs=network["in"], outputs=network["out"])
  250. model.set_weights(_set_zero_weights_to_random(model.get_weights()))
  251. # Get computer.
  252. computer = self._method(model)
  253. # Dryrun.
  254. x = np.random.rand(10, *(network["input_shape"][1:]))
  255. computer.compute(x)
  256. def test_pattern_computer(method, network_filter):
  257. """Workaround for move from unit-tests to pytest."""
  258. # todo: Mixing of pytest and unittest is not ideal.
  259. # Move completely to pytest.
  260. test_case = PatternComputerTestCase(method=method,
  261. network_filter=network_filter)
  262. test_result = unittest.TextTestRunner().run(test_case)
  263. assert len(test_result.errors) == 0
  264. assert len(test_result.failures) == 0