test_wrapper.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import pytest
  8. from innvestigate.utils.tests import dryrun
  9. from innvestigate.analyzer import WrapperBase
  10. from innvestigate.analyzer import AugmentReduceBase
  11. from innvestigate.analyzer import GaussianSmoother
  12. from innvestigate.analyzer import PathIntegrator
  13. from innvestigate.analyzer import Gradient
  14. ###############################################################################
  15. ###############################################################################
  16. ###############################################################################
  17. @pytest.mark.fast
  18. @pytest.mark.precommit
  19. def test_fast__WrapperBase():
  20. def method(model):
  21. return WrapperBase(Gradient(model))
  22. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  23. @pytest.mark.precommit
  24. def test_precommit__WrapperBase():
  25. def method(model):
  26. return WrapperBase(Gradient(model))
  27. dryrun.test_analyzer(method, "mnist.*")
  28. @pytest.mark.fast
  29. @pytest.mark.precommit
  30. def test_fast__SerializeWrapperBase():
  31. def method(model):
  32. return WrapperBase(Gradient(model))
  33. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
  34. ###############################################################################
  35. ###############################################################################
  36. ###############################################################################
  37. @pytest.mark.fast
  38. @pytest.mark.precommit
  39. def test_fast__AugmentReduceBase():
  40. def method(model):
  41. return AugmentReduceBase(Gradient(model))
  42. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  43. @pytest.mark.precommit
  44. def test_precommit__AugmentReduceBase():
  45. def method(model):
  46. return AugmentReduceBase(Gradient(model))
  47. dryrun.test_analyzer(method, "mnist.*")
  48. @pytest.mark.fast
  49. @pytest.mark.precommit
  50. def test_fast__SerializeAugmentReduceBase():
  51. def method(model):
  52. return AugmentReduceBase(Gradient(model))
  53. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
  54. ###############################################################################
  55. ###############################################################################
  56. ###############################################################################
  57. @pytest.mark.fast
  58. @pytest.mark.precommit
  59. def test_fast__GaussianSmoother():
  60. def method(model):
  61. return GaussianSmoother(Gradient(model))
  62. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  63. @pytest.mark.precommit
  64. def test_precommit__GaussianSmoother():
  65. def method(model):
  66. return GaussianSmoother(Gradient(model))
  67. dryrun.test_analyzer(method, "mnist.*")
  68. @pytest.mark.fast
  69. @pytest.mark.precommit
  70. def test_fast__SerializeGaussianSmoother():
  71. def method(model):
  72. return GaussianSmoother(Gradient(model))
  73. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
  74. ###############################################################################
  75. ###############################################################################
  76. ###############################################################################
  77. @pytest.mark.fast
  78. @pytest.mark.precommit
  79. def test_fast__PathIntegrator():
  80. def method(model):
  81. return PathIntegrator(Gradient(model))
  82. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  83. @pytest.mark.precommit
  84. def test_precommit__PathIntegrator():
  85. def method(model):
  86. return PathIntegrator(Gradient(model))
  87. dryrun.test_analyzer(method, "mnist.*")
  88. @pytest.mark.fast
  89. @pytest.mark.precommit
  90. def test_fast__SerializePathIntegrator():
  91. def method(model):
  92. return PathIntegrator(Gradient(model))
  93. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")