test_deeplift.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import keras.layers
  8. import keras.models
  9. import numpy as np
  10. import pytest
  11. try:
  12. import deeplift
  13. except ImportError:
  14. deeplift = None
  15. from innvestigate.utils.tests import dryrun
  16. from innvestigate.analyzer import DeepLIFT
  17. from innvestigate.analyzer import DeepLIFTWrapper
  18. ###############################################################################
  19. ###############################################################################
  20. ###############################################################################
  21. @pytest.mark.fast
  22. @pytest.mark.precommit
  23. def test_fast__DeepLIFT():
  24. def method(model):
  25. return DeepLIFT(model)
  26. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  27. @pytest.mark.precommit
  28. def test_precommit__DeepLIFT():
  29. def method(model):
  30. return DeepLIFT(model)
  31. dryrun.test_analyzer(method, "mnist.*")
  32. @pytest.mark.precommit
  33. def test_precommit__DeepLIFT_Rescale():
  34. def method(model):
  35. if keras.backend.image_data_format() == "channels_first":
  36. input_shape = (1, 28, 28)
  37. else:
  38. input_shape = (28, 28, 1)
  39. model = keras.models.Sequential([
  40. keras.layers.Dense(10, input_shape=input_shape),
  41. keras.layers.ReLU(),
  42. ])
  43. return DeepLIFT(model)
  44. dryrun.test_analyzer(method, "mnist.log_reg")
  45. @pytest.mark.precommit
  46. def test_precommit__DeepLIFT_neuron_selection_index():
  47. class CustomAnalyzer(DeepLIFT):
  48. def analyze(self, X):
  49. index = 0
  50. return super(CustomAnalyzer, self).analyze(X, index)
  51. def method(model):
  52. return CustomAnalyzer(model, neuron_selection_mode="index")
  53. dryrun.test_analyzer(method, "mnist.*")
  54. @pytest.mark.precommit
  55. def test_precommit__DeepLIFT_larger_batch_size():
  56. class CustomAnalyzer(DeepLIFT):
  57. def analyze(self, X):
  58. X = np.concatenate((X, X), axis=0)
  59. return super(CustomAnalyzer, self).analyze(X)[0:1]
  60. def method(model):
  61. return CustomAnalyzer(model)
  62. dryrun.test_analyzer(method, "mnist.*")
  63. @pytest.mark.skip("There is a design issue to be fixed.")
  64. @pytest.mark.precommit
  65. def test_precommit__DeepLIFT_larger_batch_size_with_index():
  66. class CustomAnalyzer(DeepLIFT):
  67. def analyze(self, X):
  68. index = 0
  69. X = np.concatenate((X, X), axis=0)
  70. return super(CustomAnalyzer, self).analyze(X, index)[0:1]
  71. def method(model):
  72. return CustomAnalyzer(model, neuron_selection_mode="index")
  73. dryrun.test_analyzer(method, "mnist.*")
  74. @pytest.mark.slow
  75. @pytest.mark.application
  76. @pytest.mark.imagenet
  77. def test_imagenet__DeepLIFT():
  78. def method(model):
  79. return DeepLIFT(model)
  80. dryrun.test_analyzer(method, "imagenet.*")
  81. ###############################################################################
  82. ###############################################################################
  83. ###############################################################################
  84. require_deeplift = pytest.mark.skipif(deeplift is None,
  85. reason="Package deeplift is required.")
  86. @require_deeplift
  87. @pytest.mark.fast
  88. @pytest.mark.precommit
  89. @pytest.mark.skip(reason="DeepLIFT does not work with skip connection.")
  90. def test_fast__DeepLIFTWrapper():
  91. def method(model):
  92. return DeepLIFTWrapper(model)
  93. dryrun.test_analyzer(method, "trivia.*:mnist.log_reg")
  94. @require_deeplift
  95. @pytest.mark.precommit
  96. def test_precommit__DeepLIFTWrapper():
  97. def method(model):
  98. return DeepLIFTWrapper(model)
  99. dryrun.test_analyzer(method, "mnist.*")
  100. @require_deeplift
  101. @pytest.mark.precommit
  102. def test_precommit__DeepLIFTWrapper_neuron_selection_index():
  103. class CustomAnalyzer(DeepLIFTWrapper):
  104. def analyze(self, X):
  105. index = 0
  106. return super(CustomAnalyzer, self).analyze(X, index)
  107. def method(model):
  108. return CustomAnalyzer(model, neuron_selection_mode="index")
  109. dryrun.test_analyzer(method, "mnist.*")
  110. @require_deeplift
  111. @pytest.mark.precommit
  112. def test_precommit__DeepLIFTWrapper_larger_batch_size():
  113. class CustomAnalyzer(DeepLIFTWrapper):
  114. def analyze(self, X):
  115. X = np.concatenate((X, X), axis=0)
  116. return super(CustomAnalyzer, self).analyze(X)[0:1]
  117. def method(model):
  118. return CustomAnalyzer(model)
  119. dryrun.test_analyzer(method, "mnist.*")
  120. @require_deeplift
  121. @pytest.mark.precommit
  122. def test_precommit__DeepLIFTWrapper_larger_batch_size_with_index():
  123. class CustomAnalyzer(DeepLIFTWrapper):
  124. def analyze(self, X):
  125. index = 0
  126. X = np.concatenate((X, X), axis=0)
  127. return super(CustomAnalyzer, self).analyze(X, index)[0:1]
  128. def method(model):
  129. return CustomAnalyzer(model, neuron_selection_mode="index")
  130. dryrun.test_analyzer(method, "mnist.*")
  131. @require_deeplift
  132. @pytest.mark.slow
  133. @pytest.mark.application
  134. @pytest.mark.imagenet
  135. def test_imagenet__DeepLIFTWrapper():
  136. def method(model):
  137. return DeepLIFTWrapper(model)
  138. dryrun.test_analyzer(method, "imagenet.*")
  139. ###############################################################################
  140. ###############################################################################
  141. ###############################################################################
  142. @pytest.mark.fast
  143. @pytest.mark.precommit
  144. def test_fast__DeepLIFT_serialize():
  145. def method(model):
  146. return DeepLIFT(model)
  147. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")
  148. @pytest.mark.fast
  149. @pytest.mark.precommit
  150. def test_fast__DeepLIFTWrapper_serialize():
  151. def method(model):
  152. return DeepLIFTWrapper(model)
  153. with pytest.raises(AssertionError):
  154. # Issue in deeplift.
  155. dryrun.test_serialize_analyzer(method, "trivia.*:mnist.log_reg")