test_base.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Get Python six functionality:
  2. from __future__ import \
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import pytest
  8. from innvestigate.utils.tests import dryrun
  9. from innvestigate.analyzer import BaselineGradient
  10. from innvestigate.analyzer import Gradient
  11. ###############################################################################
  12. ###############################################################################
  13. ###############################################################################
  14. @pytest.mark.fast
  15. @pytest.mark.precommit
  16. def test_fast__BasicGraphReversal():
  17. def method1(model):
  18. return BaselineGradient(model)
  19. def method2(model):
  20. return Gradient(model)
  21. dryrun.test_equal_analyzer(method1,
  22. method2,
  23. "trivia.*:mnist.log_reg")
  24. @pytest.mark.precommit
  25. def test_precommit__BasicGraphReversal():
  26. def method1(model):
  27. return BaselineGradient(model)
  28. def method2(model):
  29. return Gradient(model)
  30. dryrun.test_equal_analyzer(method1,
  31. method2,
  32. "mnist.*")
  33. # @pytest.mark.fast
  34. # @pytest.mark.precommit
  35. # def test_fast__ContainerGraphReversal():
  36. # def method1(model):
  37. # return Gradient(model)
  38. # def method2(model):
  39. # Create container execution
  40. # model = keras.models.Model(inputs=model.inputs,
  41. # outputs=model(model.inputs))
  42. # return Gradient(model)
  43. # dryrun.test_equal_analyzer(method1,
  44. # method2,
  45. # "trivia.*:mnist.log_reg")
  46. # @pytest.mark.precommit
  47. # def test_precommit__ContainerGraphReversal():
  48. # def method1(model):
  49. # return Gradient(model)
  50. # def method2(model):
  51. # Create container execution
  52. # model = keras.models.Model(inputs=model.inputs,
  53. # outputs=model(model.inputs))
  54. # return Gradient(model)
  55. # dryrun.test_equal_analyzer(method1,
  56. # method2,
  57. # "mnist.*")
  58. ###############################################################################
  59. ###############################################################################
  60. ###############################################################################
  61. @pytest.mark.fast
  62. @pytest.mark.precommit
  63. def test_fast__AnalyzerNetworkBase_neuron_selection_max():
  64. def method(model):
  65. return Gradient(model, neuron_selection_mode="max_activation")
  66. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  67. @pytest.mark.precommit
  68. def test_precommit__AnalyzerNetworkBase_neuron_selection_max():
  69. def method(model):
  70. return Gradient(model, neuron_selection_mode="max_activation")
  71. dryrun.test_analyzer(method, "mnist.*")
  72. @pytest.mark.fast
  73. @pytest.mark.precommit
  74. def test_fast__AnalyzerNetworkBase_neuron_selection_index():
  75. class CustomAnalyzer(Gradient):
  76. def analyze(self, X):
  77. index = 0
  78. return super(CustomAnalyzer, self).analyze(X, index)
  79. def method(model):
  80. return CustomAnalyzer(model, neuron_selection_mode="index")
  81. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  82. @pytest.mark.precommit
  83. def test_precommit__AnalyzerNetworkBase_neuron_selection_index():
  84. class CustomAnalyzer(Gradient):
  85. def analyze(self, X):
  86. index = 3
  87. return super(CustomAnalyzer, self).analyze(X, index)
  88. def method(model):
  89. return CustomAnalyzer(model, neuron_selection_mode="index")
  90. dryrun.test_analyzer(method, "mnist.*")
  91. ###############################################################################
  92. ###############################################################################
  93. ###############################################################################
  94. @pytest.mark.fast
  95. @pytest.mark.precommit
  96. def test_fast__BaseReverseNetwork_reverse_debug():
  97. def method(model):
  98. return Gradient(model, reverse_verbose=True)
  99. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  100. @pytest.mark.precommit
  101. def test_precommit__BaseReverseNetwork_reverse_debug():
  102. def method(model):
  103. return Gradient(model, reverse_verbose=True)
  104. dryrun.test_analyzer(method, "mnist.*")
  105. @pytest.mark.fast
  106. @pytest.mark.precommit
  107. def test_fast__BaseReverseNetwork_reverse_check_minmax():
  108. def method(model):
  109. return Gradient(model, reverse_verbose=True,
  110. reverse_check_min_max_values=True)
  111. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  112. @pytest.mark.precommit
  113. def test_precommit__BaseReverseNetwork_reverse_check_minmax():
  114. def method(model):
  115. return Gradient(model, reverse_verbose=True,
  116. reverse_check_min_max_values=True)
  117. dryrun.test_analyzer(method, "mnist.*")
  118. @pytest.mark.fast
  119. @pytest.mark.precommit
  120. def test_fast__BaseReverseNetwork_reverse_check_finite():
  121. def method(model):
  122. return Gradient(model, reverse_verbose=True, reverse_check_finite=True)
  123. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  124. @pytest.mark.precommit
  125. def test_precommit__BaseReverseNetwork_reverse_check_finite():
  126. def method(model):
  127. return Gradient(model, reverse_verbose=True, reverse_check_finite=True)
  128. dryrun.test_analyzer(method, "mnist.*")
  129. ###############################################################################
  130. ###############################################################################
  131. ###############################################################################
  132. @pytest.mark.fast
  133. @pytest.mark.precommit
  134. def test_fast__SerializeAnalyzerBase():
  135. def method(model):
  136. return BaselineGradient(model)
  137. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
  138. @pytest.mark.fast
  139. @pytest.mark.precommit
  140. def test_fast__SerializeReverseAnalyzerkBase():
  141. def method(model):
  142. return Gradient(model)
  143. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")