sepconv3D.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import tensorflow as tf
  2. from keras import backend as K
  3. from keras import initializers
  4. from keras import regularizers
  5. from keras import constraints
  6. from keras.engine import InputSpec
  7. from keras.layers import Conv3D
  8. #from keras.layers.convolutional import Conv3D
  9. class SeparableConv3D (Conv3D):
  10. """A custom implementation of 3D Separable Convolutions
  11. The layer takes activations with the shape batch x N_1 x N_2 x N_3 x N
  12. (batch is the batch_size, N_k is the k-th dimension, N - number of channels)
  13. First each of the N channels is convolved separately producing a single
  14. output feature map, i.e. depth multiplier is 1 (depthwise procedure)
  15. Then we apply 1x1x1 convolutions with N output channels on the output of the
  16. depthwise procedure (pointwise step)
  17. Module has only been used (and tested) with a depth multiplier of 1 but
  18. support for higher depth multipliers is built-in
  19. """
  20. def __init__(self, filters,
  21. kernel_size,
  22. strides=(1, 1, 1),
  23. padding='valid',
  24. data_format=None,
  25. depth_multiplier=1,
  26. activation=None,
  27. use_bias=True,
  28. depthwise_initializer='glorot_uniform',
  29. pointwise_initializer='glorot_uniform',
  30. bias_initializer='zeros',
  31. depthwise_regularizer=None,
  32. pointwise_regularizer=None,
  33. bias_regularizer=None,
  34. activity_regularizer=None,
  35. depthwise_constraint=None,
  36. pointwise_constraint=None,
  37. bias_constraint=None,
  38. **kwargs):
  39. super(SeparableConv3D, self).__init__(
  40. filters=filters,
  41. kernel_size=kernel_size,
  42. strides=strides,
  43. padding=padding,
  44. data_format=data_format,
  45. activation=activation,
  46. use_bias=use_bias,
  47. bias_regularizer=bias_regularizer,
  48. activity_regularizer=activity_regularizer,
  49. bias_constraint=bias_constraint,
  50. **kwargs)
  51. self.depth_multiplier = depth_multiplier
  52. self.depthwise_initializer = initializers.get(depthwise_initializer)
  53. self.pointwise_initializer = initializers.get(pointwise_initializer)
  54. self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
  55. self.pointwise_regularizer = regularizers.get(pointwise_regularizer)
  56. self.depthwise_constraint = constraints.get(depthwise_constraint)
  57. self.pointwise_constraint = constraints.get(pointwise_constraint)
  58. def build(self, input_shape):
  59. if len(input_shape) < 5:
  60. raise ValueError('Inputs to `SeparableConv3D` should have rank 5. '
  61. 'Received input shape:', str(input_shape))
  62. if self.data_format == 'channels_first':
  63. self.channel_axis = 1
  64. else:
  65. self.channel_axis = 4
  66. if input_shape[self.channel_axis] is None:
  67. raise ValueError('The channel dimension of the inputs to '
  68. '`SeparableConv3D` '
  69. 'should be defined. Found `None`.')
  70. self.input_dim = int(input_shape[self.channel_axis])
  71. depthwise_kernel_shape = (self.kernel_size[0],
  72. self.kernel_size[1],
  73. self.kernel_size[2],
  74. self.input_dim,
  75. self.depth_multiplier)
  76. pointwise_kernel_shape = (1, 1, 1,
  77. self.depth_multiplier * self.input_dim,
  78. self.filters)
  79. self.depthwise_kernel = self.add_weight(
  80. shape=depthwise_kernel_shape,
  81. initializer=self.depthwise_initializer,
  82. name='depthwise_kernel',
  83. regularizer=self.depthwise_regularizer,
  84. constraint=self.depthwise_constraint)
  85. self.pointwise_kernel = self.add_weight(
  86. shape=pointwise_kernel_shape,
  87. initializer=self.pointwise_initializer,
  88. name='pointwise_kernel',
  89. regularizer=self.pointwise_regularizer,
  90. constraint=self.pointwise_constraint)
  91. if self.use_bias:
  92. self.bias = self.add_weight(shape=(self.filters,),
  93. initializer=self.bias_initializer,
  94. name='bias',
  95. regularizer=self.bias_regularizer,
  96. constraint=self.bias_constraint)
  97. else:
  98. self.bias = None
  99. # Set input spec.
  100. self.input_spec = InputSpec(ndim=5, axes={self.channel_axis: self.input_dim})
  101. self.built = True
  102. def call(self, inputs):
  103. depthwise_conv_on_filters = []
  104. sliced_inputs = [sliced for sliced in tf.split(inputs, self.input_dim, self.channel_axis)]
  105. sliced_kernels = [sliced for sliced in tf.split(self.depthwise_kernel, self.input_dim, 3)]
  106. #See https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/slicing_and_joining
  107. for i in range(self.input_dim):
  108. depthwise_conv_on_filters.append ( K.conv3d(sliced_inputs[i],
  109. sliced_kernels[i],
  110. strides=self.strides,
  111. padding=self.padding,
  112. data_format=self.data_format,
  113. dilation_rate=self.dilation_rate) )
  114. depthwise_conv = K.concatenate(depthwise_conv_on_filters)
  115. pointwise_conv = K.conv3d(depthwise_conv, self.pointwise_kernel,
  116. strides = (1, 1, 1), padding = self.padding,
  117. data_format = self.data_format,
  118. dilation_rate=self.dilation_rate)
  119. outputs = pointwise_conv
  120. if self.bias:
  121. outputs = K.bias_add(
  122. outputs,
  123. self.bias,
  124. data_format=self.data_format)
  125. if self.activation is not None:
  126. return self.activation(outputs)
  127. return outputs
  128. def get_config(self):
  129. config = super(SeparableConv3D, self).get_config()
  130. config.pop('kernel_initializer')
  131. config.pop('kernel_regularizer')
  132. config.pop('kernel_constraint')
  133. config['depth_multiplier'] = self.depth_multiplier
  134. config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer)
  135. config['pointwise_initializer'] = initializers.serialize(self.pointwise_initializer)
  136. config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer)
  137. config['pointwise_regularizer'] = regularizers.serialize(self.pointwise_regularizer)
  138. config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint)
  139. config['pointwise_constraint'] = constraints.serialize(self.pointwise_constraint)
  140. return config