misc.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. from .base import AnalyzerNetworkBase
  8. from .. import layers as ilayers
  9. from .. import utils as iutils
  10. __all__ = ["Random", "Input"]
  11. ###############################################################################
  12. ###############################################################################
  13. ###############################################################################
  14. class Input(AnalyzerNetworkBase):
  15. """Returns the input.
  16. Returns the input as analysis.
  17. :param model: A Keras model.
  18. """
  19. def _create_analysis(self, model, stop_analysis_at_tensors=[]):
  20. tensors_to_analyze = [x for x in iutils.to_list(model.inputs)
  21. if x not in stop_analysis_at_tensors]
  22. return [ilayers.Identity()(x) for x in tensors_to_analyze]
  23. class Random(AnalyzerNetworkBase):
  24. """Returns noise.
  25. Returns the Gaussian noise as analysis.
  26. :param model: A Keras model.
  27. :param stddev: The standard deviation of the noise.
  28. """
  29. def __init__(self, model, stddev=1, **kwargs):
  30. self._stddev = stddev
  31. super(Random, self).__init__(model, **kwargs)
  32. def _create_analysis(self, model, stop_analysis_at_tensors=[]):
  33. noise = ilayers.TestPhaseGaussianNoise(stddev=self._stddev)
  34. tensors_to_analyze = [x for x in iutils.to_list(model.inputs)
  35. if x not in stop_analysis_at_tensors]
  36. return [noise(x) for x in tensors_to_analyze]
  37. def _get_state(self):
  38. state = super(Random, self)._get_state()
  39. state.update({"stddev": self._stddev})
  40. return state
  41. @classmethod
  42. def _state_to_kwargs(clazz, state):
  43. stddev = state.pop("stddev")
  44. kwargs = super(Random, clazz)._state_to_kwargs(state)
  45. kwargs.update({"stddev": stddev})
  46. return kwargs