plotting.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. """Plotly visualizations for SUV NIfTI images."""
  2. from __future__ import annotations
  3. import numpy as np
  4. import pandas as pd
  5. import nibabel as nib
  6. import plotly.graph_objects as go
  7. def _image_to_array(image: nib.Nifti1Image | np.ndarray) -> np.ndarray:
  8. """Return image data as a floating-point numpy array."""
  9. if isinstance(image, nib.Nifti1Image):
  10. return image.get_fdata(dtype=np.float64)
  11. return np.asarray(image, dtype=float)
  12. def plot_suv_pdf_plotly(
  13. image: nib.Nifti1Image | np.ndarray,
  14. percentiles: tuple[float, ...] = (50, 75, 90, 95, 99),
  15. bins: int = 100,
  16. min_suv: float = 0.0,
  17. xlim: tuple[float, float] | None = None,
  18. histnorm: str | None = "probability density",
  19. title: str = "SUV distribution",
  20. log_x: bool = False,
  21. ) -> tuple[go.Figure, pd.DataFrame]:
  22. """Plot SUV histogram / empirical PDF with optional percentile markers.
  23. Percentiles are always computed on the original SUV scale. When
  24. ``log_x=True``, the histogram is drawn in ``log10(SUV)`` coordinates and the
  25. tick labels are shown in original SUV units. This avoids incorrectly placed
  26. percentile lines on log-scaled axes.
  27. Parameters
  28. ----------
  29. image:
  30. 3D NIfTI image or 3D numpy array.
  31. percentiles:
  32. Percentile markers to draw and tabulate.
  33. bins:
  34. Number of histogram bins.
  35. min_suv:
  36. Only finite voxels with ``SUV > min_suv`` are used.
  37. xlim:
  38. Optional x-axis limits in original SUV units.
  39. histnorm:
  40. Plotly histogram normalization. Common choices are ``"probability
  41. density"``, ``"probability"`` or ``None``.
  42. title:
  43. Figure title.
  44. log_x:
  45. If True, use log10 coordinates internally and label ticks in original
  46. SUV units.
  47. Returns
  48. -------
  49. tuple
  50. ``(fig, percentile_df)``.
  51. """
  52. data = _image_to_array(image)
  53. if data.ndim != 3:
  54. raise ValueError("Input image must be 3D.")
  55. suv_values = data[np.isfinite(data) & (data > min_suv)]
  56. if suv_values.size == 0:
  57. raise ValueError("No valid SUV values found.")
  58. if log_x:
  59. if np.any(suv_values <= 0):
  60. raise ValueError("For log_x=True, all SUV values must be positive.")
  61. if xlim is not None and (xlim[0] <= 0 or xlim[1] <= 0):
  62. raise ValueError("For log_x=True, xlim values must be positive.")
  63. if percentiles:
  64. percentile_values = np.percentile(suv_values, percentiles)
  65. percentile_df = pd.DataFrame(
  66. {
  67. "percentile": percentiles,
  68. "suv_threshold": percentile_values,
  69. "n_voxels_ge_threshold": [
  70. int(np.sum(suv_values >= value)) for value in percentile_values
  71. ],
  72. "n_voxels_lt_threshold": [
  73. int(np.sum(suv_values < value)) for value in percentile_values
  74. ],
  75. }
  76. )
  77. else:
  78. percentile_values = []
  79. percentile_df = pd.DataFrame()
  80. fig = go.Figure()
  81. if log_x:
  82. plot_values = np.log10(suv_values)
  83. if xlim is not None:
  84. log_xlim = (np.log10(xlim[0]), np.log10(xlim[1]))
  85. plot_values = plot_values[(plot_values >= log_xlim[0]) & (plot_values <= log_xlim[1])]
  86. else:
  87. log_xlim = None
  88. fig.add_trace(
  89. go.Histogram(
  90. x=plot_values,
  91. nbinsx=bins,
  92. histnorm=histnorm,
  93. name="SUV histogram",
  94. opacity=0.75,
  95. marker_line_width=1,
  96. )
  97. )
  98. for p, value in zip(percentiles, percentile_values):
  99. fig.add_vline(
  100. x=np.log10(value),
  101. line_dash="dash",
  102. line_width=2,
  103. annotation_text=f"p{p:g} = {value:.3g}",
  104. annotation_position="top",
  105. )
  106. if xlim is not None:
  107. tick_suv = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100])
  108. tick_suv = tick_suv[(tick_suv >= xlim[0]) & (tick_suv <= xlim[1])]
  109. else:
  110. lo, hi = np.nanmin(suv_values), np.nanmax(suv_values)
  111. tick_suv = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100])
  112. tick_suv = tick_suv[(tick_suv >= lo) & (tick_suv <= hi)]
  113. fig.update_xaxes(
  114. title_text="SUV",
  115. tickvals=np.log10(tick_suv),
  116. ticktext=[f"{v:g}" for v in tick_suv],
  117. )
  118. if log_xlim is not None:
  119. fig.update_xaxes(range=list(log_xlim))
  120. else:
  121. plot_values = suv_values.copy()
  122. if xlim is not None:
  123. plot_values = plot_values[(plot_values >= xlim[0]) & (plot_values <= xlim[1])]
  124. fig.add_trace(
  125. go.Histogram(
  126. x=plot_values,
  127. nbinsx=bins,
  128. histnorm=histnorm,
  129. name="SUV histogram",
  130. opacity=0.75,
  131. marker_line_width=1,
  132. )
  133. )
  134. for p, value in zip(percentiles, percentile_values):
  135. fig.add_vline(
  136. x=value,
  137. line_dash="dash",
  138. line_width=2,
  139. annotation_text=f"p{p:g} = {value:.3g}",
  140. annotation_position="top",
  141. )
  142. fig.update_xaxes(title_text="SUV")
  143. if xlim is not None:
  144. fig.update_xaxes(range=list(xlim))
  145. yaxis_title = {
  146. "probability density": "Probability density",
  147. "probability": "Probability",
  148. None: "Number of voxels",
  149. }.get(histnorm, str(histnorm))
  150. fig.update_layout(
  151. title=title,
  152. xaxis_title="SUV, shown on log scale" if log_x else "SUV",
  153. yaxis_title=yaxis_title,
  154. bargap=0.02,
  155. width=850,
  156. height=500,
  157. template="plotly_white",
  158. )
  159. return fig, percentile_df
  160. def plot_hot_voxels_plotly(
  161. image: nib.Nifti1Image | np.ndarray,
  162. c: float,
  163. max_points: int = 50_000,
  164. random_state: int | None = 0,
  165. show: bool = False,
  166. ) -> go.Figure:
  167. """Create a 3D scatter plot of voxels with ``SUV > c``.
  168. Parameters
  169. ----------
  170. image:
  171. 3D NIfTI image or 3D numpy array.
  172. c:
  173. SUV threshold.
  174. max_points:
  175. Maximum number of voxels to display. If the thresholded region is
  176. larger, voxels are randomly downsampled for plotting speed.
  177. random_state:
  178. Seed for reproducible downsampling. Use ``None`` for non-reproducible
  179. downsampling.
  180. show:
  181. If True, immediately display the figure with ``fig.show()``. The figure
  182. is returned in all cases.
  183. Returns
  184. -------
  185. plotly.graph_objects.Figure
  186. 3D scatter figure.
  187. """
  188. data = _image_to_array(image)
  189. if data.ndim != 3:
  190. raise ValueError("Input image must be 3D.")
  191. mask = np.isfinite(data) & (data > c)
  192. coords = np.argwhere(mask)
  193. if coords.size == 0:
  194. raise ValueError(f"No voxels found above threshold c={c}.")
  195. values = data[mask]
  196. if len(coords) > max_points:
  197. rng = np.random.default_rng(random_state)
  198. idx = rng.choice(len(coords), size=max_points, replace=False)
  199. coords = coords[idx]
  200. values = values[idx]
  201. fig = go.Figure(
  202. data=go.Scatter3d(
  203. x=coords[:, 0],
  204. y=coords[:, 1],
  205. z=coords[:, 2],
  206. mode="markers",
  207. marker={
  208. "size": 2,
  209. "color": values,
  210. "colorscale": "Hot",
  211. "opacity": 0.5,
  212. "colorbar": {"title": "SUV"},
  213. },
  214. )
  215. )
  216. fig.update_layout(
  217. scene={
  218. "xaxis_title": "i",
  219. "yaxis_title": "j",
  220. "zaxis_title": "k",
  221. "aspectmode": "data",
  222. },
  223. width=800,
  224. height=800,
  225. )
  226. if show:
  227. fig.show()
  228. return fig