"""Plotly visualizations for SUV NIfTI images.""" from __future__ import annotations import numpy as np import pandas as pd import nibabel as nib import plotly.graph_objects as go def _image_to_array(image: nib.Nifti1Image | np.ndarray) -> np.ndarray: """Return image data as a floating-point numpy array.""" if isinstance(image, nib.Nifti1Image): return image.get_fdata(dtype=np.float64) return np.asarray(image, dtype=float) def plot_suv_pdf_plotly( image: nib.Nifti1Image | np.ndarray, percentiles: tuple[float, ...] = (50, 75, 90, 95, 99), bins: int = 100, min_suv: float = 0.0, xlim: tuple[float, float] | None = None, histnorm: str | None = "probability density", title: str = "SUV distribution", log_x: bool = False, ) -> tuple[go.Figure, pd.DataFrame]: """Plot SUV histogram / empirical PDF with optional percentile markers. Percentiles are always computed on the original SUV scale. When ``log_x=True``, the histogram is drawn in ``log10(SUV)`` coordinates and the tick labels are shown in original SUV units. This avoids incorrectly placed percentile lines on log-scaled axes. Parameters ---------- image: 3D NIfTI image or 3D numpy array. percentiles: Percentile markers to draw and tabulate. bins: Number of histogram bins. min_suv: Only finite voxels with ``SUV > min_suv`` are used. xlim: Optional x-axis limits in original SUV units. histnorm: Plotly histogram normalization. Common choices are ``"probability density"``, ``"probability"`` or ``None``. title: Figure title. log_x: If True, use log10 coordinates internally and label ticks in original SUV units. Returns ------- tuple ``(fig, percentile_df)``. """ data = _image_to_array(image) if data.ndim != 3: raise ValueError("Input image must be 3D.") suv_values = data[np.isfinite(data) & (data > min_suv)] if suv_values.size == 0: raise ValueError("No valid SUV values found.") if log_x: if np.any(suv_values <= 0): raise ValueError("For log_x=True, all SUV values must be positive.") if xlim is not None and (xlim[0] <= 0 or xlim[1] <= 0): raise ValueError("For log_x=True, xlim values must be positive.") if percentiles: percentile_values = np.percentile(suv_values, percentiles) percentile_df = pd.DataFrame( { "percentile": percentiles, "suv_threshold": percentile_values, "n_voxels_ge_threshold": [ int(np.sum(suv_values >= value)) for value in percentile_values ], "n_voxels_lt_threshold": [ int(np.sum(suv_values < value)) for value in percentile_values ], } ) else: percentile_values = [] percentile_df = pd.DataFrame() fig = go.Figure() if log_x: plot_values = np.log10(suv_values) if xlim is not None: log_xlim = (np.log10(xlim[0]), np.log10(xlim[1])) plot_values = plot_values[(plot_values >= log_xlim[0]) & (plot_values <= log_xlim[1])] else: log_xlim = None fig.add_trace( go.Histogram( x=plot_values, nbinsx=bins, histnorm=histnorm, name="SUV histogram", opacity=0.75, marker_line_width=1, ) ) for p, value in zip(percentiles, percentile_values): fig.add_vline( x=np.log10(value), line_dash="dash", line_width=2, annotation_text=f"p{p:g} = {value:.3g}", annotation_position="top", ) if xlim is not None: tick_suv = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100]) tick_suv = tick_suv[(tick_suv >= xlim[0]) & (tick_suv <= xlim[1])] else: lo, hi = np.nanmin(suv_values), np.nanmax(suv_values) tick_suv = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100]) tick_suv = tick_suv[(tick_suv >= lo) & (tick_suv <= hi)] fig.update_xaxes( title_text="SUV", tickvals=np.log10(tick_suv), ticktext=[f"{v:g}" for v in tick_suv], ) if log_xlim is not None: fig.update_xaxes(range=list(log_xlim)) else: plot_values = suv_values.copy() if xlim is not None: plot_values = plot_values[(plot_values >= xlim[0]) & (plot_values <= xlim[1])] fig.add_trace( go.Histogram( x=plot_values, nbinsx=bins, histnorm=histnorm, name="SUV histogram", opacity=0.75, marker_line_width=1, ) ) for p, value in zip(percentiles, percentile_values): fig.add_vline( x=value, line_dash="dash", line_width=2, annotation_text=f"p{p:g} = {value:.3g}", annotation_position="top", ) fig.update_xaxes(title_text="SUV") if xlim is not None: fig.update_xaxes(range=list(xlim)) yaxis_title = { "probability density": "Probability density", "probability": "Probability", None: "Number of voxels", }.get(histnorm, str(histnorm)) fig.update_layout( title=title, xaxis_title="SUV, shown on log scale" if log_x else "SUV", yaxis_title=yaxis_title, bargap=0.02, width=850, height=500, template="plotly_white", ) return fig, percentile_df def plot_hot_voxels_plotly( image: nib.Nifti1Image | np.ndarray, c: float, max_points: int = 50_000, random_state: int | None = 0, show: bool = False, ) -> go.Figure: """Create a 3D scatter plot of voxels with ``SUV > c``. Parameters ---------- image: 3D NIfTI image or 3D numpy array. c: SUV threshold. max_points: Maximum number of voxels to display. If the thresholded region is larger, voxels are randomly downsampled for plotting speed. random_state: Seed for reproducible downsampling. Use ``None`` for non-reproducible downsampling. show: If True, immediately display the figure with ``fig.show()``. The figure is returned in all cases. Returns ------- plotly.graph_objects.Figure 3D scatter figure. """ data = _image_to_array(image) if data.ndim != 3: raise ValueError("Input image must be 3D.") mask = np.isfinite(data) & (data > c) coords = np.argwhere(mask) if coords.size == 0: raise ValueError(f"No voxels found above threshold c={c}.") values = data[mask] if len(coords) > max_points: rng = np.random.default_rng(random_state) idx = rng.choice(len(coords), size=max_points, replace=False) coords = coords[idx] values = values[idx] fig = go.Figure( data=go.Scatter3d( x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode="markers", marker={ "size": 2, "color": values, "colorscale": "Hot", "opacity": 0.5, "colorbar": {"title": "SUV"}, }, ) ) fig.update_layout( scene={ "xaxis_title": "i", "yaxis_title": "j", "zaxis_title": "k", "aspectmode": "data", }, width=800, height=800, ) if show: fig.show() return fig