spatial_features.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. """Spatial and heterogeneity features for high-SUV tail regions.
  2. This module provides :func:`compute_tail_spatial_features`.
  3. Design principle
  4. ----------------
  5. The default call is intentionally backward-compatible with the original
  6. implementation:
  7. compute_tail_spatial_features(image, connectivity=26)
  8. uses 26-connectivity for both connected components and local contrast,
  9. computes spatial spread, computes sphericity, and does not crop the image.
  10. Performance options are available as explicit opt-ins:
  11. compute_sphericity=False
  12. crop_to_roi=True
  13. contrast_connectivity=6
  14. This avoids silently changing the scientific meaning of already-computed
  15. features.
  16. """
  17. from __future__ import annotations
  18. from functools import lru_cache
  19. import numpy as np
  20. import pandas as pd
  21. import nibabel as nib
  22. from scipy import ndimage
  23. from skimage.measure import marching_cubes, mesh_surface_area
  24. Array3D = np.ndarray
  25. Spacing3D = tuple[float, float, float]
  26. def _validate_connectivity(connectivity: int) -> None:
  27. """Validate a 3D connectivity code.
  28. Parameters
  29. ----------
  30. connectivity:
  31. Must be one of 6, 18, or 26.
  32. """
  33. if connectivity not in (6, 18, 26):
  34. raise ValueError("connectivity must be one of: 6, 18, 26.")
  35. @lru_cache(maxsize=None)
  36. def _connectivity_structure(connectivity: int) -> Array3D:
  37. """Return a ``scipy.ndimage`` binary structure for 3D connectivity."""
  38. _validate_connectivity(connectivity)
  39. if connectivity == 6:
  40. return ndimage.generate_binary_structure(rank=3, connectivity=1)
  41. if connectivity == 18:
  42. return ndimage.generate_binary_structure(rank=3, connectivity=2)
  43. return ndimage.generate_binary_structure(rank=3, connectivity=3)
  44. @lru_cache(maxsize=None)
  45. def _neighbor_offsets(connectivity: int) -> tuple[tuple[int, int, int], ...]:
  46. """Return unique 3D neighbor offsets for the selected connectivity.
  47. Only one offset from each symmetric pair is returned. Therefore each
  48. neighboring voxel pair is counted once, matching the original implementation.
  49. """
  50. _validate_connectivity(connectivity)
  51. offsets: list[tuple[int, int, int]] = []
  52. for di in (-1, 0, 1):
  53. for dj in (-1, 0, 1):
  54. for dk in (-1, 0, 1):
  55. if di == 0 and dj == 0 and dk == 0:
  56. continue
  57. dist1 = abs(di) + abs(dj) + abs(dk)
  58. if connectivity == 6 and dist1 != 1:
  59. continue
  60. if connectivity == 18 and dist1 > 2:
  61. continue
  62. # Keep one direction only to avoid double-counting.
  63. # This is the same rule as in the original implementation.
  64. if (di, dj, dk) > (0, 0, 0):
  65. offsets.append((di, dj, dk))
  66. return tuple(offsets)
  67. def _component_entropy(component_sizes: np.ndarray) -> float:
  68. """Entropy of the connected-component size distribution.
  69. If the component sizes are ``s_j``, this returns
  70. - sum_j p_j log(p_j), where p_j = s_j / sum_k s_k.
  71. """
  72. component_sizes = np.asarray(component_sizes, dtype=float)
  73. if component_sizes.size == 0 or component_sizes.sum() <= 0:
  74. return np.nan
  75. p = component_sizes / component_sizes.sum()
  76. p = p[p > 0]
  77. return float(-np.sum(p * np.log(p)))
  78. def _mask_bbox(mask: Array3D, margin: int = 0) -> tuple[slice, slice, slice] | None:
  79. """Return bounding-box slices for a 3D mask.
  80. This implementation avoids ``np.argwhere(mask)`` and instead uses 1D
  81. projections, which is usually faster and less memory-intensive.
  82. """
  83. if mask.ndim != 3:
  84. raise ValueError("mask must be 3D.")
  85. if not np.any(mask):
  86. return None
  87. if margin < 0:
  88. raise ValueError("margin must be >= 0.")
  89. i_nonzero = np.flatnonzero(mask.any(axis=(1, 2)))
  90. j_nonzero = np.flatnonzero(mask.any(axis=(0, 2)))
  91. k_nonzero = np.flatnonzero(mask.any(axis=(0, 1)))
  92. lo = np.array([i_nonzero[0], j_nonzero[0], k_nonzero[0]], dtype=int)
  93. hi = np.array(
  94. [i_nonzero[-1] + 1, j_nonzero[-1] + 1, k_nonzero[-1] + 1],
  95. dtype=int,
  96. )
  97. if margin > 0:
  98. lo = np.maximum(lo - margin, 0)
  99. hi = np.minimum(hi + margin, mask.shape)
  100. return tuple(slice(int(lo[d]), int(hi[d])) for d in range(3))
  101. def _crop_to_mask_bbox(
  102. suv: Array3D,
  103. valid_mask: Array3D,
  104. margin: int = 1,
  105. ) -> tuple[Array3D, Array3D]:
  106. """Crop SUV image and valid-mask to the valid-mask bounding box."""
  107. bbox = _mask_bbox(valid_mask, margin=margin)
  108. if bbox is None:
  109. return suv, valid_mask
  110. return suv[bbox], valid_mask[bbox]
  111. def _spatial_spread_from_mask(
  112. region_mask: Array3D,
  113. values: np.ndarray,
  114. voxel_spacing: Spacing3D,
  115. ) -> tuple[float, float]:
  116. """Compute unweighted and SUV-weighted spatial spread in mm^2.
  117. This is equivalent to the original implementation based on
  118. ``np.argwhere(region_mask)`` followed by multiplication with voxel spacing,
  119. but it avoids constructing a single large ``(n, 3)`` coordinate matrix.
  120. """
  121. n_voxels = int(np.count_nonzero(region_mask))
  122. if n_voxels == 0:
  123. return np.nan, np.nan
  124. ii, jj, kk = np.nonzero(region_mask)
  125. dx, dy, dz = voxel_spacing
  126. x = ii.astype(np.float64, copy=False) * dx
  127. y = jj.astype(np.float64, copy=False) * dy
  128. z = kk.astype(np.float64, copy=False) * dz
  129. cx = float(np.mean(x))
  130. cy = float(np.mean(y))
  131. cz = float(np.mean(z))
  132. spread = float(np.mean((x - cx) ** 2 + (y - cy) ** 2 + (z - cz) ** 2))
  133. weights = np.asarray(values, dtype=np.float64)
  134. weight_sum = float(np.sum(weights))
  135. if weight_sum <= 0:
  136. weighted_spread = np.nan
  137. else:
  138. wcx = float(np.sum(weights * x) / weight_sum)
  139. wcy = float(np.sum(weights * y) / weight_sum)
  140. wcz = float(np.sum(weights * z) / weight_sum)
  141. weighted_spread = float(
  142. np.sum(weights * ((x - wcx) ** 2 + (y - wcy) ** 2 + (z - wcz) ** 2))
  143. / weight_sum
  144. )
  145. return spread, weighted_spread
  146. def _local_contrast(
  147. suv: Array3D,
  148. region_mask: Array3D,
  149. connectivity: int = 26,
  150. ) -> float:
  151. """Continuous local contrast inside a binary region.
  152. Computes the mean squared SUV difference between neighboring voxel pairs
  153. where both voxels are inside ``region_mask``.
  154. This is numerically equivalent to the original concatenate-based version,
  155. but it accumulates the sum of squared differences and the number of pairs
  156. directly. This avoids large temporary arrays.
  157. """
  158. if suv.ndim != 3:
  159. raise ValueError("suv must be a 3D array.")
  160. _validate_connectivity(connectivity)
  161. region_mask = np.asarray(region_mask, dtype=bool)
  162. if region_mask.shape != suv.shape:
  163. raise ValueError("region_mask must have the same shape as suv.")
  164. if np.count_nonzero(region_mask) < 2:
  165. return np.nan
  166. nx, ny, nz = suv.shape
  167. total_sqdiff = 0.0
  168. total_pairs = 0
  169. for di, dj, dk in _neighbor_offsets(connectivity):
  170. src_slices = (
  171. slice(max(0, -di), min(nx, nx - di)),
  172. slice(max(0, -dj), min(ny, ny - dj)),
  173. slice(max(0, -dk), min(nz, nz - dk)),
  174. )
  175. dst_slices = (
  176. slice(max(0, di), min(nx, nx + di)),
  177. slice(max(0, dj), min(ny, ny + dj)),
  178. slice(max(0, dk), min(nz, nz + dk)),
  179. )
  180. pair_mask = region_mask[src_slices] & region_mask[dst_slices]
  181. n_pairs = int(np.count_nonzero(pair_mask))
  182. if n_pairs == 0:
  183. continue
  184. diff = suv[src_slices][pair_mask] - suv[dst_slices][pair_mask]
  185. total_sqdiff += float(np.sum(diff * diff))
  186. total_pairs += n_pairs
  187. if total_pairs == 0:
  188. return np.nan
  189. return float(total_sqdiff / total_pairs)
  190. def _largest_component_sphericity(
  191. largest_component_mask: Array3D,
  192. voxel_spacing: Spacing3D,
  193. *,
  194. crop_component: bool = False,
  195. ) -> float:
  196. """Approximate sphericity of the largest connected component.
  197. Sphericity is defined as
  198. psi = pi^(1/3) * (6V)^(2/3) / A,
  199. where ``V`` is physical volume and ``A`` is triangulated surface area.
  200. Parameters
  201. ----------
  202. crop_component:
  203. If False, exactly follows the original full-volume marching-cubes
  204. calculation. If True, crop the largest component to its bounding box
  205. before marching cubes. Cropping is faster but may very slightly change
  206. the surface mesh if the component touches the crop boundary.
  207. """
  208. n_voxels = int(np.count_nonzero(largest_component_mask))
  209. if n_voxels < 4:
  210. return np.nan
  211. voxel_volume = float(np.prod(voxel_spacing))
  212. volume = float(n_voxels * voxel_volume)
  213. component_for_surface = largest_component_mask
  214. if crop_component:
  215. bbox = _mask_bbox(largest_component_mask, margin=1)
  216. if bbox is None:
  217. return np.nan
  218. component_for_surface = largest_component_mask[bbox]
  219. try:
  220. verts, faces, _, _ = marching_cubes(
  221. component_for_surface.astype(float, copy=False),
  222. level=0.5,
  223. spacing=voxel_spacing,
  224. )
  225. area = float(mesh_surface_area(verts, faces))
  226. except Exception:
  227. return np.nan
  228. if area <= 0:
  229. return np.nan
  230. return float((np.pi ** (1 / 3)) * ((6 * volume) ** (2 / 3)) / area)
  231. def compute_tail_spatial_features(
  232. image: nib.Nifti1Image,
  233. voxel_spacing: Spacing3D | None = None,
  234. percentiles: tuple[float, ...] = (90, 95, 97.5, 99),
  235. connectivity: int = 26,
  236. min_component_voxels: int = 1,
  237. image_id: str | None = None,
  238. *,
  239. component_connectivity: int | None = None,
  240. contrast_connectivity: int | None = None,
  241. compute_spread: bool = True,
  242. compute_local_contrast: bool = True,
  243. compute_sphericity: bool = True,
  244. crop_to_roi: bool = False,
  245. crop_margin: int = 1,
  246. crop_component_for_sphericity: bool = False,
  247. ) -> pd.DataFrame:
  248. """Compute spatial and heterogeneity features of high-SUV tail regions.
  249. The input image is assumed to be already processed, for example PET SUV
  250. multiplied by an organ segmentation mask. Finite positive voxels are treated
  251. as the region of interest. For each percentile ``q``, the high-SUV tail is
  252. defined as
  253. R_q = {voxel : SUV(voxel) >= percentile_q(SUV within ROI)}.
  254. Backward compatibility
  255. ----------------------
  256. By default, this function preserves the semantics of the original version:
  257. - ``connectivity`` is used for both components and local contrast;
  258. - spread is computed;
  259. - sphericity is computed;
  260. - the volume is not cropped before feature calculation.
  261. Faster screening runs should explicitly disable or alter expensive features.
  262. Parameters
  263. ----------
  264. image:
  265. 3D NIfTI image containing SUV values per voxel.
  266. voxel_spacing:
  267. Physical voxel size in mm. If None, it is read from the NIfTI header.
  268. percentiles:
  269. Percentile thresholds used to define high-SUV regions.
  270. connectivity:
  271. Backward-compatible connectivity parameter. Used whenever
  272. ``component_connectivity`` or ``contrast_connectivity`` is not supplied.
  273. min_component_voxels:
  274. Components smaller than this number of voxels are ignored in
  275. component-level summaries.
  276. image_id:
  277. Optional image or patient identifier added to the output table.
  278. component_connectivity:
  279. Connectivity used for connected-component labeling. If None, uses
  280. ``connectivity``.
  281. contrast_connectivity:
  282. Connectivity used for local contrast. If None, uses ``connectivity``.
  283. Set this to 6 explicitly for a faster face-neighbor contrast.
  284. compute_spread:
  285. If False, skip spatial spread features.
  286. compute_local_contrast:
  287. If False, skip local contrast features.
  288. compute_sphericity:
  289. If False, skip marching-cubes sphericity computation.
  290. crop_to_roi:
  291. If True, crop the image to the nonzero ROI bounding box before feature
  292. calculation. This should preserve component counts and component entropy,
  293. but the default is False to match the original implementation exactly.
  294. crop_margin:
  295. Margin in voxels added when ``crop_to_roi=True``.
  296. crop_component_for_sphericity:
  297. If True, crop the largest component before marching cubes. This is
  298. faster but not the exact original sphericity calculation, so the default
  299. is False.
  300. Returns
  301. -------
  302. pandas.DataFrame
  303. One row per percentile threshold.
  304. """
  305. if not isinstance(image, nib.Nifti1Image):
  306. raise TypeError("image must be a nib.Nifti1Image.")
  307. if component_connectivity is None:
  308. component_connectivity = connectivity
  309. if contrast_connectivity is None:
  310. contrast_connectivity = connectivity
  311. _validate_connectivity(component_connectivity)
  312. _validate_connectivity(contrast_connectivity)
  313. if min_component_voxels < 1:
  314. raise ValueError("min_component_voxels must be >= 1.")
  315. if crop_margin < 0:
  316. raise ValueError("crop_margin must be >= 0.")
  317. suv = image.get_fdata(dtype=np.float64)
  318. if suv.ndim != 3:
  319. raise ValueError("Input NIfTI image must be 3D.")
  320. if voxel_spacing is None:
  321. voxel_spacing = image.header.get_zooms()[:3]
  322. if len(voxel_spacing) != 3:
  323. raise ValueError("voxel_spacing must have length 3.")
  324. voxel_spacing = tuple(float(x) for x in voxel_spacing)
  325. voxel_volume = float(np.prod(voxel_spacing))
  326. valid_mask = np.isfinite(suv) & (suv > 0)
  327. if crop_to_roi:
  328. suv, valid_mask = _crop_to_mask_bbox(
  329. suv=suv,
  330. valid_mask=valid_mask,
  331. margin=crop_margin,
  332. )
  333. values = suv[valid_mask]
  334. if values.size == 0:
  335. raise ValueError("No finite positive SUV values found in the image.")
  336. percentile_values = np.percentile(values, percentiles)
  337. structure = _connectivity_structure(component_connectivity)
  338. n_roi_voxels = int(np.count_nonzero(valid_mask))
  339. roi_volume_mm3 = float(n_roi_voxels * voxel_volume)
  340. rows: list[dict[str, object]] = []
  341. for q, threshold_raw in zip(percentiles, percentile_values):
  342. threshold = float(threshold_raw)
  343. tail_mask = valid_mask & (suv >= threshold)
  344. n_tail_voxels = int(np.count_nonzero(tail_mask))
  345. base_row: dict[str, object] = {
  346. "image_id": image_id,
  347. "percentile": q,
  348. "threshold": threshold,
  349. "n_roi_voxels": n_roi_voxels,
  350. "roi_volume_mm3": roi_volume_mm3,
  351. "n_tail_voxels": n_tail_voxels,
  352. "tail_volume_mm3": float(n_tail_voxels * voxel_volume),
  353. "tail_fraction": float(n_tail_voxels / n_roi_voxels),
  354. }
  355. if n_tail_voxels == 0:
  356. rows.append(
  357. base_row
  358. | {
  359. "tail_mean": np.nan,
  360. "tail_median": np.nan,
  361. "tail_max": np.nan,
  362. "tail_std": np.nan,
  363. "tail_cv": np.nan,
  364. "tail_sum": np.nan,
  365. "tail_excess_mean": np.nan,
  366. "tail_excess_sum": np.nan,
  367. "tail_local_contrast": np.nan,
  368. "tail_local_contrast_norm": np.nan,
  369. "tail_spread_mm2": np.nan,
  370. "tail_weighted_spread_mm2": np.nan,
  371. "n_components": 0,
  372. "largest_component_voxels": 0,
  373. "largest_component_volume_mm3": np.nan,
  374. "largest_component_fraction": np.nan,
  375. "component_entropy": np.nan,
  376. "largest_component_sphericity": np.nan,
  377. }
  378. )
  379. continue
  380. tail_values = suv[tail_mask]
  381. tail_mean = float(np.mean(tail_values))
  382. tail_median = float(np.median(tail_values))
  383. tail_max = float(np.max(tail_values))
  384. tail_std = float(np.std(tail_values, ddof=1)) if n_tail_voxels > 1 else 0.0
  385. tail_cv = float(tail_std / tail_mean) if tail_mean > 0 else np.nan
  386. tail_sum = float(np.sum(tail_values))
  387. tail_excess = tail_values - threshold
  388. tail_excess_mean = float(np.mean(tail_excess))
  389. tail_excess_sum = float(np.sum(tail_excess))
  390. if compute_local_contrast:
  391. tail_local_contrast = _local_contrast(
  392. suv=suv,
  393. region_mask=tail_mask,
  394. connectivity=contrast_connectivity,
  395. )
  396. else:
  397. tail_local_contrast = np.nan
  398. tail_local_contrast_norm = (
  399. float(tail_local_contrast / tail_mean**2)
  400. if np.isfinite(tail_local_contrast) and tail_mean > 0
  401. else np.nan
  402. )
  403. if compute_spread:
  404. tail_spread_mm2, tail_weighted_spread_mm2 = _spatial_spread_from_mask(
  405. region_mask=tail_mask,
  406. values=tail_values,
  407. voxel_spacing=voxel_spacing,
  408. )
  409. else:
  410. tail_spread_mm2 = np.nan
  411. tail_weighted_spread_mm2 = np.nan
  412. labeled, _ = ndimage.label(tail_mask, structure=structure)
  413. component_sizes_all = np.bincount(labeled.ravel())[1:]
  414. component_sizes = component_sizes_all[component_sizes_all >= min_component_voxels]
  415. n_components = int(component_sizes.size)
  416. if n_components > 0:
  417. largest_component_voxels = int(np.max(component_sizes))
  418. largest_component_volume_mm3 = float(largest_component_voxels * voxel_volume)
  419. largest_component_fraction = float(largest_component_voxels / n_tail_voxels)
  420. component_entropy = _component_entropy(component_sizes)
  421. largest_label = int(
  422. np.flatnonzero(component_sizes_all == largest_component_voxels)[0] + 1
  423. )
  424. if compute_sphericity:
  425. largest_component_mask = labeled == largest_label
  426. largest_component_sphericity = _largest_component_sphericity(
  427. largest_component_mask=largest_component_mask,
  428. voxel_spacing=voxel_spacing,
  429. crop_component=crop_component_for_sphericity,
  430. )
  431. else:
  432. largest_component_sphericity = np.nan
  433. else:
  434. largest_component_voxels = 0
  435. largest_component_volume_mm3 = np.nan
  436. largest_component_fraction = np.nan
  437. component_entropy = np.nan
  438. largest_component_sphericity = np.nan
  439. rows.append(
  440. base_row
  441. | {
  442. "tail_mean": tail_mean,
  443. "tail_median": tail_median,
  444. "tail_max": tail_max,
  445. "tail_std": tail_std,
  446. "tail_cv": tail_cv,
  447. "tail_sum": tail_sum,
  448. "tail_excess_mean": tail_excess_mean,
  449. "tail_excess_sum": tail_excess_sum,
  450. "tail_local_contrast": tail_local_contrast,
  451. "tail_local_contrast_norm": tail_local_contrast_norm,
  452. "tail_spread_mm2": tail_spread_mm2,
  453. "tail_weighted_spread_mm2": tail_weighted_spread_mm2,
  454. "n_components": n_components,
  455. "largest_component_voxels": largest_component_voxels,
  456. "largest_component_volume_mm3": largest_component_volume_mm3,
  457. "largest_component_fraction": largest_component_fraction,
  458. "component_entropy": component_entropy,
  459. "largest_component_sphericity": largest_component_sphericity,
  460. }
  461. )
  462. return pd.DataFrame(rows)