app.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import sys
  4. import traceback
  5. import pandas as pd
  6. import plotly.graph_objects as go
  7. from shiny import App, Inputs, Outputs, Session, reactive, render, ui
  8. from shinywidgets import output_widget, render_widget
  9. # -----------------------------------------------------------------------------
  10. # Project paths
  11. # -----------------------------------------------------------------------------
  12. APP_DIR = Path(__file__).resolve().parent
  13. PROJECT_ROOT = APP_DIR.parent
  14. if str(PROJECT_ROOT) not in sys.path:
  15. sys.path.insert(0, str(PROJECT_ROOT))
  16. DATA_RAW = PROJECT_ROOT / "data" / "raw"
  17. DATA_GEN = PROJECT_ROOT / "data" / "gen"
  18. DEFAULT_METADATA_PATH = DATA_GEN / "metadata.pkl"
  19. WWW_DIR = APP_DIR / "www"
  20. # -----------------------------------------------------------------------------
  21. # Local package imports
  22. # -----------------------------------------------------------------------------
  23. from src.metadata import ( # noqa: E402
  24. get_meta_data,
  25. flag_corrupted_files,
  26. flag_AE_patients,
  27. )
  28. from src.image_io import get_processed_image # noqa: E402
  29. from src.plotting import ( # noqa: E402
  30. plot_suv_pdf_plotly,
  31. plot_hot_voxels_plotly,
  32. )
  33. from src.spatial_features import compute_tail_spatial_features # noqa: E402
  34. # -----------------------------------------------------------------------------
  35. # Helper functions
  36. # -----------------------------------------------------------------------------
  37. def _empty_figure(message: str = "No plot available") -> go.Figure:
  38. """Return an empty Plotly figure with a centered annotation."""
  39. fig = go.Figure()
  40. fig.add_annotation(
  41. text=message,
  42. xref="paper",
  43. yref="paper",
  44. x=0.5,
  45. y=0.5,
  46. showarrow=False,
  47. )
  48. fig.update_layout(
  49. template="plotly_white",
  50. xaxis={"visible": False},
  51. yaxis={"visible": False},
  52. height=500,
  53. )
  54. return fig
  55. def _load_metadata() -> pd.DataFrame:
  56. """Load metadata table from disk or build it from DATA_RAW.
  57. If data/gen/metadata.pkl does not exist, the metadata table is recreated
  58. from data/raw and saved to data/gen/metadata.pkl.
  59. """
  60. if DEFAULT_METADATA_PATH.exists():
  61. df_meta = pd.read_pickle(DEFAULT_METADATA_PATH)
  62. else:
  63. df_meta = get_meta_data(str(DATA_RAW))
  64. df_meta = flag_corrupted_files(df_meta)
  65. df_meta = flag_AE_patients(df_meta)
  66. DATA_GEN.mkdir(parents=True, exist_ok=True)
  67. df_meta.to_pickle(DEFAULT_METADATA_PATH)
  68. if not isinstance(df_meta.index, pd.MultiIndex):
  69. required_cols = {"patient_id", "organ", "visit"}
  70. if required_cols.issubset(df_meta.columns):
  71. df_meta = df_meta.set_index(["patient_id", "organ", "visit"])
  72. else:
  73. raise ValueError(
  74. "Metadata table must have MultiIndex (patient_id, organ, visit) "
  75. "or columns patient_id, organ, visit."
  76. )
  77. required_columns = {"PET_path", "SEG_path"}
  78. missing = required_columns.difference(df_meta.columns)
  79. if missing:
  80. raise ValueError(f"Metadata table is missing required columns: {missing}")
  81. return df_meta.sort_index()
  82. DF_META = _load_metadata()
  83. def _metadata_for_display(df_meta: pd.DataFrame) -> pd.DataFrame:
  84. """Return metadata for UI display.
  85. Path-like columns are hidden from the table, but they remain available in
  86. DF_META for loading images.
  87. """
  88. df = df_meta.reset_index().copy()
  89. df.insert(0, "row_id", range(len(df)))
  90. # Hide all path-like columns from the UI table.
  91. df = df.loc[:, ~df.columns.str.contains("path", case=False)]
  92. df = df.loc[:, ~df.columns.str.contains("filename", case=False)]
  93. preferred = [
  94. "row_id",
  95. "patient_id",
  96. "organ",
  97. "visit",
  98. "is_AE_patient",
  99. "is_corrupted",
  100. "PET_filename",
  101. "SEG_filename",
  102. ]
  103. cols = [c for c in preferred if c in df.columns] + [
  104. c for c in df.columns if c not in preferred
  105. ]
  106. return df[cols]
  107. DF_DISPLAY = _metadata_for_display(DF_META)
  108. ORGAN_CHOICES = ["All"] + sorted(DF_DISPLAY["organ"].dropna().astype(str).unique().tolist())
  109. VISIT_CHOICES = ["All"] + [str(v) for v in sorted(DF_DISPLAY["visit"].dropna().unique().tolist())]
  110. def _format_image_id(row: pd.Series) -> str:
  111. """Create a readable image identifier from a selected metadata row."""
  112. return f"{row['patient_id']}_{row['organ']}_VISIT_{row['visit']}"
  113. def _parse_probs(prob_string: str) -> tuple[float, ...]:
  114. """Parse comma-separated percentile values from UI input."""
  115. values: list[float] = []
  116. for part in prob_string.split(","):
  117. part = part.strip()
  118. if not part:
  119. continue
  120. values.append(float(part))
  121. if not values:
  122. raise ValueError("At least one percentile must be supplied.")
  123. for value in values:
  124. if not (0 < value < 100):
  125. raise ValueError("Percentiles must be between 0 and 100.")
  126. return tuple(values)
  127. def _safe_error_message(exc: BaseException) -> str:
  128. """Return a readable error message for the diagnostics tab."""
  129. return "\n".join(
  130. [
  131. f"{type(exc).__name__}: {exc}",
  132. "",
  133. traceback.format_exc(limit=8),
  134. ]
  135. )
  136. # -----------------------------------------------------------------------------
  137. # UI
  138. # -----------------------------------------------------------------------------
  139. app_ui = ui.page_fluid(
  140. ui.head_content(
  141. ui.tags.link(
  142. rel="shortcut icon",
  143. href="/favicon.ico?v=3",
  144. type="image/x-icon",
  145. )
  146. ),
  147. ui.h2("Spatial SUV tail-feature explorer"),
  148. ui.layout_sidebar(
  149. ui.sidebar(
  150. ui.div(
  151. ui.h5("Run analysis"),
  152. ui.p("Select one row in the metadata table, then compute."),
  153. ui.input_action_button(
  154. "run",
  155. "Compute selected row",
  156. class_="btn-success btn-lg",
  157. width="100%",
  158. ),
  159. style="""
  160. position: sticky;
  161. top: 0;
  162. z-index: 1000;
  163. background: white;
  164. padding: 1rem 0 1rem 0;
  165. border-bottom: 1px solid #ddd;
  166. margin-bottom: 1rem;
  167. """,
  168. ),
  169. ui.h4("Metadata filters"),
  170. ui.input_text(
  171. "filter_patient",
  172. "Patient contains",
  173. value="",
  174. placeholder="e.g. A001",
  175. ),
  176. ui.input_select(
  177. "filter_organ",
  178. "Organ",
  179. choices=ORGAN_CHOICES,
  180. selected="All",
  181. ),
  182. ui.input_select(
  183. "filter_visit",
  184. "Visit",
  185. choices=VISIT_CHOICES,
  186. selected="All",
  187. ),
  188. ui.input_select(
  189. "filter_ae",
  190. "AE patient",
  191. choices=["All", "True", "False"],
  192. selected="All",
  193. ),
  194. ui.input_select(
  195. "filter_corrupted",
  196. "Corrupted",
  197. choices=["All", "True", "False"],
  198. selected="All",
  199. ),
  200. ui.hr(),
  201. ui.h4("Analysis settings"),
  202. ui.input_text(
  203. "probs",
  204. "Percentiles",
  205. value="80, 90, 95",
  206. placeholder="80, 90, 95",
  207. ),
  208. ui.input_numeric("bins", "Histogram bins", value=100, min=10, max=500),
  209. ui.input_numeric("min_suv", "Minimum SUV for PDF", value=0.1, min=0),
  210. ui.input_checkbox("log_x", "Log x-axis for SUV PDF", value=True),
  211. ui.hr(),
  212. ui.input_numeric(
  213. "min_component_voxels",
  214. "Minimum component voxels",
  215. value=3,
  216. min=1,
  217. step=1,
  218. ),
  219. ui.input_select(
  220. "component_connectivity",
  221. "Component connectivity",
  222. choices={"6": "6", "18": "18", "26": "26"},
  223. selected="26",
  224. ),
  225. ui.input_select(
  226. "contrast_connectivity",
  227. "Contrast connectivity",
  228. choices={"6": "6", "18": "18", "26": "26"},
  229. selected="26",
  230. ),
  231. ui.input_checkbox("compute_spread", "Compute spread", value=True),
  232. ui.input_checkbox("compute_local_contrast", "Compute local contrast", value=True),
  233. ui.input_checkbox("compute_sphericity", "Compute sphericity", value=True),
  234. ui.input_checkbox("crop_to_roi", "Crop to ROI", value=True),
  235. width=340,
  236. ),
  237. ui.navset_tab(
  238. ui.nav_panel(
  239. "Metadata table",
  240. ui.p("Filter and select one row, then click 'Compute selected row'."),
  241. ui.output_ui("filter_summary"),
  242. ui.output_data_frame("metadata_table"),
  243. ),
  244. ui.nav_panel(
  245. "SUV PDF",
  246. ui.output_ui("selected_summary_pdf"),
  247. output_widget("suv_pdf_plot", height="560px"),
  248. ui.h4("SUV percentiles"),
  249. ui.output_data_frame("suv_percentiles_table"),
  250. ),
  251. ui.nav_panel(
  252. "Hot voxels",
  253. ui.output_ui("selected_summary_hot"),
  254. ui.p(
  255. "Hot-voxel plots are shown in fixed output slots to avoid "
  256. "recreating Plotly widget containers after each computation."
  257. ),
  258. ui.navset_tab(
  259. ui.nav_panel("Hot 1", output_widget("hot_voxel_plot_0", height="760px")),
  260. ui.nav_panel("Hot 2", output_widget("hot_voxel_plot_1", height="760px")),
  261. ui.nav_panel("Hot 3", output_widget("hot_voxel_plot_2", height="760px")),
  262. ui.nav_panel("Hot 4", output_widget("hot_voxel_plot_3", height="760px")),
  263. ui.nav_panel("Hot 5", output_widget("hot_voxel_plot_4", height="760px")),
  264. ),
  265. ),
  266. ui.nav_panel(
  267. "Spatial features",
  268. ui.output_ui("selected_summary_features"),
  269. ui.output_data_frame("features_table"),
  270. ),
  271. ui.nav_panel(
  272. "Errors / diagnostics",
  273. ui.output_text_verbatim("diagnostics"),
  274. ),
  275. ),
  276. ),
  277. )
  278. # -----------------------------------------------------------------------------
  279. # Server
  280. # -----------------------------------------------------------------------------
  281. def server(input: Inputs, output: Outputs, session: Session):
  282. @reactive.calc
  283. def filtered_metadata_display() -> pd.DataFrame:
  284. """Apply sidebar filters to the displayed metadata table."""
  285. df = DF_DISPLAY.copy()
  286. patient_text = input.filter_patient().strip()
  287. organ = input.filter_organ()
  288. visit = input.filter_visit()
  289. ae = input.filter_ae()
  290. corrupted = input.filter_corrupted()
  291. if patient_text:
  292. df = df[
  293. df["patient_id"]
  294. .astype(str)
  295. .str.contains(patient_text, case=False, na=False)
  296. ]
  297. if organ != "All":
  298. df = df[df["organ"].astype(str) == organ]
  299. if visit != "All":
  300. df = df[df["visit"].astype(str) == visit]
  301. if ae != "All" and "is_AE_patient" in df.columns:
  302. df = df[df["is_AE_patient"].astype(bool) == (ae == "True")]
  303. if corrupted != "All" and "is_corrupted" in df.columns:
  304. df = df[df["is_corrupted"].astype(bool) == (corrupted == "True")]
  305. return df.reset_index(drop=True)
  306. @render.ui
  307. def filter_summary():
  308. n_shown = len(filtered_metadata_display())
  309. n_total = len(DF_DISPLAY)
  310. return ui.div(
  311. ui.strong("Rows shown: "),
  312. f"{n_shown} / {n_total}",
  313. )
  314. @render.data_frame
  315. def metadata_table():
  316. return render.DataGrid(
  317. filtered_metadata_display(),
  318. selection_mode="row",
  319. filters=False,
  320. width="100%",
  321. height="650px",
  322. )
  323. @reactive.calc
  324. def selected_row_display() -> pd.Series | None:
  325. selected = metadata_table.cell_selection()["rows"]
  326. if not selected:
  327. return None
  328. df = filtered_metadata_display()
  329. if df.empty:
  330. return None
  331. row_pos = int(selected[0])
  332. if row_pos >= len(df):
  333. return None
  334. return df.iloc[row_pos]
  335. @reactive.calc
  336. @reactive.event(input.run)
  337. def analysis_result():
  338. """Load selected image and compute all requested outputs once per click."""
  339. row_display = selected_row_display()
  340. if row_display is None:
  341. return {
  342. "ok": False,
  343. "error": "No row selected. Select a row in the metadata table first.",
  344. }
  345. try:
  346. probs = _parse_probs(input.probs())
  347. patient_id = row_display["patient_id"]
  348. organ = row_display["organ"]
  349. visit = int(row_display["visit"])
  350. index_key = (patient_id, organ, visit)
  351. row_meta, processed_image = get_processed_image(
  352. DF_META,
  353. patient_id=patient_id,
  354. organ=organ,
  355. visit=visit,
  356. )
  357. image_id = _format_image_id(row_display)
  358. pdf_fig, suv_percentiles = plot_suv_pdf_plotly(
  359. processed_image,
  360. percentiles=probs,
  361. bins=int(input.bins()),
  362. log_x=bool(input.log_x()),
  363. min_suv=float(input.min_suv()),
  364. title=f"SUV distribution: {image_id}",
  365. )
  366. hot_figs: dict[float, go.Figure] = {}
  367. for p in probs:
  368. threshold = float(
  369. suv_percentiles.loc[
  370. suv_percentiles["percentile"] == p,
  371. "suv_threshold",
  372. ].iloc[0]
  373. )
  374. hot_fig = plot_hot_voxels_plotly(
  375. processed_image,
  376. c=threshold,
  377. )
  378. hot_fig.update_layout(
  379. title=f"Hot voxels: {image_id}, p{p:g}, SUV ≥ {threshold:.4g}"
  380. )
  381. hot_figs[p] = hot_fig
  382. features = compute_tail_spatial_features(
  383. image=processed_image,
  384. percentiles=probs,
  385. component_connectivity=int(input.component_connectivity()),
  386. contrast_connectivity=int(input.contrast_connectivity()),
  387. min_component_voxels=int(input.min_component_voxels()),
  388. compute_spread=bool(input.compute_spread()),
  389. compute_local_contrast=bool(input.compute_local_contrast()),
  390. compute_sphericity=bool(input.compute_sphericity()),
  391. crop_to_roi=bool(input.crop_to_roi()),
  392. image_id=index_key,
  393. )
  394. return {
  395. "ok": True,
  396. "row_display": row_display,
  397. "row_meta": row_meta,
  398. "image_id": image_id,
  399. "probs": probs,
  400. "pdf_fig": pdf_fig,
  401. "suv_percentiles": suv_percentiles,
  402. "hot_figs": hot_figs,
  403. "features": features,
  404. "diagnostics": f"Computed successfully for {image_id}",
  405. }
  406. except Exception as exc:
  407. return {
  408. "ok": False,
  409. "row_display": row_display,
  410. "error": _safe_error_message(exc),
  411. }
  412. def _selected_summary(result: dict) -> ui.TagList:
  413. if not result.get("ok"):
  414. return ui.TagList(ui.div(ui.strong("No successful computation yet.")))
  415. row = result["row_display"]
  416. return ui.TagList(
  417. ui.div(
  418. ui.strong("Selected image: "),
  419. f"{row['patient_id']} | {row['organ']} | VISIT_{row['visit']}",
  420. )
  421. )
  422. @render.ui
  423. def selected_summary_pdf():
  424. return _selected_summary(analysis_result())
  425. @render.ui
  426. def selected_summary_hot():
  427. return _selected_summary(analysis_result())
  428. @render.ui
  429. def selected_summary_features():
  430. return _selected_summary(analysis_result())
  431. @render_widget
  432. def suv_pdf_plot():
  433. result = analysis_result()
  434. if not result.get("ok"):
  435. return _empty_figure("Select a row and click 'Compute selected row' to show the SUV PDF.")
  436. return result["pdf_fig"]
  437. @render.data_frame
  438. def suv_percentiles_table():
  439. result = analysis_result()
  440. if not result.get("ok"):
  441. return pd.DataFrame()
  442. return render.DataGrid(result["suv_percentiles"], height="250px")
  443. # Important: hot-voxel output widgets are defined statically in the UI.
  444. # Do not recreate Plotly widget containers dynamically after every run; doing
  445. # so can trigger shinywidgets/Plotly JavaScript state errors on subsequent
  446. # row selections.
  447. def _hot_voxel_figure_by_index(index: int):
  448. result = analysis_result()
  449. if not result.get("ok"):
  450. return _empty_figure("No hot-voxel plot available.")
  451. probs = list(result["probs"])
  452. if index >= len(probs):
  453. return _empty_figure("No percentile assigned to this plot.")
  454. p = probs[index]
  455. fig = result["hot_figs"].get(p)
  456. if fig is None:
  457. return _empty_figure(f"Percentile p{p:g} was not requested.")
  458. return fig
  459. @render_widget
  460. def hot_voxel_plot_0():
  461. return _hot_voxel_figure_by_index(0)
  462. @render_widget
  463. def hot_voxel_plot_1():
  464. return _hot_voxel_figure_by_index(1)
  465. @render_widget
  466. def hot_voxel_plot_2():
  467. return _hot_voxel_figure_by_index(2)
  468. @render_widget
  469. def hot_voxel_plot_3():
  470. return _hot_voxel_figure_by_index(3)
  471. @render_widget
  472. def hot_voxel_plot_4():
  473. return _hot_voxel_figure_by_index(4)
  474. @render.data_frame
  475. def features_table():
  476. result = analysis_result()
  477. if not result.get("ok"):
  478. return pd.DataFrame()
  479. return render.DataGrid(result["features"], height="420px")
  480. @render.text
  481. def diagnostics():
  482. result = analysis_result()
  483. if result.get("ok"):
  484. return result.get("diagnostics", "OK")
  485. return result.get("error", "Unknown error")
  486. app = App(
  487. app_ui,
  488. server,
  489. static_assets=WWW_DIR,
  490. )
  491. if __name__ == "__main__":
  492. from shiny import run_app
  493. # For development, prefer:
  494. # shiny run --reload shiny_app/app.py
  495. # Direct execution via `python shiny_app/app.py` works without auto-reload.
  496. run_app(app, host="127.0.0.1", port=8000, reload=False)