test_init.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Get Python six functionality:
  2. from __future__ import \
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import pytest
  8. import keras.layers
  9. import keras.models
  10. from innvestigate import create_analyzer
  11. from innvestigate.analyzer import analyzers
  12. ###############################################################################
  13. ###############################################################################
  14. ###############################################################################
  15. @pytest.mark.fast
  16. @pytest.mark.precommit
  17. def test_fast__create_analyzers():
  18. fake_model = keras.models.Sequential([
  19. keras.layers.Dense(10, input_shape=(10,))
  20. ])
  21. for name in analyzers.keys():
  22. try:
  23. create_analyzer(name, fake_model)
  24. except KeyError:
  25. # Name should be found!
  26. raise
  27. except:
  28. # Some analyzers require parameters...
  29. pass
  30. @pytest.mark.fast
  31. @pytest.mark.precommit
  32. def test_fast__create_analyzers_wrong_name():
  33. fake_model = keras.models.Sequential([
  34. keras.layers.Dense(10, input_shape=(10,))
  35. ])
  36. with pytest.raises(KeyError):
  37. create_analyzer("wrong name", fake_model)