dataset_summary.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from __future__ import annotations
  2. from copy import deepcopy
  3. from datetime import datetime, timezone
  4. from pathlib import Path
  5. from typing import Any
  6. from analysis.data_pipeline import build_dataset, build_dataset_splits
  7. def _percent(part: int, whole: int) -> float:
  8. if whole <= 0:
  9. return 0.0
  10. return 100.0 * float(part) / float(whole)
  11. def compute_dataset_summary(
  12. config: dict[str, Any],
  13. root_dir: Path,
  14. positive_class_index: int,
  15. ) -> dict[str, Any]:
  16. # Force CPU so summary generation works without GPU availability.
  17. summary_config = deepcopy(config)
  18. summary_config.setdefault("training", {})
  19. summary_config["training"]["device"] = "cpu"
  20. dataset, xls_file = build_dataset(summary_config, root_dir)
  21. seed = int(config["data"]["seed"])
  22. splits = build_dataset_splits(summary_config, dataset, xls_file, seed=seed)
  23. split_names = ["train", "validation", "test"]
  24. requested_ratios = [float(v) for v in config["data"]["data_splits"]]
  25. if len(splits) != 3:
  26. raise ValueError(f"Expected 3 dataset splits, got {len(splits)}.")
  27. total_images = int(len(dataset))
  28. labels = (dataset.expected_classes[:, positive_class_index] >= 0.5).int()
  29. splits_summary: list[dict[str, Any]] = []
  30. assigned = 0
  31. assigned_positive = 0
  32. for split_name, requested_ratio, subset in zip(
  33. split_names,
  34. requested_ratios,
  35. splits,
  36. strict=True,
  37. ):
  38. indices = list(subset.indices)
  39. split_count = int(len(indices))
  40. split_positive = int(labels[indices].sum().item()) if split_count > 0 else 0
  41. split_negative = split_count - split_positive
  42. assigned += split_count
  43. assigned_positive += split_positive
  44. splits_summary.append(
  45. {
  46. "split": split_name,
  47. "requested_ratio": requested_ratio,
  48. "image_count": split_count,
  49. "image_pct_of_dataset": _percent(split_count, total_images),
  50. "positive_count": split_positive,
  51. "negative_count": split_negative,
  52. "positive_pct_within_split": _percent(split_positive, split_count),
  53. "negative_pct_within_split": _percent(split_negative, split_count),
  54. }
  55. )
  56. if assigned != total_images:
  57. raise ValueError(
  58. f"Split coverage mismatch: assigned {assigned} images, expected {total_images}."
  59. )
  60. total_positive = assigned_positive
  61. total_negative = total_images - total_positive
  62. return {
  63. "generated_utc": datetime.now(timezone.utc).isoformat(),
  64. "seed": seed,
  65. "positive_class_index": int(positive_class_index),
  66. "totals": {
  67. "image_count": total_images,
  68. "positive_count": total_positive,
  69. "negative_count": total_negative,
  70. "positive_pct": _percent(total_positive, total_images),
  71. "negative_pct": _percent(total_negative, total_images),
  72. },
  73. "splits": splits_summary,
  74. }
  75. def _markdown_table(summary: dict[str, Any]) -> str:
  76. lines = [
  77. "| Split | Requested % | Images | Dataset % | Positive | Negative | Positive % (split) | Negative % (split) |",
  78. "|---|---:|---:|---:|---:|---:|---:|---:|",
  79. ]
  80. for item in summary["splits"]:
  81. lines.append(
  82. "| "
  83. f"{item['split'].title()} | "
  84. f"{item['requested_ratio'] * 100.0:.2f}% | "
  85. f"{item['image_count']} | "
  86. f"{item['image_pct_of_dataset']:.2f}% | "
  87. f"{item['positive_count']} | "
  88. f"{item['negative_count']} | "
  89. f"{item['positive_pct_within_split']:.2f}% | "
  90. f"{item['negative_pct_within_split']:.2f}% |"
  91. )
  92. return "\n".join(lines)
  93. def write_dataset_summary_markdown(summary: dict[str, Any], path: Path) -> None:
  94. totals = summary["totals"]
  95. lines = [
  96. "# Dataset Composition Summary",
  97. "",
  98. f"Generated (UTC): {summary['generated_utc']}",
  99. f"Split seed: {summary['seed']}",
  100. f"Positive class index: {summary['positive_class_index']}",
  101. "",
  102. "## Overall",
  103. "",
  104. f"- Total images: {totals['image_count']}",
  105. f"- Positive images: {totals['positive_count']} ({totals['positive_pct']:.2f}%)",
  106. f"- Negative images: {totals['negative_count']} ({totals['negative_pct']:.2f}%)",
  107. "",
  108. "## Train / Validation / Test Breakdown",
  109. "",
  110. _markdown_table(summary),
  111. "",
  112. ]
  113. path.write_text("\n".join(lines), encoding="utf-8")
  114. def run_dataset_summary(
  115. config: dict[str, Any],
  116. root_dir: Path,
  117. output_dir: Path,
  118. positive_class_index: int,
  119. ) -> dict[str, Any]:
  120. summary = compute_dataset_summary(
  121. config=config,
  122. root_dir=root_dir,
  123. positive_class_index=positive_class_index,
  124. )
  125. markdown_path = output_dir / "dataset_summary.md"
  126. write_dataset_summary_markdown(summary, markdown_path)
  127. return {
  128. "summary_markdown": str(markdown_path),
  129. "summary": summary,
  130. }