# Get Python six functionality: from __future__ import\ absolute_import, print_function, division, unicode_literals from builtins import range, zip ############################################################################### ############################################################################### ############################################################################### import keras import keras.backend as K import keras.constraints import keras.layers import keras.regularizers from keras.utils import conv_utils import numpy as np from . import utils as iutils from .utils.keras import backend as iK __all__ = [ "Constant", "Zero", "One", "ZerosLike", "OnesLike", "AsFloatX", "FiniteCheck", "Gradient", "GradientWRT", "Min", "Max", "Greater", "Less", "GreaterThanZero", "LessThanZero", "GreaterEqual", "LessEqual", "GreaterEqualThanZero", "LessEqualThanZero", "Sum", "Mean", "CountNonZero", "Identity", "Abs", "Square", "Clip", "Project", "Print", "Transpose", "Dot", "SafeDivide", "Repeat", "Reshape", "MultiplyWithLinspace", "TestPhaseGaussianNoise", "ExtractConv2DPatches", "RunningMeans", "Broadcast", "Gather", "GatherND", ] ############################################################################### ############################################################################### ############################################################################### def Constant(c, reference=None): if reference is None: return K.constant(c) else: dtype = K.dtype(reference) return K.constant(np.dtype(dtype)(c), dtype=dtype) def Zero(reference=None): return Constant(0, reference=reference) def One(reference=None): return Constant(1, reference=reference) class ZerosLike(keras.layers.Layer): def call(self, x): return [K.zeros_like(tmp) for tmp in iutils.to_list(x)] class OnesLike(keras.layers.Layer): def call(self, x): return [K.ones_like(tmp) for tmp in iutils.to_list(x)] class AsFloatX(keras.layers.Layer): def call(self, x): return [iK.to_floatx(tmp) for tmp in iutils.to_list(x)] class FiniteCheck(keras.layers.Layer): def call(self, x): return [K.sum(iK.to_floatx(iK.is_not_finite(tmp))) for tmp in iutils.to_list(x)] ############################################################################### ############################################################################### ############################################################################### class Gradient(keras.layers.Layer): "Returns gradient of sum(output), expects inputs+[output,]." def call(self, x): inputs, output = x[:-1], x[-1] return K.gradients(K.sum(output), inputs) def compute_output_shape(self, input_shapes): return input_shapes[:-1] class GradientWRT(keras.layers.Layer): "Returns gradient wrt to another layer and given gradient," " expects inputs+[output,]." def __init__(self, n_inputs, mask=None, **kwargs): self.n_inputs = n_inputs self.mask = mask super(GradientWRT, self).__init__(**kwargs) def call(self, x): assert isinstance(x, (list, tuple)) Xs, tmp_Ys = x[:self.n_inputs], x[self.n_inputs:] assert len(tmp_Ys) % 2 == 0 len_Ys = len(tmp_Ys) // 2 Ys, known_Ys = tmp_Ys[:len_Ys], tmp_Ys[len_Ys:] ret = iK.gradients(Xs, Ys, known_Ys) if self.mask is not None: ret = [x for c, x in zip(self.mask, ret) if c] self.__workaround__len_ret = len(ret) return ret def compute_output_shape(self, input_shapes): if self.mask is None: return input_shapes[:self.n_inputs] else: return [x for c, x in zip(self.mask, input_shapes[:self.n_inputs]) if c] # todo: remove once keras is fixed. # this is a workaround for cases when # wrapper and skip connections are used together. # bring the fix into keras and remove once # keras is patched. def compute_mask(self, inputs, mask=None): """Computes an output mask tensor. # Arguments inputs: Tensor or list of tensors. mask: Tensor or list of tensors. # Returns None or a tensor (or list of tensors, one per output tensor of the layer). """ if not self.supports_masking: if mask is not None: if isinstance(mask, list): if any(m is not None for m in mask): raise TypeError('Layer ' + self.name + ' does not support masking, ' 'but was passed an input_mask: ' + str(mask)) else: raise TypeError('Layer ' + self.name + ' does not support masking, ' 'but was passed an input_mask: ' + str(mask)) # masking not explicitly supported: return None as mask # this is the workaround for model.run_internal_graph. # it is required that there as many masks as outputs: return [None for _ in range(self.__workaround__len_ret)] # if masking is explicitly supported, by default # carry over the input mask return mask ############################################################################### ############################################################################### ############################################################################### class _Reduce(keras.layers.Layer): def __init__(self, axis=-1, keepdims=False, *args, **kwargs): self.axis = axis self.keepdims = keepdims super(_Reduce, self).__init__(*args, **kwargs) def call(self, x): return self._apply_reduce(x, axis=self.axis, keepdims=self.keepdims) def compute_output_shape(self, input_shape): if self.axis is None: if self.keepdims is False: return (1,) else: return tuple(np.ones_like(input_shape)) else: axes = np.arange(len(input_shape)) if self.keepdims is False: for i in iutils.to_list(self.axis): axes = np.delete(axes, i, 0) else: for i in iutils.to_list(self.axis): axes[i] = 1 return tuple([idx for i, idx in enumerate(input_shape) if i in axes]) def _apply_reduce(self, x, axis, keepdims): raise NotImplementedError() class Min(_Reduce): def _apply_reduce(self, x, axis, keepdims): return K.min(x, axis=axis, keepdims=keepdims) class Max(_Reduce): def _apply_reduce(self, x, axis, keepdims): return K.max(x, axis=axis, keepdims=keepdims) class Sum(_Reduce): def _apply_reduce(self, x, axis, keepdims): return K.sum(x, axis=axis, keepdims=keepdims) class Mean(_Reduce): def _apply_reduce(self, x, axis, keepdims): return K.mean(x, axis=axis, keepdims=keepdims) class CountNonZero(_Reduce): def _apply_reduce(self, x, axis, keepdims): return K.sum(iK.to_floatx(K.not_equal(x, K.constant(0))), axis=axis, keepdims=keepdims) ############################################################################### ############################################################################### ############################################################################### class _Map(keras.layers.Layer): def call(self, x): if isinstance(x, list) and len(x) == 1: x = x[0] return self._apply_map(x) def compute_output_shape(self, input_shape): return input_shape def _apply_map(self, x): raise NotImplementedError() class Identity(_Map): def _apply_map(self, x): return K.identity(x) class Abs(_Map): def _apply_map(self, x): return K.abs(x) class Square(_Map): def _apply_map(self, x): return K.square(x) class Clip(_Map): def __init__(self, min_value, max_value): self._min_value = min_value self._max_value = max_value return super(Clip, self).__init__() def _apply_map(self, x): return K.clip(x, self._min_value, self._max_value) class Project(_Map): def __init__(self, output_range=False, input_is_postive_only=False): self._output_range = output_range self._input_is_positive_only = input_is_postive_only return super(Project, self).__init__() def _apply_map(self, x): def safe_divide(a, b): return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * 1) dims = K.int_shape(x) n_dim = len(dims) axes = tuple(range(1, n_dim)) if len(axes) == 1: # TODO(albermax): this is only the case when the dimension in this # axis is 1, fix this. # Cannot reduce return x absmax = K.max(K.abs(x), axis=axes, keepdims=True) x = safe_divide(x, absmax) if self._output_range not in (False, True): # True = (-1, +1) output_range = self._output_range if not self._input_is_positive_only: x = (x+1) / 2 x = K.clip(x, 0, 1) x = output_range[0] + (x * (output_range[1]-output_range[0])) else: x = K.clip(x, -1, 1) return x class Print(_Map): def _apply_map(self, x): return K.print_tensor(x) ############################################################################### ############################################################################### ############################################################################### class Greater(keras.layers.Layer): def call(self, x): a, b = x return K.greater(a, b) class Less(keras.layers.Layer): def call(self, x): a, b = x return K.less(a, b) class GreaterThanZero(keras.layers.Layer): def call(self, x): return K.greater(x, K.constant(0)) class LessThanZero(keras.layers.Layer): def call(self, x): return K.less(x, K.constant(0)) class GreaterEqual(keras.layers.Layer): def call(self, x): a, b = x return K.greater_equal(a, b) class LessEqual(keras.layers.Layer): def call(self, x): a, b = x return K.less_equal(a, b) class GreaterEqualThanZero(keras.layers.Layer): def call(self, x): return K.greater_equal(x, K.constant(0)) class LessEqualThanZero(keras.layers.Layer): def call(self, x): return K.less_equal(x, K.constant(0)) class Transpose(keras.layers.Layer): def __init__(self, axes=None, **kwargs): self._axes = axes super(Transpose, self).__init__(**kwargs) def call(self, x): if self._axes is None: return K.transpose(x) else: return K.permute_dimensions(x, self._axes) def compute_output_shape(self, input_shape): if self._axes is None: return input_shape[::-1] else: return tuple(np.asarray(input_shape)[list(self._axes)]) class Dot(keras.layers.Layer): def call(self, x): a, b = x return K.dot(a, b) def compute_output_shape(self, input_shapes): return (input_shapes[0][0], input_shapes[1][1]) class Divide(keras.layers.Layer): def call(self, x): a, b = x return a / b def compute_output_shape(self, input_shapes): return input_shapes[0] class SafeDivide(keras.layers.Layer): def __init__(self, *args, **kwargs): factor = kwargs.pop("factor", None) if factor is None: factor = K.epsilon() self._factor = factor return super(SafeDivide, self).__init__(*args, **kwargs) def call(self, x): a, b = x return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * self._factor) def compute_output_shape(self, input_shapes): return input_shapes[0] ############################################################################### ############################################################################### ############################################################################### class Repeat(keras.layers.Layer): def __init__(self, n, axis, *args, **kwargs): self._n = n self._axis = axis return super(Repeat, self).__init__(*args, **kwargs) def call(self, x): return K.repeat_elements(x, self._n, self._axis) def compute_output_shape(self, input_shapes): if isinstance(input_shapes, list): input_shape = input_shapes[0] else: input_shape = input_shapes if input_shape[0] is None: return input_shape else: return (input_shape[0]*self._n,)+input_shape[1:] class Reshape(keras.layers.Layer): def __init__(self, shape, *args, **kwargs): self._shape = shape return super(Reshape, self).__init__(*args, **kwargs) def call(self, x): return K.reshape(x, self._shape) def compute_output_shape(self, input_shapes): return tuple(x if x >= 0 else None for x in self._shape) class MultiplyWithLinspace(keras.layers.Layer): def __init__(self, start, end, n=1, axis=-1, *args, **kwargs): self._start = start self._end = end self._n = n self._axis = axis return super(MultiplyWithLinspace, self).__init__(*args, **kwargs) def call(self, x): linspace = (self._start + (self._end-self._start) * (K.arange(self._n, dtype=K.floatx())/self._n)) # Make broadcastable. shape = np.ones(len(K.int_shape(x))) shape[self._axis] = self._n linspace = K.reshape(linspace, shape) return x * linspace def compute_output_shape(self, input_shapes): ret = input_shapes[:] ret = (ret[:self._axis] + (max(self._n, ret[self._axis]),) + ret[self._axis+1:]) return ret class TestPhaseGaussianNoise(keras.layers.GaussianNoise): def call(self, inputs): # Always add Gaussian noise! return super(TestPhaseGaussianNoise, self).call(inputs, training=True) class ExtractConv2DPatches(keras.layers.Layer): def __init__(self, kernel_shape, depth, strides, rates, padding, *args, **kwargs): self._kernel_shape = kernel_shape self._depth = depth self._strides = strides self._rates = rates self._padding = padding return super(ExtractConv2DPatches, self).__init__(*args, **kwargs) def call(self, x): return iK.extract_conv2d_patches(x, self._kernel_shape, self._strides, self._rates, self._padding) def compute_output_shape(self, input_shapes): if K.image_data_format() == 'channels_first': space = input_shapes[2:] new_space = [] for i in range(len(space)): new_dim = conv_utils.conv_output_length( space[i], self._kernel_shape[i], padding=self._padding, stride=self._strides[i], dilation=self._rates[i]) new_space.append(new_dim) if K.image_data_format() == 'channels_last': space = input_shapes[1:-1] new_space = [] for i in range(len(space)): new_dim = conv_utils.conv_output_length( space[i], self._kernel_shape[i], padding=self._padding, stride=self._strides[i], dilation=self._rates[i]) new_space.append(new_dim) return ((input_shapes[0],) + tuple(new_space) + (np.product(self._kernel_shape) * self._depth,)) class RunningMeans(keras.layers.Layer): def __init__(self, *args, **kwargs): self.stateful = True super(RunningMeans, self).__init__(*args, **kwargs) def build(self, input_shapes): means_shape, counts_shape = input_shapes self.means = self.add_weight(shape=means_shape, initializer="zeros", name="means", trainable=False) self.counts = self.add_weight(shape=counts_shape, initializer="zeros", name="counts", trainable=False) self.built = True def call(self, x): def safe_divide(a, b): return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * 1) means, counts = x new_counts = counts + self.counts # If new_means are not used for the model output, # the following part of the code will be executed after # self.counts is updated, therefore we cannot use it # hereafter. factor_new = safe_divide(counts, new_counts) factor_old = K.ones_like(factor_new) - factor_new new_means = self.means * factor_old + means * factor_new # Update state. self.add_update([ K.update(self.means, new_means), K.update(self.counts, new_counts), ]) return [new_means, new_counts] def compute_output_shape(self, input_shapes): return input_shapes class Broadcast(keras.layers.Layer): def call(self, x): target_shapped, x = x return target_shapped * 0 + x def compute_output_shape(self, input_shapes): return input_shapes[0] class Gather(keras.layers.Layer): def call(self, inputs): x, index = inputs return iK.gather(x, 1, index) def compute_output_shape(self, input_shapes): return (input_shapes[0][0], input_shapes[1][0])+input_shapes[0][2:] class GatherND(keras.layers.Layer): def call(self, inputs): x, indices = inputs return iK.gather_nd(x, indices) def compute_output_shape(self, input_shapes): return input_shapes[1][:2]+input_shapes[0][2:]