test_perturbate.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import keras.layers
  8. import keras.models
  9. import numpy as np
  10. import pytest
  11. import innvestigate.tools.perturbate
  12. import innvestigate.utils as iutils
  13. ###############################################################################
  14. ###############################################################################
  15. ###############################################################################
  16. @pytest.mark.fast
  17. @pytest.mark.precommit
  18. def test_fast__PerturbationAnalysis():
  19. # Some test data
  20. if keras.backend.image_data_format() == "channels_first":
  21. input_shape = (2, 1, 4, 4)
  22. else:
  23. input_shape = (2, 4, 4, 1)
  24. x = np.arange(2 * 4 * 4).reshape(input_shape)
  25. generator = iutils.BatchSequence([x, np.zeros(x.shape[0])], batch_size=x.shape[0])
  26. # Simple model
  27. model = keras.models.Sequential([
  28. keras.layers.Flatten(input_shape=x.shape[1:]),
  29. keras.layers.Dense(1, use_bias=False),
  30. ])
  31. weights = np.arange(4 * 4 * 1).reshape((4 * 4, 1))
  32. model.layers[-1].set_weights([weights])
  33. model.compile(loss='mean_squared_error', optimizer='sgd')
  34. expected_output = np.array([[1240.], [3160.]])
  35. assert np.all(np.isclose(model.predict(x), expected_output))
  36. # Analyzer
  37. analyzer = innvestigate.create_analyzer("gradient",
  38. model,
  39. postprocess="abs")
  40. # Run perturbation analysis
  41. perturbation = innvestigate.tools.perturbate.Perturbation("zeros", region_shape=(2, 2), in_place=False)
  42. perturbation_analysis = innvestigate.tools.perturbate.PerturbationAnalysis(analyzer, model, generator, perturbation, recompute_analysis=False,
  43. steps=3, regions_per_step=1, verbose=False)
  44. scores = perturbation_analysis.compute_perturbation_analysis()
  45. expected_scores = np.array([5761600.0, 1654564.0, 182672.0, 21284.0])
  46. assert np.all(np.isclose(scores, expected_scores))
  47. @pytest.mark.fast
  48. @pytest.mark.precommit
  49. def test_fast__Perturbation():
  50. if keras.backend.image_data_format() == "channels_first":
  51. input_shape = (1, 1, 4, 4)
  52. else:
  53. input_shape = (1, 4, 4, 1)
  54. x = np.arange(1 * 4 * 4).reshape(input_shape)
  55. perturbation = innvestigate.tools.perturbate.Perturbation("zeros", region_shape=(2, 2), in_place=False)
  56. analysis = np.zeros((4, 4))
  57. analysis[:2, 2:] = 1
  58. analysis[2:, :2] = 2
  59. analysis[2:, 2:] = 3
  60. analysis = analysis.reshape(input_shape)
  61. if keras.backend.image_data_format() == "channels_last":
  62. x = np.moveaxis(x, 3, 1)
  63. analysis = np.moveaxis(analysis, 3, 1)
  64. analysis = perturbation.reduce_function(analysis, axis=1, keepdims=True)
  65. aggregated_regions = perturbation.aggregate_regions(analysis)
  66. assert np.all(np.isclose(aggregated_regions[0, 0, :, :], np.array([[0, 1], [2, 3]])))
  67. ranks = perturbation.compute_region_ordering(aggregated_regions)
  68. assert np.all(np.isclose(ranks[0, 0, :, :], np.array([[3, 2], [1, 0]])))
  69. perturbation_mask_regions = perturbation.compute_perturbation_mask(ranks, 1)
  70. assert np.all(perturbation_mask_regions == np.array([[0, 0], [0, 1]]))
  71. perturbation_mask_regions = perturbation.compute_perturbation_mask(ranks, 4)
  72. assert np.all(perturbation_mask_regions == np.array([[1, 1], [1, 1]]))
  73. perturbation_mask_regions = perturbation.compute_perturbation_mask(ranks, 0)
  74. assert np.all(perturbation_mask_regions == np.array([[0, 0], [0, 0]]))