from __future__ import annotations from copy import deepcopy from datetime import datetime, timezone from pathlib import Path from typing import Any from .data_pipeline import build_dataset, build_dataset_splits def _percent(part: int, whole: int) -> float: if whole <= 0: return 0.0 return 100.0 * float(part) / float(whole) def compute_dataset_summary( config: dict[str, Any], root_dir: Path, positive_class_index: int, ) -> dict[str, Any]: # Force CPU so summary generation works without GPU availability. summary_config = deepcopy(config) summary_config.setdefault("training", {}) summary_config["training"]["device"] = "cpu" dataset, xls_file = build_dataset(summary_config, root_dir) seed = int(config["data"]["seed"]) splits = build_dataset_splits(summary_config, dataset, xls_file, seed=seed) split_names = ["train", "validation", "test"] requested_ratios = [float(v) for v in config["data"]["data_splits"]] if len(splits) != 3: raise ValueError(f"Expected 3 dataset splits, got {len(splits)}.") total_images = int(len(dataset)) labels = (dataset.expected_classes[:, positive_class_index] >= 0.5).int() splits_summary: list[dict[str, Any]] = [] assigned = 0 assigned_positive = 0 for split_name, requested_ratio, subset in zip( split_names, requested_ratios, splits, strict=True, ): indices = list(subset.indices) split_count = int(len(indices)) split_positive = int(labels[indices].sum().item()) if split_count > 0 else 0 split_negative = split_count - split_positive assigned += split_count assigned_positive += split_positive splits_summary.append( { "split": split_name, "requested_ratio": requested_ratio, "image_count": split_count, "image_pct_of_dataset": _percent(split_count, total_images), "positive_count": split_positive, "negative_count": split_negative, "positive_pct_within_split": _percent(split_positive, split_count), "negative_pct_within_split": _percent(split_negative, split_count), } ) if assigned != total_images: raise ValueError( f"Split coverage mismatch: assigned {assigned} images, expected {total_images}." ) total_positive = assigned_positive total_negative = total_images - total_positive return { "generated_utc": datetime.now(timezone.utc).isoformat(), "seed": seed, "positive_class_index": int(positive_class_index), "totals": { "image_count": total_images, "positive_count": total_positive, "negative_count": total_negative, "positive_pct": _percent(total_positive, total_images), "negative_pct": _percent(total_negative, total_images), }, "splits": splits_summary, } def _markdown_table(summary: dict[str, Any]) -> str: lines = [ "| Split | Requested % | Images | Dataset % | Positive | Negative | Positive % (split) | Negative % (split) |", "|---|---:|---:|---:|---:|---:|---:|---:|", ] for item in summary["splits"]: lines.append( "| " f"{item['split'].title()} | " f"{item['requested_ratio'] * 100.0:.2f}% | " f"{item['image_count']} | " f"{item['image_pct_of_dataset']:.2f}% | " f"{item['positive_count']} | " f"{item['negative_count']} | " f"{item['positive_pct_within_split']:.2f}% | " f"{item['negative_pct_within_split']:.2f}% |" ) return "\n".join(lines) def write_dataset_summary_markdown(summary: dict[str, Any], path: Path) -> None: totals = summary["totals"] lines = [ "# Dataset Composition Summary", "", f"Generated (UTC): {summary['generated_utc']}", f"Split seed: {summary['seed']}", f"Positive class index: {summary['positive_class_index']}", "", "## Overall", "", f"- Total images: {totals['image_count']}", f"- Positive images: {totals['positive_count']} ({totals['positive_pct']:.2f}%)", f"- Negative images: {totals['negative_count']} ({totals['negative_pct']:.2f}%)", "", "## Train / Validation / Test Breakdown", "", _markdown_table(summary), "", ] path.write_text("\n".join(lines), encoding="utf-8") def run_dataset_summary( config: dict[str, Any], root_dir: Path, output_dir: Path, positive_class_index: int, ) -> dict[str, Any]: summary = compute_dataset_summary( config=config, root_dir=root_dir, positive_class_index=positive_class_index, ) markdown_path = output_dir / "dataset_summary.md" write_dataset_summary_markdown(summary, markdown_path) return { "summary_markdown": str(markdown_path), "summary": summary, }