1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- # Get Python six functionality:
- from __future__ import\
- absolute_import, print_function, division, unicode_literals
- ###############################################################################
- ###############################################################################
- ###############################################################################
- from .base import AnalyzerNetworkBase
- from .. import layers as ilayers
- from .. import utils as iutils
- __all__ = ["Random", "Input"]
- ###############################################################################
- ###############################################################################
- ###############################################################################
- class Input(AnalyzerNetworkBase):
- """Returns the input.
- Returns the input as analysis.
- :param model: A Keras model.
- """
- def _create_analysis(self, model, stop_analysis_at_tensors=[]):
- tensors_to_analyze = [x for x in iutils.to_list(model.inputs)
- if x not in stop_analysis_at_tensors]
- return [ilayers.Identity()(x) for x in tensors_to_analyze]
- class Random(AnalyzerNetworkBase):
- """Returns noise.
- Returns the Gaussian noise as analysis.
- :param model: A Keras model.
- :param stddev: The standard deviation of the noise.
- """
- def __init__(self, model, stddev=1, **kwargs):
- self._stddev = stddev
- super(Random, self).__init__(model, **kwargs)
- def _create_analysis(self, model, stop_analysis_at_tensors=[]):
- noise = ilayers.TestPhaseGaussianNoise(stddev=self._stddev)
- tensors_to_analyze = [x for x in iutils.to_list(model.inputs)
- if x not in stop_analysis_at_tensors]
- return [noise(x) for x in tensors_to_analyze]
- def _get_state(self):
- state = super(Random, self)._get_state()
- state.update({"stddev": self._stddev})
- return state
- @classmethod
- def _state_to_kwargs(clazz, state):
- stddev = state.pop("stddev")
- kwargs = super(Random, clazz)._state_to_kwargs(state)
- kwargs.update({"stddev": stddev})
- return kwargs
|