data_access.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # pyright: basic
  2. from __future__ import annotations
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Any
  6. import numpy as np
  7. import pandas as pd
  8. import xarray as xr
  9. from bayesian_torch.utils.util import predictive_entropy
  10. @dataclass
  11. class BackendEvaluation:
  12. backend: str
  13. source_file: Path
  14. image_ids: np.ndarray
  15. y_true: np.ndarray
  16. y_prob: np.ndarray
  17. uncertainty_confidence: np.ndarray
  18. uncertainty_std: np.ndarray
  19. uncertainty_metric: str
  20. def _resolve_dataset_path(model_output_dir: Path) -> Path:
  21. primary = model_output_dir / "model_evaluation_results.nc"
  22. if primary.exists():
  23. return primary
  24. candidates = sorted(model_output_dir.glob("*.nc"))
  25. if not candidates:
  26. raise FileNotFoundError(f"No netCDF file found under {model_output_dir}")
  27. return candidates[0]
  28. def _positive_probability(
  29. predictions: xr.DataArray,
  30. class_index: int,
  31. ) -> tuple[np.ndarray, np.ndarray, str]:
  32. if "img_class" not in predictions.dims:
  33. raise ValueError("predictions is missing required dim: img_class")
  34. if class_index >= predictions.sizes["img_class"]:
  35. raise ValueError(
  36. f"positive class index {class_index} is out of bounds for img_class size {predictions.sizes['img_class']}"
  37. )
  38. if "model" in predictions.dims:
  39. prob_mean = predictions.mean(dim="model").isel(img_class=class_index).values
  40. prob_std = predictions.std(dim="model").isel(img_class=class_index).values
  41. return (
  42. np.asarray(prob_mean, dtype=float),
  43. np.asarray(prob_std, dtype=float),
  44. "std",
  45. )
  46. sample_like = [d for d in predictions.dims if d in {"sample", "mc_sample", "draw"}]
  47. if sample_like:
  48. dim = str(sample_like[0])
  49. prob_mean = predictions.mean(dim=dim).isel(img_class=class_index).values
  50. # For Bayesian MC predictions, uncertainty should come from predictive
  51. # entropy of the predictive distribution rather than classwise std.
  52. mc_preds = predictions.transpose(dim, "img_id", "img_class").values
  53. entropy_uncertainty = predictive_entropy(np.asarray(mc_preds, dtype=float))
  54. return (
  55. np.asarray(prob_mean, dtype=float),
  56. np.asarray(entropy_uncertainty, dtype=float),
  57. "predictive_entropy",
  58. )
  59. prob = predictions.isel(img_class=class_index).values
  60. return (
  61. np.asarray(prob, dtype=float),
  62. np.full_like(np.asarray(prob, dtype=float), np.nan),
  63. "unknown",
  64. )
  65. def _labels_to_binary(labels: xr.DataArray, class_index: int) -> np.ndarray:
  66. if "label" in labels.dims:
  67. if class_index >= labels.sizes["label"]:
  68. raise ValueError(
  69. f"positive class index {class_index} is out of bounds for label size {labels.sizes['label']}"
  70. )
  71. # One-hot labels expected in this repository.
  72. binary = labels.argmax(dim="label").values == class_index
  73. return np.asarray(binary, dtype=int)
  74. # Fallback if labels are already binary.
  75. return np.asarray(labels.values, dtype=int)
  76. def load_backend_evaluation(
  77. config: dict[str, Any],
  78. backend: str,
  79. class_index: int,
  80. ) -> BackendEvaluation:
  81. output_key = f"{backend}_path"
  82. if output_key not in config["output"]:
  83. raise KeyError(f"Missing output path key in config: output.{output_key}")
  84. model_output_dir = Path(config["output"][output_key]).expanduser().resolve()
  85. ds_path = _resolve_dataset_path(model_output_dir)
  86. ds = xr.open_dataset(ds_path)
  87. if "predictions" not in ds or "labels" not in ds:
  88. raise ValueError(
  89. f"Dataset {ds_path} must contain predictions and labels variables"
  90. )
  91. predictions = ds["predictions"]
  92. labels = ds["labels"]
  93. if "img_id" in predictions.coords:
  94. image_ids = np.asarray(predictions.coords["img_id"].values)
  95. elif "img_id" in labels.coords:
  96. image_ids = np.asarray(labels.coords["img_id"].values)
  97. else:
  98. length = predictions.sizes.get("img_id", labels.sizes.get("img_id"))
  99. if length is None:
  100. raise ValueError("Could not infer img_id length from predictions/labels")
  101. image_ids = np.arange(length)
  102. y_true = _labels_to_binary(labels, class_index=class_index)
  103. y_prob, y_std, uncertainty_metric = _positive_probability(
  104. predictions, class_index=class_index
  105. )
  106. conf = 2.0 * np.abs(y_prob - 0.5)
  107. if len(y_true) != len(y_prob):
  108. raise ValueError(
  109. f"Length mismatch after loading backend {backend}: labels={len(y_true)}, probs={len(y_prob)}"
  110. )
  111. return BackendEvaluation(
  112. backend=backend,
  113. source_file=ds_path,
  114. image_ids=image_ids,
  115. y_true=y_true,
  116. y_prob=y_prob,
  117. uncertainty_confidence=conf,
  118. uncertainty_std=y_std,
  119. uncertainty_metric=uncertainty_metric,
  120. )
  121. def load_clinical_table(config: dict[str, Any], root_dir: Path) -> pd.DataFrame:
  122. csv_path = (root_dir / config["data"]["xls_file_path"]).resolve()
  123. df = pd.read_csv(csv_path)
  124. df.columns = df.columns.str.strip()
  125. return df
  126. def physician_column(df: pd.DataFrame) -> str:
  127. exact = "DXCONFID"
  128. if exact in df.columns:
  129. return exact
  130. for col in df.columns:
  131. if "dxconfid" in col.lower():
  132. return col
  133. raise KeyError(
  134. "No physician confidence column with DXCONFID found in clinical table"
  135. )