data_access.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # pyright: basic
  2. from __future__ import annotations
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Any
  6. import numpy as np
  7. import pandas as pd
  8. import xarray as xr
  9. from bayesian_torch.utils.util import predictive_entropy
  10. from .uncertainty import confidence_certainty
  11. @dataclass
  12. class BackendEvaluation:
  13. backend: str
  14. source_file: Path
  15. image_ids: np.ndarray
  16. y_true: np.ndarray
  17. y_prob: np.ndarray
  18. uncertainty_confidence: np.ndarray
  19. uncertainty_std: np.ndarray
  20. uncertainty_metric: str
  21. def _resolve_dataset_path(model_output_dir: Path) -> Path:
  22. primary = model_output_dir / "model_evaluation_results.nc"
  23. if primary.exists():
  24. return primary
  25. candidates = sorted(model_output_dir.glob("*.nc"))
  26. if not candidates:
  27. raise FileNotFoundError(f"No netCDF file found under {model_output_dir}")
  28. return candidates[0]
  29. def _positive_probability(
  30. predictions: xr.DataArray,
  31. class_index: int,
  32. ) -> tuple[np.ndarray, np.ndarray, str]:
  33. if "img_class" not in predictions.dims:
  34. raise ValueError("predictions is missing required dim: img_class")
  35. if class_index >= predictions.sizes["img_class"]:
  36. raise ValueError(
  37. f"positive class index {class_index} is out of bounds for img_class size {predictions.sizes['img_class']}"
  38. )
  39. if "model" in predictions.dims:
  40. prob_mean = predictions.mean(dim="model").isel(img_class=class_index).values
  41. prob_std = predictions.std(dim="model").isel(img_class=class_index).values
  42. return (
  43. np.asarray(prob_mean, dtype=float),
  44. np.asarray(prob_std, dtype=float),
  45. "std",
  46. )
  47. sample_like = [d for d in predictions.dims if d in {"sample", "mc_sample", "draw"}]
  48. if sample_like:
  49. dim = str(sample_like[0])
  50. prob_mean = predictions.mean(dim=dim).isel(img_class=class_index).values
  51. # For Bayesian MC predictions, uncertainty should come from predictive
  52. # entropy of the predictive distribution rather than classwise std.
  53. mc_preds = predictions.transpose(dim, "img_id", "img_class").values
  54. entropy_uncertainty = predictive_entropy(np.asarray(mc_preds, dtype=float))
  55. return (
  56. np.asarray(prob_mean, dtype=float),
  57. np.asarray(entropy_uncertainty, dtype=float),
  58. "predictive_entropy",
  59. )
  60. prob = predictions.isel(img_class=class_index).values
  61. return (
  62. np.asarray(prob, dtype=float),
  63. np.full_like(np.asarray(prob, dtype=float), np.nan),
  64. "unknown",
  65. )
  66. def _labels_to_binary(labels: xr.DataArray, class_index: int) -> np.ndarray:
  67. if "label" in labels.dims:
  68. if class_index >= labels.sizes["label"]:
  69. raise ValueError(
  70. f"positive class index {class_index} is out of bounds for label size {labels.sizes['label']}"
  71. )
  72. # One-hot labels expected in this repository.
  73. binary = labels.argmax(dim="label").values == class_index
  74. return np.asarray(binary, dtype=int)
  75. # Fallback if labels are already binary.
  76. return np.asarray(labels.values, dtype=int)
  77. def load_backend_evaluation(
  78. config: dict[str, Any],
  79. backend: str,
  80. class_index: int,
  81. ) -> BackendEvaluation:
  82. output_key = f"{backend}_path"
  83. if output_key not in config["output"]:
  84. raise KeyError(f"Missing output path key in config: output.{output_key}")
  85. model_output_dir = Path(config["output"][output_key]).expanduser().resolve()
  86. ds_path = _resolve_dataset_path(model_output_dir)
  87. ds = xr.open_dataset(ds_path)
  88. if "predictions" not in ds or "labels" not in ds:
  89. raise ValueError(
  90. f"Dataset {ds_path} must contain predictions and labels variables"
  91. )
  92. predictions = ds["predictions"]
  93. labels = ds["labels"]
  94. if "img_id" in predictions.coords:
  95. image_ids = np.asarray(predictions.coords["img_id"].values)
  96. elif "img_id" in labels.coords:
  97. image_ids = np.asarray(labels.coords["img_id"].values)
  98. else:
  99. length = predictions.sizes.get("img_id", labels.sizes.get("img_id"))
  100. if length is None:
  101. raise ValueError("Could not infer img_id length from predictions/labels")
  102. image_ids = np.arange(length)
  103. y_true = _labels_to_binary(labels, class_index=class_index)
  104. y_prob, y_std, uncertainty_metric = _positive_probability(
  105. predictions, class_index=class_index
  106. )
  107. conf = confidence_certainty(y_prob)
  108. if len(y_true) != len(y_prob):
  109. raise ValueError(
  110. f"Length mismatch after loading backend {backend}: labels={len(y_true)}, probs={len(y_prob)}"
  111. )
  112. return BackendEvaluation(
  113. backend=backend,
  114. source_file=ds_path,
  115. image_ids=image_ids,
  116. y_true=y_true,
  117. y_prob=y_prob,
  118. uncertainty_confidence=conf,
  119. uncertainty_std=y_std,
  120. uncertainty_metric=uncertainty_metric,
  121. )
  122. def load_clinical_table(config: dict[str, Any], root_dir: Path) -> pd.DataFrame:
  123. csv_path = (root_dir / config["data"]["xls_file_path"]).resolve()
  124. df = pd.read_csv(csv_path)
  125. df.columns = df.columns.str.strip()
  126. return df
  127. def physician_column(df: pd.DataFrame) -> str:
  128. exact = "DXCONFID"
  129. if exact in df.columns:
  130. return exact
  131. for col in df.columns:
  132. if "dxconfid" in col.lower():
  133. return col
  134. raise KeyError(
  135. "No physician confidence column with DXCONFID found in clinical table"
  136. )