visualizations.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. from builtins import range
  5. ###############################################################################
  6. ###############################################################################
  7. ###############################################################################
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. __all__ = [
  11. "project",
  12. "heatmap",
  13. "graymap",
  14. "gamma",
  15. "clip_quantile",
  16. ]
  17. ###############################################################################
  18. ###############################################################################
  19. ###############################################################################
  20. def project(X, output_range=(0, 1), absmax=None, input_is_positive_only=False):
  21. """Projects a tensor into a value range.
  22. Projects the tensor values into the specified range.
  23. :param X: A tensor.
  24. :param output_range: The output value range.
  25. :param absmax: A tensor specifying the absmax used for normalizing.
  26. Default the absmax along the first axis.
  27. :param input_is_positive_only: Is the input value range only positive.
  28. :return: The tensor with the values project into output range.
  29. """
  30. if absmax is None:
  31. absmax = np.max(np.abs(X),
  32. axis=tuple(range(1, len(X.shape))))
  33. absmax = np.asarray(absmax)
  34. mask = absmax != 0
  35. if mask.sum() > 0:
  36. X[mask] /= absmax[mask]
  37. if input_is_positive_only is False:
  38. X = (X+1)/2 # [0, 1]
  39. X = X.clip(0, 1)
  40. X = output_range[0] + (X * (output_range[1]-output_range[0]))
  41. return X
  42. def heatmap(X, cmap_type="seismic", reduce_op="sum", reduce_axis=-1, alpha_cmap=False, **kwargs):
  43. """Creates a heatmap/color map.
  44. Create a heatmap or colormap out of the input tensor.
  45. :param X: A image tensor with 4 axes.
  46. :param cmap_type: The color map to use. Default 'seismic'.
  47. :param reduce_op: Operation to reduce the color axis.
  48. Either 'sum' or 'absmax'.
  49. :param reduce_axis: Axis to reduce.
  50. :param alpha_cmap: Should the alpha component of the cmap be included.
  51. :param kwargs: Arguments passed on to :func:`project`
  52. :return: The tensor as color-map.
  53. """
  54. cmap = plt.cm.get_cmap(cmap_type)
  55. tmp = X
  56. shape = tmp.shape
  57. if reduce_op == "sum":
  58. tmp = tmp.sum(axis=reduce_axis)
  59. elif reduce_op == "absmax":
  60. pos_max = tmp.max(axis=reduce_axis)
  61. neg_max = (-tmp).max(axis=reduce_axis)
  62. abs_neg_max = -neg_max
  63. tmp = np.select([pos_max >= abs_neg_max, pos_max < abs_neg_max],
  64. [pos_max, neg_max])
  65. else:
  66. raise NotImplementedError()
  67. tmp = project(tmp, output_range=(0, 255), **kwargs).astype(np.int64)
  68. if alpha_cmap:
  69. tmp = cmap(tmp.flatten()).T
  70. else:
  71. tmp = cmap(tmp.flatten())[:, :3].T
  72. tmp = tmp.T
  73. shape = list(shape)
  74. shape[reduce_axis] = 3 + alpha_cmap
  75. return tmp.reshape(shape).astype(np.float32)
  76. def graymap(X, **kwargs):
  77. """Same as :func:`heatmap` but uses a gray colormap."""
  78. return heatmap(X, cmap_type="gray", **kwargs)
  79. def gamma(X, gamma=0.5, minamp=0, maxamp=None):
  80. """
  81. Apply gamma correction to an input array X
  82. while maintaining the relative order of entries,
  83. also for negative vs positive values in X.
  84. the fxn firstly determines the max
  85. amplitude in both positive and negative
  86. direction and then applies gamma scaling
  87. to the positive and negative values of the
  88. array separately, according to the common amplitude.
  89. :param gamma: the gamma parameter for gamma scaling
  90. :param minamp: the smallest absolute value to consider.
  91. if not given assumed to be zero (neutral value for relevance,
  92. min value for saliency, ...). values above and below
  93. minamp are treated separately.
  94. :param maxamp: the largest absolute value to consider relative
  95. to the neutral value minamp
  96. if not given determined from the given data.
  97. """
  98. #prepare return array
  99. Y = np.zeros_like(X)
  100. X = X - minamp # shift to given/assumed center
  101. if maxamp is None: maxamp = np.abs(X).max() #infer maxamp if not given
  102. X = X / maxamp # scale linearly
  103. #apply gamma correction for both positive and negative values.
  104. i_pos = X > 0
  105. i_neg = np.invert(i_pos)
  106. Y[i_pos] = X[i_pos]**gamma
  107. Y[i_neg] = -(-X[i_neg])**gamma
  108. #reconstruct original scale and center
  109. Y *= maxamp
  110. Y += minamp
  111. return Y
  112. def clip_quantile(X, quantile=1):
  113. """Clip the values of X into the given quantile."""
  114. if not isinstance(quantile, (list, tuple)):
  115. quantile = (quantile, 100-quantile)
  116. low = np.percentile(X, quantile[0])
  117. high = np.percentile(X, quantile[1])
  118. X[X < low] = low
  119. X[X > high] = high
  120. return X