# Get Python six functionality:
from __future__ import\
    absolute_import, print_function, division, unicode_literals
from builtins import zip


###############################################################################
###############################################################################
###############################################################################


import keras.models
import keras.backend as K
import numpy as np


from . import base
from .. import layers as ilayers
from .. import utils as iutils
from ..utils import keras as kutils


__all__ = [
    "WrapperBase",
    "AugmentReduceBase",
    "GaussianSmoother",
    "PathIntegrator",
]


###############################################################################
###############################################################################
###############################################################################


class WrapperBase(base.AnalyzerBase):
    """Interface for wrappers around analyzers

    This class is the basic interface for wrappers around analyzers.

    :param subanalyzer: The analyzer to be wrapped.
    """

    def __init__(self, subanalyzer, *args, **kwargs):
        self._subanalyzer = subanalyzer
        model = None

        super(WrapperBase, self).__init__(model,
                                          *args, **kwargs)

    def analyze(self, *args, **kwargs):
        return self._subanalyzer.analyze(*args, **kwargs)

    def _get_state(self):
        sa_class_name, sa_state = self._subanalyzer.save()

        state = {}
        state.update({"subanalyzer_class_name": sa_class_name})
        state.update({"subanalyzer_state": sa_state})
        return state

    @classmethod
    def _state_to_kwargs(clazz, state):
        sa_class_name = state.pop("subanalyzer_class_name")
        sa_state = state.pop("subanalyzer_state")
        assert len(state) == 0

        subanalyzer = base.AnalyzerBase.load(sa_class_name, sa_state)
        kwargs = {"subanalyzer": subanalyzer}
        return kwargs


###############################################################################
###############################################################################
###############################################################################


class AugmentReduceBase(WrapperBase):
    """Interface for wrappers that augment the input and reduce the analysis.

    This class is an interface for wrappers that:
    * augment the input to the analyzer by creating new samples.
    * reduce the returned analysis to match the initial input shapes.

    :param subanalyzer: The analyzer to be wrapped.
    :param augment_by_n: Number of samples to create.
    """

    def __init__(self, subanalyzer, *args, **kwargs):
        self._augment_by_n = kwargs.pop("augment_by_n", 2)
        self._neuron_selection_mode = subanalyzer._neuron_selection_mode

        if self._neuron_selection_mode != "all":
            # TODO: this is not transparent, find a better way.
            subanalyzer._neuron_selection_mode = "index"
        super(AugmentReduceBase, self).__init__(subanalyzer,
                                                *args, **kwargs)

        if isinstance(self._subanalyzer, base.AnalyzerNetworkBase):
            # Take the keras analyzer model and
            # add augment and reduce functionality.
            self._keras_based_augment_reduce = True
        else:
            raise NotImplementedError("Keras-based subanalyzer required.")

    def create_analyzer_model(self):
        if not self._keras_based_augment_reduce:
            return

        self._subanalyzer.create_analyzer_model()

        if self._subanalyzer._n_debug_output > 0:
            raise Exception("No debug output at subanalyzer is supported.")

        model = self._subanalyzer._analyzer_model
        if None in model.input_shape[1:]:
            raise ValueError("The input shape for the model needs "
                             "to be fully specified (except the batch axis). "
                             "Model input shape is: %s" % (model.input_shape,))

        inputs = model.inputs[:self._subanalyzer._n_data_input]
        extra_inputs = model.inputs[self._subanalyzer._n_data_input:]
        # todo: check this, index seems not right.
        #outputs = model.outputs[:self._subanalyzer._n_data_input]
        extra_outputs = model.outputs[self._subanalyzer._n_data_input:]

        if len(extra_outputs) > 0:
            raise Exception("No extra output is allowed "
                            "with this wrapper.")

        new_inputs = iutils.to_list(self._augment(inputs))
        # print(type(new_inputs), type(extra_inputs))
        tmp = iutils.to_list(model(new_inputs+extra_inputs))
        new_outputs = iutils.to_list(self._reduce(tmp))
        new_constant_inputs = self._keras_get_constant_inputs()

        new_model = keras.models.Model(
            inputs=inputs+extra_inputs+new_constant_inputs,
            outputs=new_outputs+extra_outputs)
        self._subanalyzer._analyzer_model = new_model

    def analyze(self, X, *args, **kwargs):
        if self._keras_based_augment_reduce is True:
            if not hasattr(self._subanalyzer, "_analyzer_model"):
                self.create_analyzer_model()

            ns_mode = self._neuron_selection_mode
            if ns_mode in ["max_activation", "index"]:
                if ns_mode == "max_activation":
                    tmp = self._subanalyzer._model.predict(X)
                    indices = np.argmax(tmp, axis=1)
                else:
                    if len(args):
                        args = list(args)
                        indices = args.pop(0)
                    else:
                        indices = kwargs.pop("neuron_selection")

                # broadcast to match augmented samples.
                indices = np.repeat(indices, self._augment_by_n)

                kwargs["neuron_selection"] = indices
            return self._subanalyzer.analyze(X, *args, **kwargs)
        else:
            raise DeprecationWarning("Not supported anymore.")

    def _keras_get_constant_inputs(self):
        return list()

    def _augment(self, X):
        repeat = ilayers.Repeat(self._augment_by_n, axis=0)
        return [repeat(x) for x in iutils.to_list(X)]

    def _reduce(self, X):
        X_shape = [K.int_shape(x) for x in iutils.to_list(X)]
        reshape = [ilayers.Reshape((-1, self._augment_by_n)+shape[1:])
                   for shape in X_shape]
        mean = ilayers.Mean(axis=1)

        return [mean(reshape_x(x)) for x, reshape_x in zip(X, reshape)]

    def _get_state(self):
        if self._neuron_selection_mode != "all":
            # TODO: this is not transparent, find a better way.
            # revert the tempering in __init__
            tmp = self._neuron_selection_mode
            self._subanalyzer._neuron_selection_mode = tmp
        state = super(AugmentReduceBase, self)._get_state()
        state.update({"augment_by_n": self._augment_by_n})
        return state

    @classmethod
    def _state_to_kwargs(clazz, state):
        augment_by_n = state.pop("augment_by_n")
        kwargs = super(AugmentReduceBase, clazz)._state_to_kwargs(state)
        kwargs.update({"augment_by_n": augment_by_n})
        return kwargs


