import tensorflow as tf from keras import backend as K from keras import initializers from keras import regularizers from keras import constraints from keras.engine import InputSpec from keras.layers import Conv3D #from keras.layers.convolutional import Conv3D class SeparableConv3D (Conv3D): """A custom implementation of 3D Separable Convolutions The layer takes activations with the shape batch x N_1 x N_2 x N_3 x N (batch is the batch_size, N_k is the k-th dimension, N - number of channels) First each of the N channels is convolved separately producing a single output feature map, i.e. depth multiplier is 1 (depthwise procedure) Then we apply 1x1x1 convolutions with N output channels on the output of the depthwise procedure (pointwise step) Module has only been used (and tested) with a depth multiplier of 1 but support for higher depth multipliers is built-in """ def __init__(self, filters, kernel_size, strides=(1, 1, 1), padding='valid', data_format=None, depth_multiplier=1, activation=None, use_bias=True, depthwise_initializer='glorot_uniform', pointwise_initializer='glorot_uniform', bias_initializer='zeros', depthwise_regularizer=None, pointwise_regularizer=None, bias_regularizer=None, activity_regularizer=None, depthwise_constraint=None, pointwise_constraint=None, bias_constraint=None, **kwargs): super(SeparableConv3D, self).__init__( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, activation=activation, use_bias=use_bias, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, bias_constraint=bias_constraint, **kwargs) self.depth_multiplier = depth_multiplier self.depthwise_initializer = initializers.get(depthwise_initializer) self.pointwise_initializer = initializers.get(pointwise_initializer) self.depthwise_regularizer = regularizers.get(depthwise_regularizer) self.pointwise_regularizer = regularizers.get(pointwise_regularizer) self.depthwise_constraint = constraints.get(depthwise_constraint) self.pointwise_constraint = constraints.get(pointwise_constraint) def build(self, input_shape): if len(input_shape) < 5: raise ValueError('Inputs to `SeparableConv3D` should have rank 5. ' 'Received input shape:', str(input_shape)) if self.data_format == 'channels_first': self.channel_axis = 1 else: self.channel_axis = 4 if input_shape[self.channel_axis] is None: raise ValueError('The channel dimension of the inputs to ' '`SeparableConv3D` ' 'should be defined. Found `None`.') self.input_dim = int(input_shape[self.channel_axis]) depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], self.input_dim, self.depth_multiplier) pointwise_kernel_shape = (1, 1, 1, self.depth_multiplier * self.input_dim, self.filters) self.depthwise_kernel = self.add_weight( shape=depthwise_kernel_shape, initializer=self.depthwise_initializer, name='depthwise_kernel', regularizer=self.depthwise_regularizer, constraint=self.depthwise_constraint) self.pointwise_kernel = self.add_weight( shape=pointwise_kernel_shape, initializer=self.pointwise_initializer, name='pointwise_kernel', regularizer=self.pointwise_regularizer, constraint=self.pointwise_constraint) if self.use_bias: self.bias = self.add_weight(shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None # Set input spec. self.input_spec = InputSpec(ndim=5, axes={self.channel_axis: self.input_dim}) self.built = True def call(self, inputs): depthwise_conv_on_filters = [] sliced_inputs = [sliced for sliced in tf.split(inputs, self.input_dim, self.channel_axis)] sliced_kernels = [sliced for sliced in tf.split(self.depthwise_kernel, self.input_dim, 3)] #See https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/slicing_and_joining for i in range(self.input_dim): depthwise_conv_on_filters.append ( K.conv3d(sliced_inputs[i], sliced_kernels[i], strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) ) depthwise_conv = K.concatenate(depthwise_conv_on_filters) pointwise_conv = K.conv3d(depthwise_conv, self.pointwise_kernel, strides = (1, 1, 1), padding = self.padding, data_format = self.data_format, dilation_rate=self.dilation_rate) outputs = pointwise_conv if self.bias: outputs = K.bias_add( outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs def get_config(self): config = super(SeparableConv3D, self).get_config() config.pop('kernel_initializer') config.pop('kernel_regularizer') config.pop('kernel_constraint') config['depth_multiplier'] = self.depth_multiplier config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer) config['pointwise_initializer'] = initializers.serialize(self.pointwise_initializer) config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer) config['pointwise_regularizer'] = regularizers.serialize(self.pointwise_regularizer) config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint) config['pointwise_constraint'] = constraints.serialize(self.pointwise_constraint) return config