app.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. """
  2. Interactive Shiny app for inspecting processed PET SUV NIfTI images.
  3. Run from project root:
  4. shiny run --reload app.py
  5. Required packages:
  6. pip install shiny shinywidgets shinyswatch plotly pandas nibabel scipy scikit-image
  7. Expected project structure:
  8. project_root/
  9. ├── app.py
  10. ├── data/
  11. │ ├── gen/
  12. │ └── raw/
  13. └── spatial_suv_charact/
  14. ├── image_io.py
  15. ├── metadata.py
  16. ├── plotting.py
  17. ├── spatial_features.py
  18. └── suv_stats.py
  19. """
  20. from __future__ import annotations
  21. from pathlib import Path
  22. import sys
  23. import traceback
  24. import pandas as pd
  25. import plotly.graph_objects as go
  26. from shiny import App, Inputs, Outputs, Session, reactive, render, ui
  27. from shinywidgets import output_widget, render_widget
  28. # -----------------------------------------------------------------------------
  29. # Project paths
  30. # -----------------------------------------------------------------------------
  31. APP_DIR = Path(__file__).resolve().parent
  32. PROJECT_ROOT = APP_DIR
  33. if str(PROJECT_ROOT) not in sys.path:
  34. sys.path.insert(0, str(PROJECT_ROOT))
  35. import os
  36. DATA_RAW = os.path.join(PROJECT_ROOT, "data", "raw")
  37. DATA_GEN = os.path.join(PROJECT_ROOT, "data", "gen")
  38. DEFAULT_METADATA_PATH = os.path.join(DATA_GEN, "metadata.pkl")
  39. # -----------------------------------------------------------------------------
  40. # Local package imports
  41. # -----------------------------------------------------------------------------
  42. from spatial_suv_charact.metadata import ( # noqa: E402
  43. get_meta_data,
  44. flag_corrupted_files,
  45. flag_AE_patients,
  46. )
  47. from spatial_suv_charact.image_io import get_processed_image # noqa: E402
  48. from spatial_suv_charact.plotting import ( # noqa: E402
  49. plot_suv_pdf_plotly,
  50. plot_hot_voxels_plotly,
  51. )
  52. from spatial_suv_charact.spatial_features import compute_tail_spatial_features # noqa: E402
  53. # -----------------------------------------------------------------------------
  54. # Helper functions
  55. # -----------------------------------------------------------------------------
  56. def _empty_figure(message: str = "No plot available") -> go.Figure:
  57. """Return an empty Plotly figure with a centered annotation."""
  58. fig = go.Figure()
  59. fig.add_annotation(
  60. text=message,
  61. xref="paper",
  62. yref="paper",
  63. x=0.5,
  64. y=0.5,
  65. showarrow=False,
  66. )
  67. fig.update_layout(
  68. template="plotly_white",
  69. xaxis={"visible": False},
  70. yaxis={"visible": False},
  71. height=500,
  72. )
  73. return fig
  74. def _load_metadata() -> pd.DataFrame:
  75. """Load metadata table from disk or build it from DATA_RAW."""
  76. if DEFAULT_METADATA_PATH.exists():
  77. df_meta = pd.read_pickle(DEFAULT_METADATA_PATH)
  78. else:
  79. df_meta = get_meta_data(str(DATA_RAW))
  80. df_meta = flag_corrupted_files(df_meta)
  81. df_meta = flag_AE_patients(df_meta)
  82. DATA_GEN.mkdir(parents=True, exist_ok=True)
  83. df_meta.to_pickle(DEFAULT_METADATA_PATH)
  84. if not isinstance(df_meta.index, pd.MultiIndex):
  85. required_cols = {"patient_id", "organ", "visit"}
  86. if required_cols.issubset(df_meta.columns):
  87. df_meta = df_meta.set_index(["patient_id", "organ", "visit"])
  88. else:
  89. raise ValueError(
  90. "Metadata table must have MultiIndex (patient_id, organ, visit) "
  91. "or columns patient_id, organ, visit."
  92. )
  93. required_columns = {"PET_path", "SEG_path"}
  94. missing = required_columns.difference(df_meta.columns)
  95. if missing:
  96. raise ValueError(f"Metadata table is missing required columns: {missing}")
  97. return df_meta.sort_index()
  98. DF_META = _load_metadata()
  99. def _metadata_for_display(df_meta: pd.DataFrame) -> pd.DataFrame:
  100. """Return metadata with index columns exposed and an internal row_id."""
  101. df = df_meta.reset_index().copy()
  102. df.insert(0, "row_id", range(len(df)))
  103. preferred = [
  104. "row_id",
  105. "patient_id",
  106. "organ",
  107. "visit",
  108. "is_AE_patient",
  109. "is_corrupted",
  110. "PET_filename",
  111. "SEG_filename",
  112. "PET_path",
  113. "SEG_path",
  114. ]
  115. cols = [c for c in preferred if c in df.columns] + [
  116. c for c in df.columns if c not in preferred
  117. ]
  118. return df[cols]
  119. DF_DISPLAY = _metadata_for_display(DF_META)
  120. def _format_image_id(row: pd.Series) -> str:
  121. """Create a readable image identifier from a selected metadata row."""
  122. return f"{row['patient_id']}_{row['organ']}_VISIT_{row['visit']}"
  123. def _parse_probs(prob_string: str) -> tuple[float, ...]:
  124. """Parse comma-separated percentile values from UI input."""
  125. values: list[float] = []
  126. for part in prob_string.split(","):
  127. part = part.strip()
  128. if not part:
  129. continue
  130. values.append(float(part))
  131. if not values:
  132. raise ValueError("At least one percentile must be supplied.")
  133. for value in values:
  134. if not (0 < value < 100):
  135. raise ValueError("Percentiles must be between 0 and 100.")
  136. return tuple(values)
  137. def _safe_error_message(exc: BaseException) -> str:
  138. """Return a readable error message for the diagnostics tab."""
  139. return "\n".join(
  140. [
  141. f"{type(exc).__name__}: {exc}",
  142. "",
  143. traceback.format_exc(limit=8),
  144. ]
  145. )
  146. # -----------------------------------------------------------------------------
  147. # UI
  148. # -----------------------------------------------------------------------------
  149. app_ui = ui.page_fluid(
  150. ui.h2("Spatial SUV tail-feature explorer"),
  151. ui.layout_sidebar(
  152. ui.sidebar(
  153. ui.input_text(
  154. "probs",
  155. "Percentiles",
  156. value="80, 90, 95",
  157. placeholder="80, 90, 95",
  158. ),
  159. ui.input_numeric("bins", "Histogram bins", value=100, min=10, max=500),
  160. ui.input_numeric("min_suv", "Minimum SUV for PDF", value=0.1, min=0),
  161. ui.input_checkbox("log_x", "Log x-axis for SUV PDF", value=True),
  162. ui.hr(),
  163. ui.input_numeric(
  164. "min_component_voxels",
  165. "Minimum component voxels",
  166. value=3,
  167. min=1,
  168. step=1,
  169. ),
  170. ui.input_select(
  171. "component_connectivity",
  172. "Component connectivity",
  173. choices={"6": "6", "18": "18", "26": "26"},
  174. selected="26",
  175. ),
  176. ui.input_select(
  177. "contrast_connectivity",
  178. "Contrast connectivity",
  179. choices={"6": "6", "18": "18", "26": "26"},
  180. selected="26",
  181. ),
  182. ui.input_checkbox("compute_spread", "Compute spread", value=True),
  183. ui.input_checkbox("compute_local_contrast", "Compute local contrast", value=True),
  184. ui.input_checkbox("compute_sphericity", "Compute sphericity", value=True),
  185. ui.input_checkbox("crop_to_roi", "Crop to ROI", value=True),
  186. ui.hr(),
  187. ui.input_action_button("run", "Compute selected row", class_="btn-primary"),
  188. width=330,
  189. ),
  190. ui.navset_tab(
  191. ui.nav_panel(
  192. "Metadata table",
  193. ui.p("Select one row from the table, then click 'Compute selected row'."),
  194. ui.output_data_frame("metadata_table"),
  195. ),
  196. ui.nav_panel(
  197. "SUV PDF",
  198. ui.output_ui("selected_summary_pdf"),
  199. output_widget("suv_pdf_plot", height="560px"),
  200. ui.h4("SUV percentiles"),
  201. ui.output_data_frame("suv_percentiles_table"),
  202. ),
  203. ui.nav_panel(
  204. "Hot voxels",
  205. ui.output_ui("selected_summary_hot"),
  206. ui.output_ui("hot_voxel_tabs"),
  207. ),
  208. ui.nav_panel(
  209. "Spatial features",
  210. ui.output_ui("selected_summary_features"),
  211. ui.output_data_frame("features_table"),
  212. ),
  213. ui.nav_panel(
  214. "Errors / diagnostics",
  215. ui.output_text_verbatim("diagnostics"),
  216. ),
  217. ),
  218. ),
  219. )
  220. # -----------------------------------------------------------------------------
  221. # Server
  222. # -----------------------------------------------------------------------------
  223. def server(input: Inputs, output: Outputs, session: Session):
  224. @render.data_frame
  225. def metadata_table():
  226. return render.DataGrid(
  227. DF_DISPLAY,
  228. selection_mode="row",
  229. filters=True,
  230. height="650px",
  231. )
  232. @reactive.calc
  233. def selected_row_display() -> pd.Series | None:
  234. selected = metadata_table.cell_selection()["rows"]
  235. if not selected:
  236. return None
  237. return DF_DISPLAY.iloc[int(selected[0])]
  238. @reactive.calc
  239. @reactive.event(input.run)
  240. def analysis_result():
  241. """Load selected image and compute all requested outputs once per click."""
  242. row_display = selected_row_display()
  243. if row_display is None:
  244. return {
  245. "ok": False,
  246. "error": "No row selected. Select a row in the metadata table first.",
  247. }
  248. try:
  249. probs = _parse_probs(input.probs())
  250. patient_id = row_display["patient_id"]
  251. organ = row_display["organ"]
  252. visit = int(row_display["visit"])
  253. index_key = (patient_id, organ, visit)
  254. row_meta, processed_image = get_processed_image(
  255. DF_META,
  256. patient_id=patient_id,
  257. organ=organ,
  258. visit=visit,
  259. )
  260. image_id = _format_image_id(row_display)
  261. pdf_fig, suv_percentiles = plot_suv_pdf_plotly(
  262. processed_image,
  263. percentiles=probs,
  264. bins=int(input.bins()),
  265. log_x=bool(input.log_x()),
  266. min_suv=float(input.min_suv()),
  267. title=f"SUV distribution: {image_id}",
  268. )
  269. hot_figs: dict[float, go.Figure] = {}
  270. for p in probs:
  271. threshold = float(
  272. suv_percentiles.loc[
  273. suv_percentiles["percentile"] == p,
  274. "suv_threshold",
  275. ].iloc[0]
  276. )
  277. hot_fig = plot_hot_voxels_plotly(
  278. processed_image,
  279. c=threshold,
  280. )
  281. hot_fig.update_layout(
  282. title=f"Hot voxels: {image_id}, p{p:g}, SUV ≥ {threshold:.4g}"
  283. )
  284. hot_figs[p] = hot_fig
  285. features = compute_tail_spatial_features(
  286. image=processed_image,
  287. percentiles=probs,
  288. component_connectivity=int(input.component_connectivity()),
  289. contrast_connectivity=int(input.contrast_connectivity()),
  290. min_component_voxels=int(input.min_component_voxels()),
  291. compute_spread=bool(input.compute_spread()),
  292. compute_local_contrast=bool(input.compute_local_contrast()),
  293. compute_sphericity=bool(input.compute_sphericity()),
  294. crop_to_roi=bool(input.crop_to_roi()),
  295. image_id=index_key,
  296. )
  297. return {
  298. "ok": True,
  299. "row_display": row_display,
  300. "row_meta": row_meta,
  301. "image_id": image_id,
  302. "probs": probs,
  303. "pdf_fig": pdf_fig,
  304. "suv_percentiles": suv_percentiles,
  305. "hot_figs": hot_figs,
  306. "features": features,
  307. "diagnostics": f"Computed successfully for {image_id}",
  308. }
  309. except Exception as exc:
  310. return {
  311. "ok": False,
  312. "row_display": row_display,
  313. "error": _safe_error_message(exc),
  314. }
  315. def _selected_summary(result: dict) -> ui.TagList:
  316. if not result.get("ok"):
  317. return ui.TagList(ui.div(ui.strong("No successful computation yet.")))
  318. row = result["row_display"]
  319. return ui.TagList(
  320. ui.div(
  321. ui.strong("Selected image: "),
  322. f"{row['patient_id']} | {row['organ']} | VISIT_{row['visit']}",
  323. )
  324. )
  325. @render.ui
  326. def selected_summary_pdf():
  327. return _selected_summary(analysis_result())
  328. @render.ui
  329. def selected_summary_hot():
  330. return _selected_summary(analysis_result())
  331. @render.ui
  332. def selected_summary_features():
  333. return _selected_summary(analysis_result())
  334. @render_widget
  335. def suv_pdf_plot():
  336. result = analysis_result()
  337. if not result.get("ok"):
  338. return _empty_figure("Select a row and click 'Compute selected row' to show the SUV PDF.")
  339. return result["pdf_fig"]
  340. @render.data_frame
  341. def suv_percentiles_table():
  342. result = analysis_result()
  343. if not result.get("ok"):
  344. return pd.DataFrame()
  345. return render.DataGrid(result["suv_percentiles"], height="250px")
  346. @render.ui
  347. def hot_voxel_tabs():
  348. result = analysis_result()
  349. if not result.get("ok"):
  350. return ui.div(ui.em("Select a row and compute to show hot-voxel plots."))
  351. probs = list(result["probs"])
  352. max_plots = 5
  353. if len(probs) > max_plots:
  354. probs = probs[:max_plots]
  355. tabs = []
  356. for i, p in enumerate(probs):
  357. tabs.append(
  358. ui.nav_panel(
  359. f"p{p:g}",
  360. output_widget(f"hot_voxel_plot_{i}", height="760px"),
  361. )
  362. )
  363. return ui.navset_tab(*tabs)
  364. def _hot_voxel_figure_by_index(index: int):
  365. result = analysis_result()
  366. if not result.get("ok"):
  367. return _empty_figure("No hot-voxel plot available.")
  368. probs = list(result["probs"])
  369. if index >= len(probs):
  370. return _empty_figure("No percentile assigned to this plot.")
  371. p = probs[index]
  372. fig = result["hot_figs"].get(p)
  373. if fig is None:
  374. return _empty_figure(f"Percentile p{p:g} was not requested.")
  375. return fig
  376. @render_widget
  377. def hot_voxel_plot_0():
  378. return _hot_voxel_figure_by_index(0)
  379. @render_widget
  380. def hot_voxel_plot_1():
  381. return _hot_voxel_figure_by_index(1)
  382. @render_widget
  383. def hot_voxel_plot_2():
  384. return _hot_voxel_figure_by_index(2)
  385. @render_widget
  386. def hot_voxel_plot_3():
  387. return _hot_voxel_figure_by_index(3)
  388. @render_widget
  389. def hot_voxel_plot_4():
  390. return _hot_voxel_figure_by_index(4)
  391. @render.data_frame
  392. def features_table():
  393. result = analysis_result()
  394. if not result.get("ok"):
  395. return pd.DataFrame()
  396. return render.DataGrid(result["features"], height="420px")
  397. @render.text
  398. def diagnostics():
  399. result = analysis_result()
  400. if result.get("ok"):
  401. return result.get("diagnostics", "OK")
  402. return result.get("error", "Unknown error")
  403. app = App(app_ui, server)
  404. if __name__ == "__main__":
  405. from shiny import run_app
  406. run_app("app:app", host="127.0.0.1", port=8000, reload=True)