###############################################################################
###############################################################################
###############################################################################


class GaussianSmoother(AugmentReduceBase):
    """Wrapper that adds noise to the input and averages over analyses

    This wrapper creates new samples by adding Gaussian noise
    to the input. The final analysis is an average of the returned analyses.

    :param subanalyzer: The analyzer to be wrapped.
    :param noise_scale: The stddev of the applied noise.
    :param augment_by_n: Number of samples to create.
    """

    def __init__(self, subanalyzer, *args, **kwargs):
        self._noise_scale = kwargs.pop("noise_scale", 1)
        super(GaussianSmoother, self).__init__(subanalyzer,
                                               *args, **kwargs)

    def _augment(self, X):
        tmp = super(GaussianSmoother, self)._augment(X)
        noise = ilayers.TestPhaseGaussianNoise(stddev=self._noise_scale)
        return [noise(x) for x in tmp]

    def _get_state(self):
        state = super(GaussianSmoother, self)._get_state()
        state.update({"noise_scale": self._noise_scale})
        return state

    @classmethod
    def _state_to_kwargs(clazz, state):
        noise_scale = state.pop("noise_scale")
        kwargs = super(GaussianSmoother, clazz)._state_to_kwargs(state)
        kwargs.update({"noise_scale": noise_scale})
        return kwargs


###############################################################################
###############################################################################
###############################################################################


class PathIntegrator(AugmentReduceBase):
    """Integrated the analysis along a path

    This analyzer:
    * creates a path from input to reference image.
    * creates steps number of intermediate inputs and
      crests an analysis for them.
    * sums the analyses and multiplies them with the input-reference_input.

    This wrapper is used to implement Integrated Gradients.
    We refer to the paper for further information.

    :param subanalyzer: The analyzer to be wrapped.
    :param steps: Number of steps for integration.
    :param reference_inputs: The reference input.
    """

    def __init__(self, subanalyzer, *args, **kwargs):
        steps = kwargs.pop("steps", 16)
        self._reference_inputs = kwargs.pop("reference_inputs", 0)
        self._keras_constant_inputs = None
        super(PathIntegrator, self).__init__(subanalyzer,
                                             *args,
                                             augment_by_n=steps,
                                             **kwargs)

    def _keras_set_constant_inputs(self, inputs):
        tmp = [K.variable(x) for x in inputs]
        self._keras_constant_inputs = [
            keras.layers.Input(tensor=x, shape=x.shape[1:])
            for x in tmp]

    def _keras_get_constant_inputs(self):
        return self._keras_constant_inputs

    def _compute_difference(self, X):
        if self._keras_constant_inputs is None:
            tmp = kutils.broadcast_np_tensors_to_keras_tensors(
                X, self._reference_inputs)
            self._keras_set_constant_inputs(tmp)

        reference_inputs = self._keras_get_constant_inputs()
        return [keras.layers.Subtract()([x, ri])
                for x, ri in zip(X, reference_inputs)]

    def _augment(self, X):
        tmp = super(PathIntegrator, self)._augment(X)
        tmp = [ilayers.Reshape((-1, self._augment_by_n)+K.int_shape(x)[1:])(x)
               for x in tmp]

        difference = self._compute_difference(X)
        self._keras_difference = difference
        # Make broadcastable.
        difference = [ilayers.Reshape((-1, 1)+K.int_shape(x)[1:])(x)
                      for x in difference]

        # Compute path steps.
        multiply_with_linspace = ilayers.MultiplyWithLinspace(
            0, 1,
            n=self._augment_by_n,
            axis=1)
        path_steps = [multiply_with_linspace(d) for d in difference]

        reference_inputs = self._keras_get_constant_inputs()
        ret = [keras.layers.Add()([x, p]) for x, p in zip(reference_inputs, path_steps)]
        ret = [ilayers.Reshape((-1,)+K.int_shape(x)[2:])(x) for x in ret]
        return ret

    def _reduce(self, X):
        tmp = super(PathIntegrator, self)._reduce(X)
        difference = self._keras_difference
        del self._keras_difference

        return [keras.layers.Multiply()([x, d])
                for x, d in zip(tmp, difference)]

    def _get_state(self):
        state = super(PathIntegrator, self)._get_state()
        state.update({"reference_inputs": self._reference_inputs})
        return state

    @classmethod
    def _state_to_kwargs(clazz, state):
        reference_inputs = state.pop("reference_inputs")
        kwargs = super(PathIntegrator, clazz)._state_to_kwargs(state)
        kwargs.update({"reference_inputs": reference_inputs})
        # We use steps instead.
        kwargs.update({"steps": kwargs["augment_by_n"]})
        del kwargs["augment_by_n"]
        return kwargs