app.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import sys
  4. import traceback
  5. import pandas as pd
  6. import plotly.graph_objects as go
  7. from shiny import App, Inputs, Outputs, Session, reactive, render, ui
  8. from shinywidgets import output_widget, render_widget
  9. # -----------------------------------------------------------------------------
  10. # Project paths
  11. # -----------------------------------------------------------------------------
  12. APP_DIR = Path(__file__).resolve().parent
  13. PROJECT_ROOT = APP_DIR.parent
  14. if str(PROJECT_ROOT) not in sys.path:
  15. sys.path.insert(0, str(PROJECT_ROOT))
  16. DATA_RAW = PROJECT_ROOT / "data" / "raw"
  17. DATA_GEN = PROJECT_ROOT / "data" / "gen"
  18. DEFAULT_METADATA_PATH = DATA_GEN / "metadata.pkl"
  19. WWW_DIR = APP_DIR / "www"
  20. # -----------------------------------------------------------------------------
  21. # Local package imports
  22. # -----------------------------------------------------------------------------
  23. from src.metadata import ( # noqa: E402
  24. get_meta_data,
  25. flag_corrupted_files,
  26. flag_AE_patients,
  27. )
  28. from src.image_io import get_processed_image # noqa: E402
  29. from src.plotting import ( # noqa: E402
  30. plot_suv_pdf_plotly,
  31. plot_hot_voxels_plotly,
  32. )
  33. from src.spatial_features import compute_tail_spatial_features # noqa: E402
  34. # -----------------------------------------------------------------------------
  35. # Helper functions
  36. # -----------------------------------------------------------------------------
  37. def _empty_figure(message: str = "No plot available") -> go.Figure:
  38. """Return an empty Plotly figure with a centered annotation."""
  39. fig = go.Figure()
  40. fig.add_annotation(
  41. text=message,
  42. xref="paper",
  43. yref="paper",
  44. x=0.5,
  45. y=0.5,
  46. showarrow=False,
  47. )
  48. fig.update_layout(
  49. template="plotly_white",
  50. xaxis={"visible": False},
  51. yaxis={"visible": False},
  52. height=500,
  53. )
  54. return fig
  55. def _load_metadata() -> pd.DataFrame:
  56. """Load metadata table from disk or build it from DATA_RAW.
  57. If data/gen/metadata.pkl does not exist, the metadata table is recreated
  58. from data/raw and saved to data/gen/metadata.pkl.
  59. """
  60. if DEFAULT_METADATA_PATH.exists():
  61. df_meta = pd.read_pickle(DEFAULT_METADATA_PATH)
  62. else:
  63. df_meta = get_meta_data(str(DATA_RAW))
  64. df_meta = flag_corrupted_files(df_meta)
  65. df_meta = flag_AE_patients(df_meta)
  66. DATA_GEN.mkdir(parents=True, exist_ok=True)
  67. df_meta.to_pickle(DEFAULT_METADATA_PATH)
  68. if not isinstance(df_meta.index, pd.MultiIndex):
  69. required_cols = {"patient_id", "organ", "visit"}
  70. if required_cols.issubset(df_meta.columns):
  71. df_meta = df_meta.set_index(["patient_id", "organ", "visit"])
  72. else:
  73. raise ValueError(
  74. "Metadata table must have MultiIndex (patient_id, organ, visit) "
  75. "or columns patient_id, organ, visit."
  76. )
  77. required_columns = {"PET_path", "SEG_path"}
  78. missing = required_columns.difference(df_meta.columns)
  79. if missing:
  80. raise ValueError(f"Metadata table is missing required columns: {missing}")
  81. return df_meta.sort_index()
  82. DF_META = _load_metadata()
  83. def _metadata_for_display(df_meta: pd.DataFrame) -> pd.DataFrame:
  84. """Return metadata for UI display.
  85. Path-like columns are hidden from the table, but they remain available in
  86. DF_META for loading images.
  87. """
  88. df = df_meta.reset_index().copy()
  89. df.insert(0, "row_id", range(len(df)))
  90. # Hide all path-like columns from the UI table.
  91. df = df.loc[:, ~df.columns.str.contains("path", case=False)]
  92. df = df.loc[:, ~df.columns.str.contains("filename", case=False)]
  93. preferred = [
  94. "row_id",
  95. "patient_id",
  96. "organ",
  97. "visit",
  98. "is_AE_patient",
  99. "is_corrupted",
  100. "PET_filename",
  101. "SEG_filename",
  102. ]
  103. cols = [c for c in preferred if c in df.columns] + [
  104. c for c in df.columns if c not in preferred
  105. ]
  106. return df[cols]
  107. DF_DISPLAY = _metadata_for_display(DF_META)
  108. ORGAN_CHOICES = ["All"] + sorted(DF_DISPLAY["organ"].dropna().astype(str).unique().tolist())
  109. VISIT_CHOICES = ["All"] + [str(v) for v in sorted(DF_DISPLAY["visit"].dropna().unique().tolist())]
  110. def _format_image_id(row: pd.Series) -> str:
  111. """Create a readable image identifier from a selected metadata row."""
  112. return f"{row['patient_id']}_{row['organ']}_VISIT_{row['visit']}"
  113. def _parse_probs(prob_string: str) -> tuple[float, ...]:
  114. """Parse comma-separated percentile values from UI input."""
  115. values: list[float] = []
  116. for part in prob_string.split(","):
  117. part = part.strip()
  118. if not part:
  119. continue
  120. values.append(float(part))
  121. if not values:
  122. raise ValueError("At least one percentile must be supplied.")
  123. for value in values:
  124. if not (0 < value < 100):
  125. raise ValueError("Percentiles must be between 0 and 100.")
  126. return tuple(values)
  127. def _safe_error_message(exc: BaseException) -> str:
  128. """Return a readable error message for the diagnostics tab."""
  129. return "\n".join(
  130. [
  131. f"{type(exc).__name__}: {exc}",
  132. "",
  133. traceback.format_exc(limit=8),
  134. ]
  135. )
  136. # -----------------------------------------------------------------------------
  137. # UI
  138. # -----------------------------------------------------------------------------
  139. app_ui = ui.page_fluid(
  140. ui.head_content(
  141. ui.tags.link(
  142. rel="shortcut icon",
  143. href="/favicon.ico?v=3",
  144. type="image/x-icon",
  145. )
  146. ),
  147. ui.h2("Spatial SUV tail-feature explorer"),
  148. ui.layout_sidebar(
  149. ui.sidebar(
  150. ui.div(
  151. ui.h5("Run analysis"),
  152. ui.p("Select one row in the metadata table, then compute."),
  153. ui.input_action_button(
  154. "run",
  155. "Compute selected row",
  156. class_="btn-success btn-lg",
  157. width="100%",
  158. ),
  159. style="""
  160. position: sticky;
  161. top: 0;
  162. z-index: 1000;
  163. background: white;
  164. padding: 1rem 0 1rem 0;
  165. border-bottom: 1px solid #ddd;
  166. margin-bottom: 1rem;
  167. """,
  168. ),
  169. ui.h4("Metadata filters"),
  170. ui.input_text(
  171. "filter_patient",
  172. "Patient contains",
  173. value="",
  174. placeholder="e.g. A001",
  175. ),
  176. ui.input_select(
  177. "filter_organ",
  178. "Organ",
  179. choices=ORGAN_CHOICES,
  180. selected="All",
  181. ),
  182. ui.input_select(
  183. "filter_visit",
  184. "Visit",
  185. choices=VISIT_CHOICES,
  186. selected="All",
  187. ),
  188. ui.input_select(
  189. "filter_ae",
  190. "AE patient",
  191. choices=["All", "True", "False"],
  192. selected="All",
  193. ),
  194. ui.input_select(
  195. "filter_corrupted",
  196. "Corrupted",
  197. choices=["All", "True", "False"],
  198. selected="All",
  199. ),
  200. ui.hr(),
  201. ui.h4("Analysis settings"),
  202. ui.input_text(
  203. "probs",
  204. "Percentiles",
  205. value="80, 90, 95",
  206. placeholder="80, 90, 95",
  207. ),
  208. ui.input_numeric("bins", "Histogram bins", value=100, min=10, max=500),
  209. ui.input_numeric("min_suv", "Minimum SUV for PDF", value=0.1, min=0),
  210. ui.input_checkbox("log_x", "Log x-axis for SUV PDF", value=True),
  211. ui.hr(),
  212. ui.input_numeric(
  213. "min_component_voxels",
  214. "Minimum component voxels",
  215. value=3,
  216. min=1,
  217. step=1,
  218. ),
  219. ui.input_select(
  220. "component_connectivity",
  221. "Component connectivity",
  222. choices={"6": "6", "18": "18", "26": "26"},
  223. selected="26",
  224. ),
  225. ui.input_select(
  226. "contrast_connectivity",
  227. "Contrast connectivity",
  228. choices={"6": "6", "18": "18", "26": "26"},
  229. selected="26",
  230. ),
  231. ui.input_checkbox("compute_spread", "Compute spread", value=True),
  232. ui.input_checkbox("compute_local_contrast", "Compute local contrast", value=True),
  233. ui.input_checkbox("compute_sphericity", "Compute sphericity", value=True),
  234. ui.input_checkbox("crop_to_roi", "Crop to ROI", value=True),
  235. width=340,
  236. ),
  237. ui.navset_tab(
  238. ui.nav_panel(
  239. "Metadata table",
  240. ui.p("Filter and select one row, then click 'Compute selected row'."),
  241. ui.output_ui("filter_summary"),
  242. ui.output_data_frame("metadata_table"),
  243. ),
  244. ui.nav_panel(
  245. "SUV PDF",
  246. ui.output_ui("selected_summary_pdf"),
  247. output_widget("suv_pdf_plot", height="560px"),
  248. ui.h4("SUV percentiles"),
  249. ui.output_data_frame("suv_percentiles_table"),
  250. ),
  251. ui.nav_panel(
  252. "Hot voxels",
  253. ui.output_ui("selected_summary_hot"),
  254. ui.output_ui("hot_voxel_tabs"),
  255. ),
  256. ui.nav_panel(
  257. "Spatial features",
  258. ui.output_ui("selected_summary_features"),
  259. ui.output_data_frame("features_table"),
  260. ),
  261. ui.nav_panel(
  262. "Errors / diagnostics",
  263. ui.output_text_verbatim("diagnostics"),
  264. ),
  265. ),
  266. ),
  267. )
  268. # -----------------------------------------------------------------------------
  269. # Server
  270. # -----------------------------------------------------------------------------
  271. def server(input: Inputs, output: Outputs, session: Session):
  272. @reactive.calc
  273. def filtered_metadata_display() -> pd.DataFrame:
  274. """Apply sidebar filters to the displayed metadata table."""
  275. df = DF_DISPLAY.copy()
  276. patient_text = input.filter_patient().strip()
  277. organ = input.filter_organ()
  278. visit = input.filter_visit()
  279. ae = input.filter_ae()
  280. corrupted = input.filter_corrupted()
  281. if patient_text:
  282. df = df[
  283. df["patient_id"]
  284. .astype(str)
  285. .str.contains(patient_text, case=False, na=False)
  286. ]
  287. if organ != "All":
  288. df = df[df["organ"].astype(str) == organ]
  289. if visit != "All":
  290. df = df[df["visit"].astype(str) == visit]
  291. if ae != "All" and "is_AE_patient" in df.columns:
  292. df = df[df["is_AE_patient"].astype(bool) == (ae == "True")]
  293. if corrupted != "All" and "is_corrupted" in df.columns:
  294. df = df[df["is_corrupted"].astype(bool) == (corrupted == "True")]
  295. return df.reset_index(drop=True)
  296. @render.ui
  297. def filter_summary():
  298. n_shown = len(filtered_metadata_display())
  299. n_total = len(DF_DISPLAY)
  300. return ui.div(
  301. ui.strong("Rows shown: "),
  302. f"{n_shown} / {n_total}",
  303. )
  304. @render.data_frame
  305. def metadata_table():
  306. return render.DataGrid(
  307. filtered_metadata_display(),
  308. selection_mode="row",
  309. filters=False,
  310. width="100%",
  311. height="650px",
  312. )
  313. @reactive.calc
  314. def selected_row_display() -> pd.Series | None:
  315. selected = metadata_table.cell_selection()["rows"]
  316. if not selected:
  317. return None
  318. df = filtered_metadata_display()
  319. if df.empty:
  320. return None
  321. row_pos = int(selected[0])
  322. if row_pos >= len(df):
  323. return None
  324. return df.iloc[row_pos]
  325. @reactive.calc
  326. @reactive.event(input.run)
  327. def analysis_result():
  328. """Load selected image and compute all requested outputs once per click."""
  329. row_display = selected_row_display()
  330. if row_display is None:
  331. return {
  332. "ok": False,
  333. "error": "No row selected. Select a row in the metadata table first.",
  334. }
  335. try:
  336. probs = _parse_probs(input.probs())
  337. patient_id = row_display["patient_id"]
  338. organ = row_display["organ"]
  339. visit = int(row_display["visit"])
  340. index_key = (patient_id, organ, visit)
  341. row_meta, processed_image = get_processed_image(
  342. DF_META,
  343. patient_id=patient_id,
  344. organ=organ,
  345. visit=visit,
  346. )
  347. image_id = _format_image_id(row_display)
  348. pdf_fig, suv_percentiles = plot_suv_pdf_plotly(
  349. processed_image,
  350. percentiles=probs,
  351. bins=int(input.bins()),
  352. log_x=bool(input.log_x()),
  353. min_suv=float(input.min_suv()),
  354. title=f"SUV distribution: {image_id}",
  355. )
  356. hot_figs: dict[float, go.Figure] = {}
  357. for p in probs:
  358. threshold = float(
  359. suv_percentiles.loc[
  360. suv_percentiles["percentile"] == p,
  361. "suv_threshold",
  362. ].iloc[0]
  363. )
  364. hot_fig = plot_hot_voxels_plotly(
  365. processed_image,
  366. c=threshold,
  367. )
  368. hot_fig.update_layout(
  369. title=f"Hot voxels: {image_id}, p{p:g}, SUV ≥ {threshold:.4g}"
  370. )
  371. hot_figs[p] = hot_fig
  372. features = compute_tail_spatial_features(
  373. image=processed_image,
  374. percentiles=probs,
  375. component_connectivity=int(input.component_connectivity()),
  376. contrast_connectivity=int(input.contrast_connectivity()),
  377. min_component_voxels=int(input.min_component_voxels()),
  378. compute_spread=bool(input.compute_spread()),
  379. compute_local_contrast=bool(input.compute_local_contrast()),
  380. compute_sphericity=bool(input.compute_sphericity()),
  381. crop_to_roi=bool(input.crop_to_roi()),
  382. image_id=index_key,
  383. )
  384. return {
  385. "ok": True,
  386. "row_display": row_display,
  387. "row_meta": row_meta,
  388. "image_id": image_id,
  389. "probs": probs,
  390. "pdf_fig": pdf_fig,
  391. "suv_percentiles": suv_percentiles,
  392. "hot_figs": hot_figs,
  393. "features": features,
  394. "diagnostics": f"Computed successfully for {image_id}",
  395. }
  396. except Exception as exc:
  397. return {
  398. "ok": False,
  399. "row_display": row_display,
  400. "error": _safe_error_message(exc),
  401. }
  402. def _selected_summary(result: dict) -> ui.TagList:
  403. if not result.get("ok"):
  404. return ui.TagList(ui.div(ui.strong("No successful computation yet.")))
  405. row = result["row_display"]
  406. return ui.TagList(
  407. ui.div(
  408. ui.strong("Selected image: "),
  409. f"{row['patient_id']} | {row['organ']} | VISIT_{row['visit']}",
  410. )
  411. )
  412. @render.ui
  413. def selected_summary_pdf():
  414. return _selected_summary(analysis_result())
  415. @render.ui
  416. def selected_summary_hot():
  417. return _selected_summary(analysis_result())
  418. @render.ui
  419. def selected_summary_features():
  420. return _selected_summary(analysis_result())
  421. @render_widget
  422. def suv_pdf_plot():
  423. result = analysis_result()
  424. if not result.get("ok"):
  425. return _empty_figure("Select a row and click 'Compute selected row' to show the SUV PDF.")
  426. return result["pdf_fig"]
  427. @render.data_frame
  428. def suv_percentiles_table():
  429. result = analysis_result()
  430. if not result.get("ok"):
  431. return pd.DataFrame()
  432. return render.DataGrid(result["suv_percentiles"], height="250px")
  433. @render.ui
  434. def hot_voxel_tabs():
  435. result = analysis_result()
  436. if not result.get("ok"):
  437. return ui.div(ui.em("Select a row and compute to show hot-voxel plots."))
  438. probs = list(result["probs"])
  439. max_plots = 5
  440. if len(probs) > max_plots:
  441. probs = probs[:max_plots]
  442. tabs = []
  443. for i, p in enumerate(probs):
  444. tabs.append(
  445. ui.nav_panel(
  446. f"p{p:g}",
  447. output_widget(f"hot_voxel_plot_{i}", height="760px"),
  448. )
  449. )
  450. return ui.navset_tab(*tabs)
  451. def _hot_voxel_figure_by_index(index: int):
  452. result = analysis_result()
  453. if not result.get("ok"):
  454. return _empty_figure("No hot-voxel plot available.")
  455. probs = list(result["probs"])
  456. if index >= len(probs):
  457. return _empty_figure("No percentile assigned to this plot.")
  458. p = probs[index]
  459. fig = result["hot_figs"].get(p)
  460. if fig is None:
  461. return _empty_figure(f"Percentile p{p:g} was not requested.")
  462. return fig
  463. @render_widget
  464. def hot_voxel_plot_0():
  465. return _hot_voxel_figure_by_index(0)
  466. @render_widget
  467. def hot_voxel_plot_1():
  468. return _hot_voxel_figure_by_index(1)
  469. @render_widget
  470. def hot_voxel_plot_2():
  471. return _hot_voxel_figure_by_index(2)
  472. @render_widget
  473. def hot_voxel_plot_3():
  474. return _hot_voxel_figure_by_index(3)
  475. @render_widget
  476. def hot_voxel_plot_4():
  477. return _hot_voxel_figure_by_index(4)
  478. @render.data_frame
  479. def features_table():
  480. result = analysis_result()
  481. if not result.get("ok"):
  482. return pd.DataFrame()
  483. return render.DataGrid(result["features"], height="420px")
  484. @render.text
  485. def diagnostics():
  486. result = analysis_result()
  487. if result.get("ok"):
  488. return result.get("diagnostics", "OK")
  489. return result.get("error", "Unknown error")
  490. app = App(
  491. app_ui,
  492. server,
  493. static_assets=WWW_DIR,
  494. )
  495. if __name__ == "__main__":
  496. from shiny import run_app
  497. # For development, prefer:
  498. # shiny run --reload shiny_app/app.py
  499. # Direct execution via `python shiny_app/app.py` works without auto-reload.
  500. run_app(app, host="127.0.0.1", port=8000, reload=False)