123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- # Get Python six functionality:
- from __future__ import \
- absolute_import, print_function, division, unicode_literals
- ###############################################################################
- ###############################################################################
- ###############################################################################
- import pytest
- from innvestigate.utils.tests import dryrun
- from innvestigate.analyzer import BaselineGradient
- from innvestigate.analyzer import Gradient
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__BasicGraphReversal():
- def method1(model):
- return BaselineGradient(model)
- def method2(model):
- return Gradient(model)
- dryrun.test_equal_analyzer(method1,
- method2,
- "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__BasicGraphReversal():
- def method1(model):
- return BaselineGradient(model)
- def method2(model):
- return Gradient(model)
- dryrun.test_equal_analyzer(method1,
- method2,
- "mnist.*")
- # @pytest.mark.fast
- # @pytest.mark.precommit
- # def test_fast__ContainerGraphReversal():
- # def method1(model):
- # return Gradient(model)
- # def method2(model):
- # Create container execution
- # model = keras.models.Model(inputs=model.inputs,
- # outputs=model(model.inputs))
- # return Gradient(model)
- # dryrun.test_equal_analyzer(method1,
- # method2,
- # "trivia.*:mnist.log_reg")
- # @pytest.mark.precommit
- # def test_precommit__ContainerGraphReversal():
- # def method1(model):
- # return Gradient(model)
- # def method2(model):
- # Create container execution
- # model = keras.models.Model(inputs=model.inputs,
- # outputs=model(model.inputs))
- # return Gradient(model)
- # dryrun.test_equal_analyzer(method1,
- # method2,
- # "mnist.*")
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__AnalyzerNetworkBase_neuron_selection_max():
- def method(model):
- return Gradient(model, neuron_selection_mode="max_activation")
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__AnalyzerNetworkBase_neuron_selection_max():
- def method(model):
- return Gradient(model, neuron_selection_mode="max_activation")
- dryrun.test_analyzer(method, "mnist.*")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__AnalyzerNetworkBase_neuron_selection_index():
- class CustomAnalyzer(Gradient):
- def analyze(self, X):
- index = 0
- return super(CustomAnalyzer, self).analyze(X, index)
- def method(model):
- return CustomAnalyzer(model, neuron_selection_mode="index")
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__AnalyzerNetworkBase_neuron_selection_index():
- class CustomAnalyzer(Gradient):
- def analyze(self, X):
- index = 3
- return super(CustomAnalyzer, self).analyze(X, index)
- def method(model):
- return CustomAnalyzer(model, neuron_selection_mode="index")
- dryrun.test_analyzer(method, "mnist.*")
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__BaseReverseNetwork_reverse_debug():
- def method(model):
- return Gradient(model, reverse_verbose=True)
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__BaseReverseNetwork_reverse_debug():
- def method(model):
- return Gradient(model, reverse_verbose=True)
- dryrun.test_analyzer(method, "mnist.*")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__BaseReverseNetwork_reverse_check_minmax():
- def method(model):
- return Gradient(model, reverse_verbose=True,
- reverse_check_min_max_values=True)
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__BaseReverseNetwork_reverse_check_minmax():
- def method(model):
- return Gradient(model, reverse_verbose=True,
- reverse_check_min_max_values=True)
- dryrun.test_analyzer(method, "mnist.*")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__BaseReverseNetwork_reverse_check_finite():
- def method(model):
- return Gradient(model, reverse_verbose=True, reverse_check_finite=True)
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__BaseReverseNetwork_reverse_check_finite():
- def method(model):
- return Gradient(model, reverse_verbose=True, reverse_check_finite=True)
- dryrun.test_analyzer(method, "mnist.*")
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__SerializeAnalyzerBase():
- def method(model):
- return BaselineGradient(model)
- dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__SerializeReverseAnalyzerkBase():
- def method(model):
- return Gradient(model)
- dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
-
|