app.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import sys
  4. import traceback
  5. import pandas as pd
  6. import plotly.graph_objects as go
  7. from shiny import App, Inputs, Outputs, Session, reactive, render, ui
  8. from shinywidgets import output_widget, render_widget
  9. # -----------------------------------------------------------------------------
  10. # Project paths
  11. # -----------------------------------------------------------------------------
  12. APP_DIR = Path(__file__).resolve().parent
  13. PROJECT_ROOT = APP_DIR.parent
  14. if str(PROJECT_ROOT) not in sys.path:
  15. sys.path.insert(0, str(PROJECT_ROOT))
  16. DATA_RAW = PROJECT_ROOT / "data" / "raw"
  17. DATA_GEN = PROJECT_ROOT / "data" / "gen"
  18. DEFAULT_METADATA_PATH = DATA_GEN / "metadata.pkl"
  19. # -----------------------------------------------------------------------------
  20. # Local package imports
  21. # -----------------------------------------------------------------------------
  22. from src.metadata import ( # noqa: E402
  23. get_meta_data,
  24. flag_corrupted_files,
  25. flag_AE_patients,
  26. )
  27. from src.image_io import get_processed_image # noqa: E402
  28. from src.plotting import ( # noqa: E402
  29. plot_suv_pdf_plotly,
  30. plot_hot_voxels_plotly,
  31. )
  32. from src.spatial_features import compute_tail_spatial_features # noqa: E402
  33. # -----------------------------------------------------------------------------
  34. # Helper functions
  35. # -----------------------------------------------------------------------------
  36. def _empty_figure(message: str = "No plot available") -> go.Figure:
  37. """Return an empty Plotly figure with a centered annotation."""
  38. fig = go.Figure()
  39. fig.add_annotation(
  40. text=message,
  41. xref="paper",
  42. yref="paper",
  43. x=0.5,
  44. y=0.5,
  45. showarrow=False,
  46. )
  47. fig.update_layout(
  48. template="plotly_white",
  49. xaxis={"visible": False},
  50. yaxis={"visible": False},
  51. height=500,
  52. )
  53. return fig
  54. def _load_metadata() -> pd.DataFrame:
  55. """Load metadata table from disk or build it from DATA_RAW."""
  56. if DEFAULT_METADATA_PATH.exists():
  57. df_meta = pd.read_pickle(DEFAULT_METADATA_PATH)
  58. else:
  59. df_meta = get_meta_data(str(DATA_RAW))
  60. df_meta = flag_corrupted_files(df_meta)
  61. df_meta = flag_AE_patients(df_meta)
  62. DATA_GEN.mkdir(parents=True, exist_ok=True)
  63. df_meta.to_pickle(DEFAULT_METADATA_PATH)
  64. if not isinstance(df_meta.index, pd.MultiIndex):
  65. required_cols = {"patient_id", "organ", "visit"}
  66. if required_cols.issubset(df_meta.columns):
  67. df_meta = df_meta.set_index(["patient_id", "organ", "visit"])
  68. else:
  69. raise ValueError(
  70. "Metadata table must have MultiIndex (patient_id, organ, visit) "
  71. "or columns patient_id, organ, visit."
  72. )
  73. required_columns = {"PET_path", "SEG_path"}
  74. missing = required_columns.difference(df_meta.columns)
  75. if missing:
  76. raise ValueError(f"Metadata table is missing required columns: {missing}")
  77. return df_meta.sort_index()
  78. DF_META = _load_metadata()
  79. def _metadata_for_display(df_meta: pd.DataFrame) -> pd.DataFrame:
  80. """Return metadata with index columns exposed and an internal row_id."""
  81. df = df_meta.reset_index().copy()
  82. df.insert(0, "row_id", range(len(df)))
  83. preferred = [
  84. "row_id",
  85. "patient_id",
  86. "organ",
  87. "visit",
  88. "is_AE_patient",
  89. "is_corrupted",
  90. "PET_filename",
  91. "SEG_filename",
  92. "PET_path",
  93. "SEG_path",
  94. ]
  95. cols = [c for c in preferred if c in df.columns] + [
  96. c for c in df.columns if c not in preferred
  97. ]
  98. return df[cols]
  99. DF_DISPLAY = _metadata_for_display(DF_META)
  100. def _format_image_id(row: pd.Series) -> str:
  101. """Create a readable image identifier from a selected metadata row."""
  102. return f"{row['patient_id']}_{row['organ']}_VISIT_{row['visit']}"
  103. def _parse_probs(prob_string: str) -> tuple[float, ...]:
  104. """Parse comma-separated percentile values from UI input."""
  105. values: list[float] = []
  106. for part in prob_string.split(","):
  107. part = part.strip()
  108. if not part:
  109. continue
  110. values.append(float(part))
  111. if not values:
  112. raise ValueError("At least one percentile must be supplied.")
  113. for value in values:
  114. if not (0 < value < 100):
  115. raise ValueError("Percentiles must be between 0 and 100.")
  116. return tuple(values)
  117. def _safe_error_message(exc: BaseException) -> str:
  118. """Return a readable error message for the diagnostics tab."""
  119. return "\n".join(
  120. [
  121. f"{type(exc).__name__}: {exc}",
  122. "",
  123. traceback.format_exc(limit=8),
  124. ]
  125. )
  126. # -----------------------------------------------------------------------------
  127. # UI
  128. # -----------------------------------------------------------------------------
  129. app_ui = ui.page_fluid(
  130. ui.h2("Spatial SUV tail-feature explorer"),
  131. ui.layout_sidebar(
  132. ui.sidebar(
  133. ui.input_text(
  134. "probs",
  135. "Percentiles",
  136. value="80, 90, 95",
  137. placeholder="80, 90, 95",
  138. ),
  139. ui.input_numeric("bins", "Histogram bins", value=100, min=10, max=500),
  140. ui.input_numeric("min_suv", "Minimum SUV for PDF", value=0.1, min=0),
  141. ui.input_checkbox("log_x", "Log x-axis for SUV PDF", value=True),
  142. ui.hr(),
  143. ui.input_numeric(
  144. "min_component_voxels",
  145. "Minimum component voxels",
  146. value=3,
  147. min=1,
  148. step=1,
  149. ),
  150. ui.input_select(
  151. "component_connectivity",
  152. "Component connectivity",
  153. choices={"6": "6", "18": "18", "26": "26"},
  154. selected="26",
  155. ),
  156. ui.input_select(
  157. "contrast_connectivity",
  158. "Contrast connectivity",
  159. choices={"6": "6", "18": "18", "26": "26"},
  160. selected="26",
  161. ),
  162. ui.input_checkbox("compute_spread", "Compute spread", value=True),
  163. ui.input_checkbox("compute_local_contrast", "Compute local contrast", value=True),
  164. ui.input_checkbox("compute_sphericity", "Compute sphericity", value=True),
  165. ui.input_checkbox("crop_to_roi", "Crop to ROI", value=True),
  166. ui.hr(),
  167. ui.input_action_button("run", "Compute selected row", class_="btn-primary"),
  168. width=330,
  169. ),
  170. ui.navset_tab(
  171. ui.nav_panel(
  172. "Metadata table",
  173. ui.p("Select one row from the table, then click 'Compute selected row'."),
  174. ui.output_data_frame("metadata_table"),
  175. ),
  176. ui.nav_panel(
  177. "SUV PDF",
  178. ui.output_ui("selected_summary_pdf"),
  179. output_widget("suv_pdf_plot", height="560px"),
  180. ui.h4("SUV percentiles"),
  181. ui.output_data_frame("suv_percentiles_table"),
  182. ),
  183. ui.nav_panel(
  184. "Hot voxels",
  185. ui.output_ui("selected_summary_hot"),
  186. ui.output_ui("hot_voxel_tabs"),
  187. ),
  188. ui.nav_panel(
  189. "Spatial features",
  190. ui.output_ui("selected_summary_features"),
  191. ui.output_data_frame("features_table"),
  192. ),
  193. ui.nav_panel(
  194. "Errors / diagnostics",
  195. ui.output_text_verbatim("diagnostics"),
  196. ),
  197. ),
  198. ),
  199. )
  200. # -----------------------------------------------------------------------------
  201. # Server
  202. # -----------------------------------------------------------------------------
  203. def server(input: Inputs, output: Outputs, session: Session):
  204. @render.data_frame
  205. def metadata_table():
  206. return render.DataGrid(
  207. DF_DISPLAY,
  208. selection_mode="row",
  209. filters=True,
  210. height="650px",
  211. )
  212. @reactive.calc
  213. def selected_row_display() -> pd.Series | None:
  214. selected = metadata_table.cell_selection()["rows"]
  215. if not selected:
  216. return None
  217. return DF_DISPLAY.iloc[int(selected[0])]
  218. @reactive.calc
  219. @reactive.event(input.run)
  220. def analysis_result():
  221. """Load selected image and compute all requested outputs once per click."""
  222. row_display = selected_row_display()
  223. if row_display is None:
  224. return {
  225. "ok": False,
  226. "error": "No row selected. Select a row in the metadata table first.",
  227. }
  228. try:
  229. probs = _parse_probs(input.probs())
  230. patient_id = row_display["patient_id"]
  231. organ = row_display["organ"]
  232. visit = int(row_display["visit"])
  233. index_key = (patient_id, organ, visit)
  234. row_meta, processed_image = get_processed_image(
  235. DF_META,
  236. patient_id=patient_id,
  237. organ=organ,
  238. visit=visit,
  239. )
  240. image_id = _format_image_id(row_display)
  241. pdf_fig, suv_percentiles = plot_suv_pdf_plotly(
  242. processed_image,
  243. percentiles=probs,
  244. bins=int(input.bins()),
  245. log_x=bool(input.log_x()),
  246. min_suv=float(input.min_suv()),
  247. title=f"SUV distribution: {image_id}",
  248. )
  249. hot_figs: dict[float, go.Figure] = {}
  250. for p in probs:
  251. threshold = float(
  252. suv_percentiles.loc[
  253. suv_percentiles["percentile"] == p,
  254. "suv_threshold",
  255. ].iloc[0]
  256. )
  257. hot_fig = plot_hot_voxels_plotly(
  258. processed_image,
  259. c=threshold,
  260. )
  261. hot_fig.update_layout(
  262. title=f"Hot voxels: {image_id}, p{p:g}, SUV ≥ {threshold:.4g}"
  263. )
  264. hot_figs[p] = hot_fig
  265. features = compute_tail_spatial_features(
  266. image=processed_image,
  267. percentiles=probs,
  268. component_connectivity=int(input.component_connectivity()),
  269. contrast_connectivity=int(input.contrast_connectivity()),
  270. min_component_voxels=int(input.min_component_voxels()),
  271. compute_spread=bool(input.compute_spread()),
  272. compute_local_contrast=bool(input.compute_local_contrast()),
  273. compute_sphericity=bool(input.compute_sphericity()),
  274. crop_to_roi=bool(input.crop_to_roi()),
  275. image_id=index_key,
  276. )
  277. return {
  278. "ok": True,
  279. "row_display": row_display,
  280. "row_meta": row_meta,
  281. "image_id": image_id,
  282. "probs": probs,
  283. "pdf_fig": pdf_fig,
  284. "suv_percentiles": suv_percentiles,
  285. "hot_figs": hot_figs,
  286. "features": features,
  287. "diagnostics": f"Computed successfully for {image_id}",
  288. }
  289. except Exception as exc:
  290. return {
  291. "ok": False,
  292. "row_display": row_display,
  293. "error": _safe_error_message(exc),
  294. }
  295. def _selected_summary(result: dict) -> ui.TagList:
  296. if not result.get("ok"):
  297. return ui.TagList(ui.div(ui.strong("No successful computation yet.")))
  298. row = result["row_display"]
  299. return ui.TagList(
  300. ui.div(
  301. ui.strong("Selected image: "),
  302. f"{row['patient_id']} | {row['organ']} | VISIT_{row['visit']}",
  303. )
  304. )
  305. @render.ui
  306. def selected_summary_pdf():
  307. return _selected_summary(analysis_result())
  308. @render.ui
  309. def selected_summary_hot():
  310. return _selected_summary(analysis_result())
  311. @render.ui
  312. def selected_summary_features():
  313. return _selected_summary(analysis_result())
  314. @render_widget
  315. def suv_pdf_plot():
  316. result = analysis_result()
  317. if not result.get("ok"):
  318. return _empty_figure("Select a row and click 'Compute selected row' to show the SUV PDF.")
  319. return result["pdf_fig"]
  320. @render.data_frame
  321. def suv_percentiles_table():
  322. result = analysis_result()
  323. if not result.get("ok"):
  324. return pd.DataFrame()
  325. return render.DataGrid(result["suv_percentiles"], height="250px")
  326. @render.ui
  327. def hot_voxel_tabs():
  328. result = analysis_result()
  329. if not result.get("ok"):
  330. return ui.div(ui.em("Select a row and compute to show hot-voxel plots."))
  331. probs = list(result["probs"])
  332. max_plots = 5
  333. if len(probs) > max_plots:
  334. probs = probs[:max_plots]
  335. tabs = []
  336. for i, p in enumerate(probs):
  337. tabs.append(
  338. ui.nav_panel(
  339. f"p{p:g}",
  340. output_widget(f"hot_voxel_plot_{i}", height="760px"),
  341. )
  342. )
  343. return ui.navset_tab(*tabs)
  344. def _hot_voxel_figure_by_index(index: int):
  345. result = analysis_result()
  346. if not result.get("ok"):
  347. return _empty_figure("No hot-voxel plot available.")
  348. probs = list(result["probs"])
  349. if index >= len(probs):
  350. return _empty_figure("No percentile assigned to this plot.")
  351. p = probs[index]
  352. fig = result["hot_figs"].get(p)
  353. if fig is None:
  354. return _empty_figure(f"Percentile p{p:g} was not requested.")
  355. return fig
  356. @render_widget
  357. def hot_voxel_plot_0():
  358. return _hot_voxel_figure_by_index(0)
  359. @render_widget
  360. def hot_voxel_plot_1():
  361. return _hot_voxel_figure_by_index(1)
  362. @render_widget
  363. def hot_voxel_plot_2():
  364. return _hot_voxel_figure_by_index(2)
  365. @render_widget
  366. def hot_voxel_plot_3():
  367. return _hot_voxel_figure_by_index(3)
  368. @render_widget
  369. def hot_voxel_plot_4():
  370. return _hot_voxel_figure_by_index(4)
  371. @render.data_frame
  372. def features_table():
  373. result = analysis_result()
  374. if not result.get("ok"):
  375. return pd.DataFrame()
  376. return render.DataGrid(result["features"], height="420px")
  377. @render.text
  378. def diagnostics():
  379. result = analysis_result()
  380. if result.get("ok"):
  381. return result.get("diagnostics", "OK")
  382. return result.get("error", "Unknown error")
  383. app = App(app_ui, server)
  384. if __name__ == "__main__":
  385. from shiny import run_app
  386. # For development, prefer:
  387. # shiny run --reload shiny_app/app.py
  388. # Direct execution via `python shiny_app/app.py` works without auto-reload.
  389. run_app(app, host="127.0.0.1", port=8000, reload=False)