123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660 |
- # Get Python six functionality:
- from __future__ import\
- absolute_import, print_function, division, unicode_literals
- from builtins import range, zip
- ###############################################################################
- ###############################################################################
- ###############################################################################
- import keras
- import keras.backend as K
- import keras.constraints
- import keras.layers
- import keras.regularizers
- from keras.utils import conv_utils
- import numpy as np
- from . import utils as iutils
- from .utils.keras import backend as iK
- __all__ = [
- "Constant",
- "Zero",
- "One",
- "ZerosLike",
- "OnesLike",
- "AsFloatX",
- "FiniteCheck",
- "Gradient",
- "GradientWRT",
- "Min",
- "Max",
- "Greater",
- "Less",
- "GreaterThanZero",
- "LessThanZero",
- "GreaterEqual",
- "LessEqual",
- "GreaterEqualThanZero",
- "LessEqualThanZero",
- "Sum",
- "Mean",
- "CountNonZero",
- "Identity",
- "Abs",
- "Square",
- "Clip",
- "Project",
- "Print",
- "Transpose",
- "Dot",
- "SafeDivide",
- "Repeat",
- "Reshape",
- "MultiplyWithLinspace",
- "TestPhaseGaussianNoise",
- "ExtractConv2DPatches",
- "RunningMeans",
- "Broadcast",
- "Gather",
- "GatherND",
- ]
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def Constant(c, reference=None):
- if reference is None:
- return K.constant(c)
- else:
- dtype = K.dtype(reference)
- return K.constant(np.dtype(dtype)(c), dtype=dtype)
- def Zero(reference=None):
- return Constant(0, reference=reference)
- def One(reference=None):
- return Constant(1, reference=reference)
- class ZerosLike(keras.layers.Layer):
- def call(self, x):
- return [K.zeros_like(tmp) for tmp in iutils.to_list(x)]
- class OnesLike(keras.layers.Layer):
- def call(self, x):
- return [K.ones_like(tmp) for tmp in iutils.to_list(x)]
- class AsFloatX(keras.layers.Layer):
- def call(self, x):
- return [iK.to_floatx(tmp) for tmp in iutils.to_list(x)]
- class FiniteCheck(keras.layers.Layer):
- def call(self, x):
- return [K.sum(iK.to_floatx(iK.is_not_finite(tmp)))
- for tmp in iutils.to_list(x)]
- ###############################################################################
- ###############################################################################
- ###############################################################################
- class Gradient(keras.layers.Layer):
- "Returns gradient of sum(output), expects inputs+[output,]."
- def call(self, x):
- inputs, output = x[:-1], x[-1]
- return K.gradients(K.sum(output), inputs)
- def compute_output_shape(self, input_shapes):
- return input_shapes[:-1]
- class GradientWRT(keras.layers.Layer):
- "Returns gradient wrt to another layer and given gradient,"
- " expects inputs+[output,]."
- def __init__(self, n_inputs, mask=None, **kwargs):
- self.n_inputs = n_inputs
- self.mask = mask
- super(GradientWRT, self).__init__(**kwargs)
- def call(self, x):
- assert isinstance(x, (list, tuple))
- Xs, tmp_Ys = x[:self.n_inputs], x[self.n_inputs:]
- assert len(tmp_Ys) % 2 == 0
- len_Ys = len(tmp_Ys) // 2
- Ys, known_Ys = tmp_Ys[:len_Ys], tmp_Ys[len_Ys:]
- ret = iK.gradients(Xs, Ys, known_Ys)
- if self.mask is not None:
- ret = [x for c, x in zip(self.mask, ret) if c]
- self.__workaround__len_ret = len(ret)
- return ret
- def compute_output_shape(self, input_shapes):
- if self.mask is None:
- return input_shapes[:self.n_inputs]
- else:
- return [x for c, x in zip(self.mask, input_shapes[:self.n_inputs])
- if c]
- # todo: remove once keras is fixed.
- # this is a workaround for cases when
- # wrapper and skip connections are used together.
- # bring the fix into keras and remove once
- # keras is patched.
- def compute_mask(self, inputs, mask=None):
- """Computes an output mask tensor.
- # Arguments
- inputs: Tensor or list of tensors.
- mask: Tensor or list of tensors.
- # Returns
- None or a tensor (or list of tensors,
- one per output tensor of the layer).
- """
- if not self.supports_masking:
- if mask is not None:
- if isinstance(mask, list):
- if any(m is not None for m in mask):
- raise TypeError('Layer ' + self.name +
- ' does not support masking, '
- 'but was passed an input_mask: ' +
- str(mask))
- else:
- raise TypeError('Layer ' + self.name +
- ' does not support masking, '
- 'but was passed an input_mask: ' +
- str(mask))
- # masking not explicitly supported: return None as mask
- # this is the workaround for model.run_internal_graph.
- # it is required that there as many masks as outputs:
- return [None for _ in range(self.__workaround__len_ret)]
- # if masking is explicitly supported, by default
- # carry over the input mask
- return mask
- ###############################################################################
- ###############################################################################
- ###############################################################################
- class _Reduce(keras.layers.Layer):
- def __init__(self, axis=-1, keepdims=False, *args, **kwargs):
- self.axis = axis
- self.keepdims = keepdims
- super(_Reduce, self).__init__(*args, **kwargs)
- def call(self, x):
- return self._apply_reduce(x, axis=self.axis, keepdims=self.keepdims)
- def compute_output_shape(self, input_shape):
- if self.axis is None:
- if self.keepdims is False:
- return (1,)
- else:
- return tuple(np.ones_like(input_shape))
- else:
- axes = np.arange(len(input_shape))
- if self.keepdims is False:
- for i in iutils.to_list(self.axis):
- axes = np.delete(axes, i, 0)
- else:
- for i in iutils.to_list(self.axis):
- axes[i] = 1
- return tuple([idx
- for i, idx in enumerate(input_shape)
- if i in axes])
- def _apply_reduce(self, x, axis, keepdims):
- raise NotImplementedError()
- class Min(_Reduce):
- def _apply_reduce(self, x, axis, keepdims):
- return K.min(x, axis=axis, keepdims=keepdims)
- class Max(_Reduce):
- def _apply_reduce(self, x, axis, keepdims):
- return K.max(x, axis=axis, keepdims=keepdims)
- class Sum(_Reduce):
- def _apply_reduce(self, x, axis, keepdims):
- return K.sum(x, axis=axis, keepdims=keepdims)
- class Mean(_Reduce):
- def _apply_reduce(self, x, axis, keepdims):
- return K.mean(x, axis=axis, keepdims=keepdims)
- class CountNonZero(_Reduce):
- def _apply_reduce(self, x, axis, keepdims):
- return K.sum(iK.to_floatx(K.not_equal(x, K.constant(0))),
- axis=axis,
- keepdims=keepdims)
- ###############################################################################
- ###############################################################################
- ###############################################################################
- class _Map(keras.layers.Layer):
- def call(self, x):
- if isinstance(x, list) and len(x) == 1:
- x = x[0]
- return self._apply_map(x)
- def compute_output_shape(self, input_shape):
- return input_shape
- def _apply_map(self, x):
- raise NotImplementedError()
- class Identity(_Map):
- def _apply_map(self, x):
- return K.identity(x)
- class Abs(_Map):
- def _apply_map(self, x):
- return K.abs(x)
- class Square(_Map):
- def _apply_map(self, x):
- return K.square(x)
- class Clip(_Map):
- def __init__(self, min_value, max_value):
- self._min_value = min_value
- self._max_value = max_value
- return super(Clip, self).__init__()
- def _apply_map(self, x):
- return K.clip(x, self._min_value, self._max_value)
- class Project(_Map):
- def __init__(self, output_range=False, input_is_postive_only=False):
- self._output_range = output_range
- self._input_is_positive_only = input_is_postive_only
- return super(Project, self).__init__()
- def _apply_map(self, x):
- def safe_divide(a, b):
- return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * 1)
- dims = K.int_shape(x)
- n_dim = len(dims)
- axes = tuple(range(1, n_dim))
- if len(axes) == 1:
- # TODO(albermax): this is only the case when the dimension in this
- # axis is 1, fix this.
- # Cannot reduce
- return x
- absmax = K.max(K.abs(x),
- axis=axes,
- keepdims=True)
- x = safe_divide(x, absmax)
- if self._output_range not in (False, True): # True = (-1, +1)
- output_range = self._output_range
- if not self._input_is_positive_only:
- x = (x+1) / 2
- x = K.clip(x, 0, 1)
- x = output_range[0] + (x * (output_range[1]-output_range[0]))
- else:
- x = K.clip(x, -1, 1)
- return x
- class Print(_Map):
- def _apply_map(self, x):
- return K.print_tensor(x)
- ###############################################################################
- ###############################################################################
- ###############################################################################
- class Greater(keras.layers.Layer):
- def call(self, x):
- a, b = x
- return K.greater(a, b)
- class Less(keras.layers.Layer):
- def call(self, x):
- a, b = x
- return K.less(a, b)
- class GreaterThanZero(keras.layers.Layer):
- def call(self, x):
- return K.greater(x, K.constant(0))
- class LessThanZero(keras.layers.Layer):
- def call(self, x):
- return K.less(x, K.constant(0))
- class GreaterEqual(keras.layers.Layer):
- def call(self, x):
- a, b = x
- return K.greater_equal(a, b)
- class LessEqual(keras.layers.Layer):
- def call(self, x):
- a, b = x
- return K.less_equal(a, b)
- class GreaterEqualThanZero(keras.layers.Layer):
- def call(self, x):
- return K.greater_equal(x, K.constant(0))
- class LessEqualThanZero(keras.layers.Layer):
- def call(self, x):
- return K.less_equal(x, K.constant(0))
- class Transpose(keras.layers.Layer):
- def __init__(self, axes=None, **kwargs):
- self._axes = axes
- super(Transpose, self).__init__(**kwargs)
- def call(self, x):
- if self._axes is None:
- return K.transpose(x)
- else:
- return K.permute_dimensions(x, self._axes)
- def compute_output_shape(self, input_shape):
- if self._axes is None:
- return input_shape[::-1]
- else:
- return tuple(np.asarray(input_shape)[list(self._axes)])
- class Dot(keras.layers.Layer):
- def call(self, x):
- a, b = x
- return K.dot(a, b)
- def compute_output_shape(self, input_shapes):
- return (input_shapes[0][0], input_shapes[1][1])
- class Divide(keras.layers.Layer):
- def call(self, x):
- a, b = x
- return a / b
- def compute_output_shape(self, input_shapes):
- return input_shapes[0]
- class SafeDivide(keras.layers.Layer):
- def __init__(self, *args, **kwargs):
- factor = kwargs.pop("factor", None)
- if factor is None:
- factor = K.epsilon()
- self._factor = factor
- return super(SafeDivide, self).__init__(*args, **kwargs)
- def call(self, x):
- a, b = x
- return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * self._factor)
- def compute_output_shape(self, input_shapes):
- return input_shapes[0]
- ###############################################################################
- ###############################################################################
- ###############################################################################
- class Repeat(keras.layers.Layer):
- def __init__(self, n, axis, *args, **kwargs):
- self._n = n
- self._axis = axis
- return super(Repeat, self).__init__(*args, **kwargs)
- def call(self, x):
- return K.repeat_elements(x, self._n, self._axis)
- def compute_output_shape(self, input_shapes):
- if isinstance(input_shapes, list):
- input_shape = input_shapes[0]
- else:
- input_shape = input_shapes
- if input_shape[0] is None:
- return input_shape
- else:
- return (input_shape[0]*self._n,)+input_shape[1:]
- class Reshape(keras.layers.Layer):
- def __init__(self, shape, *args, **kwargs):
- self._shape = shape
- return super(Reshape, self).__init__(*args, **kwargs)
- def call(self, x):
- return K.reshape(x, self._shape)
- def compute_output_shape(self, input_shapes):
- return tuple(x if x >= 0 else None for x in self._shape)
- class MultiplyWithLinspace(keras.layers.Layer):
- def __init__(self, start, end, n=1, axis=-1, *args, **kwargs):
- self._start = start
- self._end = end
- self._n = n
- self._axis = axis
- return super(MultiplyWithLinspace, self).__init__(*args, **kwargs)
- def call(self, x):
- linspace = (self._start +
- (self._end-self._start) *
- (K.arange(self._n, dtype=K.floatx())/self._n))
- # Make broadcastable.
- shape = np.ones(len(K.int_shape(x)))
- shape[self._axis] = self._n
- linspace = K.reshape(linspace, shape)
- return x * linspace
- def compute_output_shape(self, input_shapes):
- ret = input_shapes[:]
- ret = (ret[:self._axis] +
- (max(self._n, ret[self._axis]),) +
- ret[self._axis+1:])
- return ret
- class TestPhaseGaussianNoise(keras.layers.GaussianNoise):
- def call(self, inputs):
- # Always add Gaussian noise!
- return super(TestPhaseGaussianNoise, self).call(inputs, training=True)
- class ExtractConv2DPatches(keras.layers.Layer):
- def __init__(self,
- kernel_shape,
- depth,
- strides,
- rates,
- padding,
- *args,
- **kwargs):
- self._kernel_shape = kernel_shape
- self._depth = depth
- self._strides = strides
- self._rates = rates
- self._padding = padding
- return super(ExtractConv2DPatches, self).__init__(*args, **kwargs)
- def call(self, x):
- return iK.extract_conv2d_patches(x,
- self._kernel_shape,
- self._strides,
- self._rates,
- self._padding)
- def compute_output_shape(self, input_shapes):
- if K.image_data_format() == 'channels_first':
- space = input_shapes[2:]
- new_space = []
- for i in range(len(space)):
- new_dim = conv_utils.conv_output_length(
- space[i],
- self._kernel_shape[i],
- padding=self._padding,
- stride=self._strides[i],
- dilation=self._rates[i])
- new_space.append(new_dim)
- if K.image_data_format() == 'channels_last':
- space = input_shapes[1:-1]
- new_space = []
- for i in range(len(space)):
- new_dim = conv_utils.conv_output_length(
- space[i],
- self._kernel_shape[i],
- padding=self._padding,
- stride=self._strides[i],
- dilation=self._rates[i])
- new_space.append(new_dim)
- return ((input_shapes[0],) +
- tuple(new_space) +
- (np.product(self._kernel_shape) * self._depth,))
- class RunningMeans(keras.layers.Layer):
- def __init__(self, *args, **kwargs):
- self.stateful = True
- super(RunningMeans, self).__init__(*args, **kwargs)
- def build(self, input_shapes):
- means_shape, counts_shape = input_shapes
- self.means = self.add_weight(shape=means_shape,
- initializer="zeros",
- name="means",
- trainable=False)
- self.counts = self.add_weight(shape=counts_shape,
- initializer="zeros",
- name="counts",
- trainable=False)
- self.built = True
- def call(self, x):
- def safe_divide(a, b):
- return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * 1)
- means, counts = x
- new_counts = counts + self.counts
- # If new_means are not used for the model output,
- # the following part of the code will be executed after
- # self.counts is updated, therefore we cannot use it
- # hereafter.
- factor_new = safe_divide(counts, new_counts)
- factor_old = K.ones_like(factor_new) - factor_new
- new_means = self.means * factor_old + means * factor_new
- # Update state.
- self.add_update([
- K.update(self.means, new_means),
- K.update(self.counts, new_counts),
- ])
- return [new_means, new_counts]
- def compute_output_shape(self, input_shapes):
- return input_shapes
- class Broadcast(keras.layers.Layer):
- def call(self, x):
- target_shapped, x = x
- return target_shapped * 0 + x
- def compute_output_shape(self, input_shapes):
- return input_shapes[0]
- class Gather(keras.layers.Layer):
- def call(self, inputs):
- x, index = inputs
- return iK.gather(x, 1, index)
- def compute_output_shape(self, input_shapes):
- return (input_shapes[0][0], input_shapes[1][0])+input_shapes[0][2:]
- class GatherND(keras.layers.Layer):
- def call(self, inputs):
- x, indices = inputs
- return iK.gather_nd(x, indices)
- def compute_output_shape(self, input_shapes):
- return input_shapes[1][:2]+input_shapes[0][2:]
|