layer.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import range
  5. ###############################################################################
  6. ###############################################################################
  7. ###############################################################################
  8. import keras.models
  9. import keras.engine.topology
  10. from ... import utils as iutils
  11. __all__ = [
  12. "TestAnalysisHelper",
  13. ]
  14. ###############################################################################
  15. ###############################################################################
  16. ###############################################################################
  17. class TestAnalysisHelper(object):
  18. def __init__(self, model, analyzer, weights=None):
  19. """ Helper class for retrieving output and analysis in test cases.
  20. :param model: A Keras layer object or a list of layer objects.
  21. In this case a sequntial model will be build. The first layer
  22. must have set input_shape or batch_input_shape.
  23. Alternatively a tuple with input and output tensors, in which
  24. case the keras modle api will be used.
  25. :param analyzer: Either an analyzer class or a function
  26. that takes a keras model and returns an analyzer.
  27. :param weights: After creating the model set the given weights.
  28. """
  29. if isinstance(model, keras.engine.topology.Layer):
  30. model = [model]
  31. if isinstance(model, list):
  32. self._model = keras.models.Sequential(model)
  33. else:
  34. self._model = keras.models.Model(*model)
  35. self._input_shapes = iutils.to_list(self._model.input_shape)
  36. if weights is not None:
  37. self._model.set_weights(weights)
  38. self._analyzer = analyzer(self._model)
  39. @property
  40. def weights(self):
  41. return self._model.get_weights()
  42. def run(self, inputs):
  43. """Runs the model given the inputs.
  44. :return: Tuple with model output and analyzer output.
  45. """
  46. return_list = True
  47. if not isinstance(inputs, list):
  48. return_list = False
  49. inputs = iutils.to_list(inputs)
  50. augmented = []
  51. for i in range(len(inputs)):
  52. if len(inputs[i].shape) == len(self._input_shapes[i])-1:
  53. # Augment by batch axis.
  54. augmented.append(i)
  55. inputs[i] = inputs[i].reshape((1,)+inputs[i].shape)
  56. outputs = iutils.to_list(self._model.predict_on_batch(inputs))
  57. analysis = iutils.to_list(self._analyzer.analyze(inputs))
  58. for i in augmented:
  59. # Remove batch axis.
  60. outputs[i] = outputs[i][0]
  61. analysis[i] = analysis[i][0]
  62. if return_list:
  63. return outputs, analysis
  64. else:
  65. return outputs[0], analysis[0]