1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- # Get Python six functionality:
- from __future__ import\
- absolute_import, print_function, division, unicode_literals
- from builtins import range
- ###############################################################################
- ###############################################################################
- ###############################################################################
- import keras.models
- import keras.engine.topology
- from ... import utils as iutils
- __all__ = [
- "TestAnalysisHelper",
- ]
- ###############################################################################
- ###############################################################################
- ###############################################################################
- class TestAnalysisHelper(object):
- def __init__(self, model, analyzer, weights=None):
- """ Helper class for retrieving output and analysis in test cases.
-
- :param model: A Keras layer object or a list of layer objects.
- In this case a sequntial model will be build. The first layer
- must have set input_shape or batch_input_shape.
- Alternatively a tuple with input and output tensors, in which
- case the keras modle api will be used.
- :param analyzer: Either an analyzer class or a function
- that takes a keras model and returns an analyzer.
- :param weights: After creating the model set the given weights.
- """
- if isinstance(model, keras.engine.topology.Layer):
- model = [model]
- if isinstance(model, list):
- self._model = keras.models.Sequential(model)
- else:
- self._model = keras.models.Model(*model)
- self._input_shapes = iutils.to_list(self._model.input_shape)
- if weights is not None:
- self._model.set_weights(weights)
- self._analyzer = analyzer(self._model)
- @property
- def weights(self):
- return self._model.get_weights()
-
- def run(self, inputs):
- """Runs the model given the inputs.
- :return: Tuple with model output and analyzer output.
- """
- return_list = True
- if not isinstance(inputs, list):
- return_list = False
- inputs = iutils.to_list(inputs)
- augmented = []
- for i in range(len(inputs)):
- if len(inputs[i].shape) == len(self._input_shapes[i])-1:
- # Augment by batch axis.
- augmented.append(i)
- inputs[i] = inputs[i].reshape((1,)+inputs[i].shape)
- outputs = iutils.to_list(self._model.predict_on_batch(inputs))
- analysis = iutils.to_list(self._analyzer.analyze(inputs))
- for i in augmented:
- # Remove batch axis.
- outputs[i] = outputs[i][0]
- analysis[i] = analysis[i][0]
-
- if return_list:
- return outputs, analysis
- else:
- return outputs[0], analysis[0]
|