backend.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import zip
  5. ###############################################################################
  6. ###############################################################################
  7. ###############################################################################
  8. import keras.backend as K
  9. __all__ = [
  10. "to_floatx",
  11. "gradients",
  12. "is_not_finite",
  13. "extract_conv2d_patches",
  14. "gather",
  15. "gather_nd",
  16. ]
  17. ###############################################################################
  18. ###############################################################################
  19. ###############################################################################
  20. def to_floatx(x):
  21. return K.cast(x, K.floatx())
  22. ###############################################################################
  23. ###############################################################################
  24. ###############################################################################
  25. def gradients(Xs, Ys, known_Ys):
  26. """Partial derivatives
  27. Computes the partial derivatives between Ys and Xs and
  28. using the gradients for Ys known_Ys.
  29. :param Xs: List of input tensors.
  30. :param Ys: List of output tensors that depend on Xs.
  31. :param known_Ys: Gradients for Ys.
  32. :return: Gradients for Xs given known_Ys
  33. """
  34. backend = K.backend()
  35. if backend == "theano":
  36. # no global import => do not break if module is not present
  37. assert len(Ys) == 1
  38. import theano.gradient
  39. known_Ys = {k: v for k, v in zip(Ys, known_Ys)}
  40. # todo: check the stop gradient issue here!
  41. return theano.gradient.grad(K.sum(Ys[0]), Xs, known_grads=known_Ys)
  42. elif backend == "tensorflow":
  43. # no global import => do not break if module is not present
  44. import tensorflow
  45. return tensorflow.gradients(Ys, Xs,
  46. grad_ys=known_Ys,
  47. stop_gradients=Xs)
  48. else:
  49. # todo: add cntk
  50. raise NotImplementedError()
  51. ###############################################################################
  52. ###############################################################################
  53. ###############################################################################
  54. def is_not_finite(x):
  55. """Checks if tensor x is finite, if not throws an exception."""
  56. backend = K.backend()
  57. if backend == "theano":
  58. # no global import => do not break if module is not present
  59. import theano.tensor
  60. return theano.tensor.or_(theano.tensor.isnan(x),
  61. theano.tensor.isinf(x))
  62. elif backend == "tensorflow":
  63. # no global import => do not break if module is not present
  64. import tensorflow
  65. #x = tensorflow.check_numerics(x, "innvestigate - is_finite check")
  66. return tensorflow.logical_not(tensorflow.is_finite(x))
  67. else:
  68. # todo: add cntk
  69. raise NotImplementedError()
  70. ###############################################################################
  71. ###############################################################################
  72. ###############################################################################
  73. def extract_conv2d_patches(x, kernel_shape, strides, rates, padding):
  74. """Extracts conv2d patches like TF function extract_image_patches.
  75. :param x: Input image.
  76. :param kernel_shape: Shape of the Keras conv2d kernel.
  77. :param strides: Strides of the Keras conv2d layer.
  78. :param rates: Dilation rates of the Keras conv2d layer.
  79. :param padding: Paddings of the Keras conv2d layer.
  80. :return: The extracted patches.
  81. """
  82. backend = K.backend()
  83. if backend == "theano":
  84. # todo: add theano function.
  85. raise NotImplementedError()
  86. elif backend == "tensorflow":
  87. # no global import => do not break if module is not present
  88. import tensorflow
  89. if K.image_data_format() == "channels_first":
  90. x = K.permute_dimensions(x, (0, 2, 3, 1))
  91. kernel_shape = [1, kernel_shape[0], kernel_shape[1], 1]
  92. strides = [1, strides[0], strides[1], 1]
  93. rates = [1, rates[0], rates[1], 1]
  94. ret = tensorflow.extract_image_patches(x,
  95. kernel_shape,
  96. strides,
  97. rates,
  98. padding.upper())
  99. if K.image_data_format() == "channels_first":
  100. # todo: check if we need to permute again.xs
  101. pass
  102. return ret
  103. else:
  104. # todo: add cntk
  105. raise NotImplementedError()
  106. ###############################################################################
  107. ###############################################################################
  108. ###############################################################################
  109. def gather(x, axis, indices):
  110. """Works as TensorFlow's gather."""
  111. backend = K.backend()
  112. if backend == "theano":
  113. # todo: add theano function.
  114. raise NotImplementedError()
  115. elif backend == "tensorflow":
  116. # no global import => do not break if module is not present
  117. import tensorflow
  118. return tensorflow.gather(x, indices, axis=axis)
  119. else:
  120. # todo: add cntk
  121. raise NotImplementedError()
  122. def gather_nd(x, indices):
  123. """Works as TensorFlow's gather_nd."""
  124. backend = K.backend()
  125. if backend == "theano":
  126. # todo: add theano function.
  127. raise NotImplementedError()
  128. elif backend == "tensorflow":
  129. # no global import => do not break if module is not present
  130. import tensorflow
  131. return tensorflow.gather_nd(x, indices)
  132. else:
  133. # todo: add cntk
  134. raise NotImplementedError()