123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- # Get Python six functionality:
- from __future__ import\
- absolute_import, print_function, division, unicode_literals
- ###############################################################################
- ###############################################################################
- ###############################################################################
- import keras.layers
- import keras.models
- import numpy as np
- import pytest
- import innvestigate.tools.perturbate
- import innvestigate.utils as iutils
- ###############################################################################
- ###############################################################################
- ###############################################################################
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__PerturbationAnalysis():
- # Some test data
- if keras.backend.image_data_format() == "channels_first":
- input_shape = (2, 1, 4, 4)
- else:
- input_shape = (2, 4, 4, 1)
- x = np.arange(2 * 4 * 4).reshape(input_shape)
- generator = iutils.BatchSequence([x, np.zeros(x.shape[0])], batch_size=x.shape[0])
- # Simple model
- model = keras.models.Sequential([
- keras.layers.Flatten(input_shape=x.shape[1:]),
- keras.layers.Dense(1, use_bias=False),
- ])
- weights = np.arange(4 * 4 * 1).reshape((4 * 4, 1))
- model.layers[-1].set_weights([weights])
- model.compile(loss='mean_squared_error', optimizer='sgd')
- expected_output = np.array([[1240.], [3160.]])
- assert np.all(np.isclose(model.predict(x), expected_output))
- # Analyzer
- analyzer = innvestigate.create_analyzer("gradient",
- model,
- postprocess="abs")
- # Run perturbation analysis
- perturbation = innvestigate.tools.perturbate.Perturbation("zeros", region_shape=(2, 2), in_place=False)
- perturbation_analysis = innvestigate.tools.perturbate.PerturbationAnalysis(analyzer, model, generator, perturbation, recompute_analysis=False,
- steps=3, regions_per_step=1, verbose=False)
- scores = perturbation_analysis.compute_perturbation_analysis()
- expected_scores = np.array([5761600.0, 1654564.0, 182672.0, 21284.0])
- assert np.all(np.isclose(scores, expected_scores))
- @pytest.mark.fast
- @pytest.mark.precommit
- def test_fast__Perturbation():
- if keras.backend.image_data_format() == "channels_first":
- input_shape = (1, 1, 4, 4)
- else:
- input_shape = (1, 4, 4, 1)
- x = np.arange(1 * 4 * 4).reshape(input_shape)
- perturbation = innvestigate.tools.perturbate.Perturbation("zeros", region_shape=(2, 2), in_place=False)
- analysis = np.zeros((4, 4))
- analysis[:2, 2:] = 1
- analysis[2:, :2] = 2
- analysis[2:, 2:] = 3
- analysis = analysis.reshape(input_shape)
- if keras.backend.image_data_format() == "channels_last":
- x = np.moveaxis(x, 3, 1)
- analysis = np.moveaxis(analysis, 3, 1)
- analysis = perturbation.reduce_function(analysis, axis=1, keepdims=True)
- aggregated_regions = perturbation.aggregate_regions(analysis)
- assert np.all(np.isclose(aggregated_regions[0, 0, :, :], np.array([[0, 1], [2, 3]])))
- ranks = perturbation.compute_region_ordering(aggregated_regions)
- assert np.all(np.isclose(ranks[0, 0, :, :], np.array([[3, 2], [1, 0]])))
- perturbation_mask_regions = perturbation.compute_perturbation_mask(ranks, 1)
- assert np.all(perturbation_mask_regions == np.array([[0, 0], [0, 1]]))
- perturbation_mask_regions = perturbation.compute_perturbation_mask(ranks, 4)
- assert np.all(perturbation_mask_regions == np.array([[1, 1], [1, 1]]))
- perturbation_mask_regions = perturbation.compute_perturbation_mask(ranks, 0)
- assert np.all(perturbation_mask_regions == np.array([[0, 0], [0, 0]]))
|