layers.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import range, zip
  5. ###############################################################################
  6. ###############################################################################
  7. ###############################################################################
  8. import keras
  9. import keras.backend as K
  10. import keras.constraints
  11. import keras.layers
  12. import keras.regularizers
  13. from keras.utils import conv_utils
  14. import numpy as np
  15. from . import utils as iutils
  16. from .utils.keras import backend as iK
  17. __all__ = [
  18. "Constant",
  19. "Zero",
  20. "One",
  21. "ZerosLike",
  22. "OnesLike",
  23. "AsFloatX",
  24. "FiniteCheck",
  25. "Gradient",
  26. "GradientWRT",
  27. "Min",
  28. "Max",
  29. "Greater",
  30. "Less",
  31. "GreaterThanZero",
  32. "LessThanZero",
  33. "GreaterEqual",
  34. "LessEqual",
  35. "GreaterEqualThanZero",
  36. "LessEqualThanZero",
  37. "Sum",
  38. "Mean",
  39. "CountNonZero",
  40. "Identity",
  41. "Abs",
  42. "Square",
  43. "Clip",
  44. "Project",
  45. "Print",
  46. "Transpose",
  47. "Dot",
  48. "SafeDivide",
  49. "Repeat",
  50. "Reshape",
  51. "MultiplyWithLinspace",
  52. "TestPhaseGaussianNoise",
  53. "ExtractConv2DPatches",
  54. "RunningMeans",
  55. "Broadcast",
  56. "Gather",
  57. "GatherND",
  58. ]
  59. ###############################################################################
  60. ###############################################################################
  61. ###############################################################################
  62. def Constant(c, reference=None):
  63. if reference is None:
  64. return K.constant(c)
  65. else:
  66. dtype = K.dtype(reference)
  67. return K.constant(np.dtype(dtype)(c), dtype=dtype)
  68. def Zero(reference=None):
  69. return Constant(0, reference=reference)
  70. def One(reference=None):
  71. return Constant(1, reference=reference)
  72. class ZerosLike(keras.layers.Layer):
  73. def call(self, x):
  74. return [K.zeros_like(tmp) for tmp in iutils.to_list(x)]
  75. class OnesLike(keras.layers.Layer):
  76. def call(self, x):
  77. return [K.ones_like(tmp) for tmp in iutils.to_list(x)]
  78. class AsFloatX(keras.layers.Layer):
  79. def call(self, x):
  80. return [iK.to_floatx(tmp) for tmp in iutils.to_list(x)]
  81. class FiniteCheck(keras.layers.Layer):
  82. def call(self, x):
  83. return [K.sum(iK.to_floatx(iK.is_not_finite(tmp)))
  84. for tmp in iutils.to_list(x)]
  85. ###############################################################################
  86. ###############################################################################
  87. ###############################################################################
  88. class Gradient(keras.layers.Layer):
  89. "Returns gradient of sum(output), expects inputs+[output,]."
  90. def call(self, x):
  91. inputs, output = x[:-1], x[-1]
  92. return K.gradients(K.sum(output), inputs)
  93. def compute_output_shape(self, input_shapes):
  94. return input_shapes[:-1]
  95. class GradientWRT(keras.layers.Layer):
  96. "Returns gradient wrt to another layer and given gradient,"
  97. " expects inputs+[output,]."
  98. def __init__(self, n_inputs, mask=None, **kwargs):
  99. self.n_inputs = n_inputs
  100. self.mask = mask
  101. super(GradientWRT, self).__init__(**kwargs)
  102. def call(self, x):
  103. assert isinstance(x, (list, tuple))
  104. Xs, tmp_Ys = x[:self.n_inputs], x[self.n_inputs:]
  105. assert len(tmp_Ys) % 2 == 0
  106. len_Ys = len(tmp_Ys) // 2
  107. Ys, known_Ys = tmp_Ys[:len_Ys], tmp_Ys[len_Ys:]
  108. ret = iK.gradients(Xs, Ys, known_Ys)
  109. if self.mask is not None:
  110. ret = [x for c, x in zip(self.mask, ret) if c]
  111. self.__workaround__len_ret = len(ret)
  112. return ret
  113. def compute_output_shape(self, input_shapes):
  114. if self.mask is None:
  115. return input_shapes[:self.n_inputs]
  116. else:
  117. return [x for c, x in zip(self.mask, input_shapes[:self.n_inputs])
  118. if c]
  119. # todo: remove once keras is fixed.
  120. # this is a workaround for cases when
  121. # wrapper and skip connections are used together.
  122. # bring the fix into keras and remove once
  123. # keras is patched.
  124. def compute_mask(self, inputs, mask=None):
  125. """Computes an output mask tensor.
  126. # Arguments
  127. inputs: Tensor or list of tensors.
  128. mask: Tensor or list of tensors.
  129. # Returns
  130. None or a tensor (or list of tensors,
  131. one per output tensor of the layer).
  132. """
  133. if not self.supports_masking:
  134. if mask is not None:
  135. if isinstance(mask, list):
  136. if any(m is not None for m in mask):
  137. raise TypeError('Layer ' + self.name +
  138. ' does not support masking, '
  139. 'but was passed an input_mask: ' +
  140. str(mask))
  141. else:
  142. raise TypeError('Layer ' + self.name +
  143. ' does not support masking, '
  144. 'but was passed an input_mask: ' +
  145. str(mask))
  146. # masking not explicitly supported: return None as mask
  147. # this is the workaround for model.run_internal_graph.
  148. # it is required that there as many masks as outputs:
  149. return [None for _ in range(self.__workaround__len_ret)]
  150. # if masking is explicitly supported, by default
  151. # carry over the input mask
  152. return mask
  153. ###############################################################################
  154. ###############################################################################
  155. ###############################################################################
  156. class _Reduce(keras.layers.Layer):
  157. def __init__(self, axis=-1, keepdims=False, *args, **kwargs):
  158. self.axis = axis
  159. self.keepdims = keepdims
  160. super(_Reduce, self).__init__(*args, **kwargs)
  161. def call(self, x):
  162. return self._apply_reduce(x, axis=self.axis, keepdims=self.keepdims)
  163. def compute_output_shape(self, input_shape):
  164. if self.axis is None:
  165. if self.keepdims is False:
  166. return (1,)
  167. else:
  168. return tuple(np.ones_like(input_shape))
  169. else:
  170. axes = np.arange(len(input_shape))
  171. if self.keepdims is False:
  172. for i in iutils.to_list(self.axis):
  173. axes = np.delete(axes, i, 0)
  174. else:
  175. for i in iutils.to_list(self.axis):
  176. axes[i] = 1
  177. return tuple([idx
  178. for i, idx in enumerate(input_shape)
  179. if i in axes])
  180. def _apply_reduce(self, x, axis, keepdims):
  181. raise NotImplementedError()
  182. class Min(_Reduce):
  183. def _apply_reduce(self, x, axis, keepdims):
  184. return K.min(x, axis=axis, keepdims=keepdims)
  185. class Max(_Reduce):
  186. def _apply_reduce(self, x, axis, keepdims):
  187. return K.max(x, axis=axis, keepdims=keepdims)
  188. class Sum(_Reduce):
  189. def _apply_reduce(self, x, axis, keepdims):
  190. return K.sum(x, axis=axis, keepdims=keepdims)
  191. class Mean(_Reduce):
  192. def _apply_reduce(self, x, axis, keepdims):
  193. return K.mean(x, axis=axis, keepdims=keepdims)
  194. class CountNonZero(_Reduce):
  195. def _apply_reduce(self, x, axis, keepdims):
  196. return K.sum(iK.to_floatx(K.not_equal(x, K.constant(0))),
  197. axis=axis,
  198. keepdims=keepdims)
  199. ###############################################################################
  200. ###############################################################################
  201. ###############################################################################
  202. class _Map(keras.layers.Layer):
  203. def call(self, x):
  204. if isinstance(x, list) and len(x) == 1:
  205. x = x[0]
  206. return self._apply_map(x)
  207. def compute_output_shape(self, input_shape):
  208. return input_shape
  209. def _apply_map(self, x):
  210. raise NotImplementedError()
  211. class Identity(_Map):
  212. def _apply_map(self, x):
  213. return K.identity(x)
  214. class Abs(_Map):
  215. def _apply_map(self, x):
  216. return K.abs(x)
  217. class Square(_Map):
  218. def _apply_map(self, x):
  219. return K.square(x)
  220. class Clip(_Map):
  221. def __init__(self, min_value, max_value):
  222. self._min_value = min_value
  223. self._max_value = max_value
  224. return super(Clip, self).__init__()
  225. def _apply_map(self, x):
  226. return K.clip(x, self._min_value, self._max_value)
  227. class Project(_Map):
  228. def __init__(self, output_range=False, input_is_postive_only=False):
  229. self._output_range = output_range
  230. self._input_is_positive_only = input_is_postive_only
  231. return super(Project, self).__init__()
  232. def _apply_map(self, x):
  233. def safe_divide(a, b):
  234. return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * 1)
  235. dims = K.int_shape(x)
  236. n_dim = len(dims)
  237. axes = tuple(range(1, n_dim))
  238. if len(axes) == 1:
  239. # TODO(albermax): this is only the case when the dimension in this
  240. # axis is 1, fix this.
  241. # Cannot reduce
  242. return x
  243. absmax = K.max(K.abs(x),
  244. axis=axes,
  245. keepdims=True)
  246. x = safe_divide(x, absmax)
  247. if self._output_range not in (False, True): # True = (-1, +1)
  248. output_range = self._output_range
  249. if not self._input_is_positive_only:
  250. x = (x+1) / 2
  251. x = K.clip(x, 0, 1)
  252. x = output_range[0] + (x * (output_range[1]-output_range[0]))
  253. else:
  254. x = K.clip(x, -1, 1)
  255. return x
  256. class Print(_Map):
  257. def _apply_map(self, x):
  258. return K.print_tensor(x)
  259. ###############################################################################
  260. ###############################################################################
  261. ###############################################################################
  262. class Greater(keras.layers.Layer):
  263. def call(self, x):
  264. a, b = x
  265. return K.greater(a, b)
  266. class Less(keras.layers.Layer):
  267. def call(self, x):
  268. a, b = x
  269. return K.less(a, b)
  270. class GreaterThanZero(keras.layers.Layer):
  271. def call(self, x):
  272. return K.greater(x, K.constant(0))
  273. class LessThanZero(keras.layers.Layer):
  274. def call(self, x):
  275. return K.less(x, K.constant(0))
  276. class GreaterEqual(keras.layers.Layer):
  277. def call(self, x):
  278. a, b = x
  279. return K.greater_equal(a, b)
  280. class LessEqual(keras.layers.Layer):
  281. def call(self, x):
  282. a, b = x
  283. return K.less_equal(a, b)
  284. class GreaterEqualThanZero(keras.layers.Layer):
  285. def call(self, x):
  286. return K.greater_equal(x, K.constant(0))
  287. class LessEqualThanZero(keras.layers.Layer):
  288. def call(self, x):
  289. return K.less_equal(x, K.constant(0))
  290. class Transpose(keras.layers.Layer):
  291. def __init__(self, axes=None, **kwargs):
  292. self._axes = axes
  293. super(Transpose, self).__init__(**kwargs)
  294. def call(self, x):
  295. if self._axes is None:
  296. return K.transpose(x)
  297. else:
  298. return K.permute_dimensions(x, self._axes)
  299. def compute_output_shape(self, input_shape):
  300. if self._axes is None:
  301. return input_shape[::-1]
  302. else:
  303. return tuple(np.asarray(input_shape)[list(self._axes)])
  304. class Dot(keras.layers.Layer):
  305. def call(self, x):
  306. a, b = x
  307. return K.dot(a, b)
  308. def compute_output_shape(self, input_shapes):
  309. return (input_shapes[0][0], input_shapes[1][1])
  310. class Divide(keras.layers.Layer):
  311. def call(self, x):
  312. a, b = x
  313. return a / b
  314. def compute_output_shape(self, input_shapes):
  315. return input_shapes[0]
  316. class SafeDivide(keras.layers.Layer):
  317. def __init__(self, *args, **kwargs):
  318. factor = kwargs.pop("factor", None)
  319. if factor is None:
  320. factor = K.epsilon()
  321. self._factor = factor
  322. return super(SafeDivide, self).__init__(*args, **kwargs)
  323. def call(self, x):
  324. a, b = x
  325. return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * self._factor)
  326. def compute_output_shape(self, input_shapes):
  327. return input_shapes[0]
  328. ###############################################################################
  329. ###############################################################################
  330. ###############################################################################
  331. class Repeat(keras.layers.Layer):
  332. def __init__(self, n, axis, *args, **kwargs):
  333. self._n = n
  334. self._axis = axis
  335. return super(Repeat, self).__init__(*args, **kwargs)
  336. def call(self, x):
  337. return K.repeat_elements(x, self._n, self._axis)
  338. def compute_output_shape(self, input_shapes):
  339. if isinstance(input_shapes, list):
  340. input_shape = input_shapes[0]
  341. else:
  342. input_shape = input_shapes
  343. if input_shape[0] is None:
  344. return input_shape
  345. else:
  346. return (input_shape[0]*self._n,)+input_shape[1:]
  347. class Reshape(keras.layers.Layer):
  348. def __init__(self, shape, *args, **kwargs):
  349. self._shape = shape
  350. return super(Reshape, self).__init__(*args, **kwargs)
  351. def call(self, x):
  352. return K.reshape(x, self._shape)
  353. def compute_output_shape(self, input_shapes):
  354. return tuple(x if x >= 0 else None for x in self._shape)
  355. class MultiplyWithLinspace(keras.layers.Layer):
  356. def __init__(self, start, end, n=1, axis=-1, *args, **kwargs):
  357. self._start = start
  358. self._end = end
  359. self._n = n
  360. self._axis = axis
  361. return super(MultiplyWithLinspace, self).__init__(*args, **kwargs)
  362. def call(self, x):
  363. linspace = (self._start +
  364. (self._end-self._start) *
  365. (K.arange(self._n, dtype=K.floatx())/self._n))
  366. # Make broadcastable.
  367. shape = np.ones(len(K.int_shape(x)))
  368. shape[self._axis] = self._n
  369. linspace = K.reshape(linspace, shape)
  370. return x * linspace
  371. def compute_output_shape(self, input_shapes):
  372. ret = input_shapes[:]
  373. ret = (ret[:self._axis] +
  374. (max(self._n, ret[self._axis]),) +
  375. ret[self._axis+1:])
  376. return ret
  377. class TestPhaseGaussianNoise(keras.layers.GaussianNoise):
  378. def call(self, inputs):
  379. # Always add Gaussian noise!
  380. return super(TestPhaseGaussianNoise, self).call(inputs, training=True)
  381. class ExtractConv2DPatches(keras.layers.Layer):
  382. def __init__(self,
  383. kernel_shape,
  384. depth,
  385. strides,
  386. rates,
  387. padding,
  388. *args,
  389. **kwargs):
  390. self._kernel_shape = kernel_shape
  391. self._depth = depth
  392. self._strides = strides
  393. self._rates = rates
  394. self._padding = padding
  395. return super(ExtractConv2DPatches, self).__init__(*args, **kwargs)
  396. def call(self, x):
  397. return iK.extract_conv2d_patches(x,
  398. self._kernel_shape,
  399. self._strides,
  400. self._rates,
  401. self._padding)
  402. def compute_output_shape(self, input_shapes):
  403. if K.image_data_format() == 'channels_first':
  404. space = input_shapes[2:]
  405. new_space = []
  406. for i in range(len(space)):
  407. new_dim = conv_utils.conv_output_length(
  408. space[i],
  409. self._kernel_shape[i],
  410. padding=self._padding,
  411. stride=self._strides[i],
  412. dilation=self._rates[i])
  413. new_space.append(new_dim)
  414. if K.image_data_format() == 'channels_last':
  415. space = input_shapes[1:-1]
  416. new_space = []
  417. for i in range(len(space)):
  418. new_dim = conv_utils.conv_output_length(
  419. space[i],
  420. self._kernel_shape[i],
  421. padding=self._padding,
  422. stride=self._strides[i],
  423. dilation=self._rates[i])
  424. new_space.append(new_dim)
  425. return ((input_shapes[0],) +
  426. tuple(new_space) +
  427. (np.product(self._kernel_shape) * self._depth,))
  428. class RunningMeans(keras.layers.Layer):
  429. def __init__(self, *args, **kwargs):
  430. self.stateful = True
  431. super(RunningMeans, self).__init__(*args, **kwargs)
  432. def build(self, input_shapes):
  433. means_shape, counts_shape = input_shapes
  434. self.means = self.add_weight(shape=means_shape,
  435. initializer="zeros",
  436. name="means",
  437. trainable=False)
  438. self.counts = self.add_weight(shape=counts_shape,
  439. initializer="zeros",
  440. name="counts",
  441. trainable=False)
  442. self.built = True
  443. def call(self, x):
  444. def safe_divide(a, b):
  445. return a / (b + iK.to_floatx(K.equal(b, K.constant(0))) * 1)
  446. means, counts = x
  447. new_counts = counts + self.counts
  448. # If new_means are not used for the model output,
  449. # the following part of the code will be executed after
  450. # self.counts is updated, therefore we cannot use it
  451. # hereafter.
  452. factor_new = safe_divide(counts, new_counts)
  453. factor_old = K.ones_like(factor_new) - factor_new
  454. new_means = self.means * factor_old + means * factor_new
  455. # Update state.
  456. self.add_update([
  457. K.update(self.means, new_means),
  458. K.update(self.counts, new_counts),
  459. ])
  460. return [new_means, new_counts]
  461. def compute_output_shape(self, input_shapes):
  462. return input_shapes
  463. class Broadcast(keras.layers.Layer):
  464. def call(self, x):
  465. target_shapped, x = x
  466. return target_shapped * 0 + x
  467. def compute_output_shape(self, input_shapes):
  468. return input_shapes[0]
  469. class Gather(keras.layers.Layer):
  470. def call(self, inputs):
  471. x, index = inputs
  472. return iK.gather(x, 1, index)
  473. def compute_output_shape(self, input_shapes):
  474. return (input_shapes[0][0], input_shapes[1][0])+input_shapes[0][2:]
  475. class GatherND(keras.layers.Layer):
  476. def call(self, inputs):
  477. x, indices = inputs
  478. return iK.gather_nd(x, indices)
  479. def compute_output_shape(self, input_shapes):
  480. return input_shapes[1][:2]+input_shapes[0][2:]