123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- # Get Python six functionality:
- from __future__ import\
- absolute_import, print_function, division, unicode_literals
- ###############################################################################
- ###############################################################################
- ###############################################################################
- import pytest
- from innvestigate.utils.tests import dryrun
- from innvestigate.analyzer import WrapperBase
- from innvestigate.analyzer import AugmentReduceBase
- from innvestigate.analyzer import GaussianSmoother
- from innvestigate.analyzer import PathIntegrator
- from innvestigate.analyzer import Gradient
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__WrapperBase():
- def method(model):
- return WrapperBase(Gradient(model))
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__WrapperBase():
- def method(model):
- return WrapperBase(Gradient(model))
- dryrun.test_analyzer(method, "mnist.*")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__SerializeWrapperBase():
- def method(model):
- return WrapperBase(Gradient(model))
- dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__AugmentReduceBase():
- def method(model):
- return AugmentReduceBase(Gradient(model))
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__AugmentReduceBase():
- def method(model):
- return AugmentReduceBase(Gradient(model))
- dryrun.test_analyzer(method, "mnist.*")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__SerializeAugmentReduceBase():
- def method(model):
- return AugmentReduceBase(Gradient(model))
- dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__GaussianSmoother():
- def method(model):
- return GaussianSmoother(Gradient(model))
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__GaussianSmoother():
- def method(model):
- return GaussianSmoother(Gradient(model))
- dryrun.test_analyzer(method, "mnist.*")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__SerializeGaussianSmoother():
- def method(model):
- return GaussianSmoother(Gradient(model))
- dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__PathIntegrator():
- def method(model):
- return PathIntegrator(Gradient(model))
- dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
- @pytest.mark.precommit
- def test_precommit__PathIntegrator():
- def method(model):
- return PathIntegrator(Gradient(model))
- dryrun.test_analyzer(method, "mnist.*")
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__SerializePathIntegrator():
- def method(model):
- return PathIntegrator(Gradient(model))
- dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
|