__init__.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import keras.backend as K
  8. import keras.utils
  9. import math
  10. __all__ = [
  11. "model_wo_softmax",
  12. "to_list",
  13. "BatchSequence",
  14. "TargetAugmentedSequence",
  15. "preprocess_images",
  16. "postprocess_images",
  17. ]
  18. ###############################################################################
  19. ###############################################################################
  20. ###############################################################################
  21. def model_wo_softmax(*args, **kwargs):
  22. # Break cyclic import
  23. from .keras.graph import model_wo_softmax
  24. return model_wo_softmax(*args, **kwargs)
  25. ###############################################################################
  26. ###############################################################################
  27. ###############################################################################
  28. def to_list(l):
  29. """ If not list, wraps parameter into a list."""
  30. if not isinstance(l, list):
  31. return [l, ]
  32. else:
  33. return l
  34. ###############################################################################
  35. ###############################################################################
  36. ###############################################################################
  37. class BatchSequence(keras.utils.Sequence):
  38. """Batch sequence generator.
  39. Take a (list of) input tensors and a batch size
  40. and creates a generators that creates a sequence of batches.
  41. :param Xs: One or a list of tensors. First axis needs to have same length.
  42. :param batch_size: Batch size. Default 32.
  43. """
  44. def __init__(self, Xs, batch_size=32):
  45. self.Xs = to_list(Xs)
  46. self.single_tensor = len(Xs) == 1
  47. self.batch_size = batch_size
  48. if not self.single_tensor:
  49. for X in self.Xs[1:]:
  50. assert X.shape[0] == self.Xs[0].shape[0]
  51. super(BatchSequence, self).__init__()
  52. def __len__(self):
  53. return int(math.ceil(float(len(self.Xs[0])) / self.batch_size))
  54. def __getitem__(self, idx):
  55. ret = [X[idx*self.batch_size:(idx+1)*self.batch_size]
  56. for X in self.Xs]
  57. if self.single_tensor:
  58. return ret[0]
  59. else:
  60. return tuple(ret)
  61. class TargetAugmentedSequence(keras.utils.Sequence):
  62. """Augments a sequence with a target on the fly.
  63. Takes a sequence/generator and a function that
  64. creates on the fly for each batch a target.
  65. The generator takes a batch from that sequence,
  66. computes the target and returns both.
  67. :param sequence: A sequence or generator.
  68. :param augment_f: Takes a batch and returns a target.
  69. """
  70. def __init__(self, sequence, augment_f):
  71. self.sequence = sequence
  72. self.augment_f = augment_f
  73. super(TargetAugmentedSequence, self).__init__()
  74. def __len__(self):
  75. return len(self.sequence)
  76. def __getitem__(self, idx):
  77. inputs = self.sequence[idx]
  78. if isinstance(inputs, tuple):
  79. assert len(inputs) == 1
  80. inputs = inputs[0]
  81. targets = self.augment_f(to_list(inputs))
  82. return inputs, targets
  83. ###############################################################################
  84. ###############################################################################
  85. ###############################################################################
  86. def preprocess_images(images, color_coding=None):
  87. """Image preprocessing
  88. Takes a batch of images and:
  89. * Adjust the color axis to the Keras format.
  90. * Fixes the color coding.
  91. :param images: Batch of images with 4 axes.
  92. :param color_coding: Determines the color coding.
  93. Can be None, 'RGBtoBGR' or 'BGRtoRGB'.
  94. :return: The preprocessed batch.
  95. """
  96. ret = images
  97. image_data_format = K.image_data_format()
  98. # todo: not very general:
  99. channels_first = images.shape[1] in [1, 3]
  100. if image_data_format == "channels_first" and not channels_first:
  101. ret = ret.transpose(0, 3, 1, 2)
  102. if image_data_format == "channels_last" and channels_first:
  103. ret = ret.transpose(0, 2, 3, 1)
  104. assert color_coding in [None, "RGBtoBGR", "BGRtoRGB"]
  105. if color_coding in ["RGBtoBGR", "BGRtoRGB"]:
  106. if image_data_format == "channels_first":
  107. ret = ret[:, ::-1, :, :]
  108. if image_data_format == "channels_last":
  109. ret = ret[:, :, :, ::-1]
  110. return ret
  111. def postprocess_images(images, color_coding=None, channels_first=None):
  112. """Image postprocessing
  113. Takes a batch of images and reverts the preprocessing.
  114. :param images: A batch of images with 4 axes.
  115. :param color_coding: The initial color coding,
  116. see :func:`preprocess_images`.
  117. :param channels_first: The output channel format.
  118. :return: The postprocessed images.
  119. """
  120. ret = images
  121. image_data_format = K.image_data_format()
  122. assert color_coding in [None, "RGBtoBGR", "BGRtoRGB"]
  123. if color_coding in ["RGBtoBGR", "BGRtoRGB"]:
  124. if image_data_format == "channels_first":
  125. ret = ret[:, ::-1, :, :]
  126. if image_data_format == "channels_last":
  127. ret = ret[:, :, :, ::-1]
  128. if image_data_format == "channels_first" and not channels_first:
  129. ret = ret.transpose(0, 2, 3, 1)
  130. if image_data_format == "channels_last" and channels_first:
  131. ret = ret.transpose(0, 3, 1, 2)
  132. return ret