|
|
@@ -5,142 +5,45 @@ from __future__ import annotations
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
|
|
|
-import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import torch
|
|
|
from bayesian_torch.utils.util import predictive_entropy
|
|
|
|
|
|
-from data.dataset import (
|
|
|
- divide_dataset_by_patient_id,
|
|
|
- initalize_dataloaders,
|
|
|
- load_adni_data_from_file,
|
|
|
-)
|
|
|
from model.cnn import CNN3D
|
|
|
|
|
|
+from .data_pipeline import build_dataset_and_test_loader
|
|
|
from .metrics import calibration_stats, performance_at_threshold
|
|
|
+from .model_utils import configure_bayesian_sampling_mode
|
|
|
+from .plotting import (
|
|
|
+ plots_dir,
|
|
|
+ save_noise_example_grid,
|
|
|
+ save_noise_metrics_plot,
|
|
|
+)
|
|
|
from .runtime import write_json
|
|
|
+from .uncertainty import confidence_certainty, confidence_uncertainty
|
|
|
|
|
|
|
|
|
-def _enable_bayesian_sampling_mode(model: torch.nn.Module) -> None:
|
|
|
- model.eval()
|
|
|
-
|
|
|
-
|
|
|
-def _central_slice(volume: torch.Tensor) -> np.ndarray:
|
|
|
- tensor = volume.detach().cpu()
|
|
|
- if tensor.ndim == 5:
|
|
|
- tensor = tensor[0]
|
|
|
- if tensor.ndim == 4:
|
|
|
- tensor = tensor[0]
|
|
|
- if tensor.ndim != 3:
|
|
|
- raise ValueError(
|
|
|
- f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}"
|
|
|
- )
|
|
|
-
|
|
|
- center_index = tensor.shape[0] // 2
|
|
|
- return tensor[center_index].numpy().astype(float)
|
|
|
-
|
|
|
-
|
|
|
-def _normalize_for_display(image: np.ndarray) -> np.ndarray:
|
|
|
- low = float(np.percentile(image, 1))
|
|
|
- high = float(np.percentile(image, 99))
|
|
|
- if high <= low:
|
|
|
- return np.zeros_like(image, dtype=float)
|
|
|
- clipped = np.clip(image, low, high)
|
|
|
- return (clipped - low) / (high - low)
|
|
|
-
|
|
|
-
|
|
|
-def _save_noise_example_grid(
|
|
|
- original_mri: torch.Tensor,
|
|
|
- noisy_by_sigma: list[tuple[float, torch.Tensor]],
|
|
|
- output_path: Path,
|
|
|
- title: str,
|
|
|
-) -> None:
|
|
|
- if not noisy_by_sigma:
|
|
|
- return
|
|
|
-
|
|
|
- original_slice = _normalize_for_display(_central_slice(original_mri))
|
|
|
- n_rows = len(noisy_by_sigma)
|
|
|
- fig, axes = plt.subplots(n_rows, 2, figsize=(8, 3.2 * n_rows))
|
|
|
- if n_rows == 1:
|
|
|
- axes = np.array([axes])
|
|
|
-
|
|
|
- for row_idx, (sigma, noisy_tensor) in enumerate(noisy_by_sigma):
|
|
|
- noisy_slice = _normalize_for_display(_central_slice(noisy_tensor))
|
|
|
- ax_orig, ax_noisy = axes[row_idx]
|
|
|
- ax_orig.imshow(original_slice, cmap="gray")
|
|
|
- ax_orig.set_title(f"Original")
|
|
|
- ax_orig.axis("off")
|
|
|
-
|
|
|
- ax_noisy.imshow(noisy_slice, cmap="gray")
|
|
|
- ax_noisy.set_title(f"Noisy sigma={sigma:g}")
|
|
|
- ax_noisy.axis("off")
|
|
|
-
|
|
|
- fig.suptitle(title)
|
|
|
- fig.tight_layout()
|
|
|
- output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
- fig.savefig(output_path)
|
|
|
- plt.close(fig)
|
|
|
-
|
|
|
-
|
|
|
-def _confidence_certainty(y_prob: np.ndarray) -> np.ndarray:
|
|
|
- # 2 * |p - 0.5| maps 0 -> very uncertain and 1 -> very certain.
|
|
|
- return 2.0 * np.abs(y_prob - 0.5)
|
|
|
|
|
|
+def _apply_scaled_noise(
|
|
|
+ volume: torch.Tensor, sigma: float, intensity_range: float
|
|
|
+) -> torch.Tensor:
|
|
|
+ # Scale by global MRI intensity range measured from holdout set.
|
|
|
+ return volume + (torch.randn_like(volume) * sigma * intensity_range)
|
|
|
|
|
|
-def _confidence_uncertainty(y_prob: np.ndarray) -> np.ndarray:
|
|
|
- # Invert certainty so that larger values mean more uncertainty, matching std.
|
|
|
- return 1.0 - _confidence_certainty(y_prob)
|
|
|
|
|
|
+def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]:
|
|
|
+ if not noise_sigmas:
|
|
|
+ raise ValueError("noise_sigmas must contain at least one value")
|
|
|
|
|
|
-def _apply_scaled_noise(volume: torch.Tensor, sigma: float) -> torch.Tensor:
|
|
|
- # Make sigma dimensionless with respect to MRI intensity scale.
|
|
|
- # sigma=1.0 means noise std equals total MRI intensity range, which is 65535 .
|
|
|
- return volume + (torch.randn_like(volume) * sigma * 65535)
|
|
|
+ ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float)
|
|
|
+ if len(ordered) == 1:
|
|
|
+ return [float(ordered[0])]
|
|
|
|
|
|
-
|
|
|
-def _xls_pre(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
- data = df[["Image Data ID", "Sex", "Age (current)"]].copy()
|
|
|
- data["Sex"] = data["Sex"].astype(str).str.strip()
|
|
|
- data = data.replace({"M": 0, "F": 1})
|
|
|
- return data
|
|
|
-
|
|
|
-
|
|
|
-def _build_test_loader(
|
|
|
- config: dict[str, Any], root_dir: Path
|
|
|
-) -> torch.utils.data.DataLoader:
|
|
|
- mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii")
|
|
|
- xls_file = (root_dir / config["data"]["xls_file_path"]).resolve()
|
|
|
-
|
|
|
- dataset = load_adni_data_from_file(
|
|
|
- mri_files,
|
|
|
- xls_file,
|
|
|
- device=config["training"]["device"],
|
|
|
- xls_preprocessor=_xls_pre,
|
|
|
+ uniform = np.linspace(
|
|
|
+ float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float
|
|
|
)
|
|
|
-
|
|
|
- ptid_df = pd.read_csv(xls_file)
|
|
|
- ptid_df.columns = ptid_df.columns.str.strip()
|
|
|
- ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(
|
|
|
- subset=["Image Data ID", "PTID"]
|
|
|
- )
|
|
|
- ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
|
|
|
- ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
|
|
|
- ptid_df = ptid_df[ptid_df["PTID"] != ""]
|
|
|
- ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
|
|
|
-
|
|
|
- splits = divide_dataset_by_patient_id(
|
|
|
- dataset,
|
|
|
- ptids,
|
|
|
- tuple(config["data"]["data_splits"]),
|
|
|
- seed=int(config["data"]["seed"]),
|
|
|
- )
|
|
|
-
|
|
|
- _, _, test_loader = initalize_dataloaders(
|
|
|
- splits,
|
|
|
- batch_size=int(config["training"]["batch_size"]),
|
|
|
- )
|
|
|
- return test_loader
|
|
|
+ return [float(s) for s in uniform]
|
|
|
|
|
|
|
|
|
def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
|
|
|
@@ -208,7 +111,7 @@ def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
|
|
|
model.to(device)
|
|
|
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
|
|
|
model.to(device)
|
|
|
- _enable_bayesian_sampling_mode(model)
|
|
|
+ configure_bayesian_sampling_mode(model, stochastic=False)
|
|
|
return model
|
|
|
|
|
|
|
|
|
@@ -216,6 +119,7 @@ def _infer_with_noise_ensemble(
|
|
|
test_loader: torch.utils.data.DataLoader,
|
|
|
models: list[torch.nn.Module],
|
|
|
sigma: float,
|
|
|
+ intensity_range: float,
|
|
|
class_index: int,
|
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
|
if not models:
|
|
|
@@ -231,7 +135,7 @@ def _infer_with_noise_ensemble(
|
|
|
mri_device = mri.float().to(device)
|
|
|
xls_device = xls.float().to(device)
|
|
|
labels_device = labels.to(device)
|
|
|
- noisy = _apply_scaled_noise(mri_device, sigma)
|
|
|
+ noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
|
|
|
preds = []
|
|
|
for model in models:
|
|
|
out = model((noisy, xls_device))
|
|
|
@@ -253,6 +157,7 @@ def _infer_with_noise_bayesian(
|
|
|
test_loader: torch.utils.data.DataLoader,
|
|
|
model: torch.nn.Module,
|
|
|
sigma: float,
|
|
|
+ intensity_range: float,
|
|
|
class_index: int,
|
|
|
mc_passes: int,
|
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
|
@@ -266,7 +171,7 @@ def _infer_with_noise_bayesian(
|
|
|
mri_device = mri.float().to(device)
|
|
|
xls_device = xls.float().to(device)
|
|
|
labels_device = labels.to(device)
|
|
|
- noisy = _apply_scaled_noise(mri_device, sigma)
|
|
|
+ noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
|
|
|
draws = []
|
|
|
for _ in range(mc_passes):
|
|
|
out = model((noisy, xls_device))
|
|
|
@@ -295,8 +200,21 @@ def run_noise_analysis(
|
|
|
calibration_bins: int,
|
|
|
bayesian_mc_passes: int,
|
|
|
) -> dict[str, Any]:
|
|
|
- test_loader = _build_test_loader(config, root_dir)
|
|
|
- examples_dir = output_dir / "noise_examples"
|
|
|
+ noise_sigmas = _uniform_sigma_schedule(noise_sigmas)
|
|
|
+ dataset, test_loader = build_dataset_and_test_loader(
|
|
|
+ config=config,
|
|
|
+ root_dir=root_dir,
|
|
|
+ seed=int(config["data"]["seed"]),
|
|
|
+ )
|
|
|
+ # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range(
|
|
|
+ # dataset
|
|
|
+ # )
|
|
|
+
|
|
|
+ # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent
|
|
|
+ intensity_range = 10_000.0
|
|
|
+
|
|
|
+ out_plots_dir = plots_dir(output_dir)
|
|
|
+ examples_dir = out_plots_dir / "noise_examples"
|
|
|
examples_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
rows: list[dict[str, Any]] = []
|
|
|
@@ -308,6 +226,7 @@ def run_noise_analysis(
|
|
|
test_loader,
|
|
|
models,
|
|
|
sigma,
|
|
|
+ intensity_range,
|
|
|
class_index=class_index,
|
|
|
)
|
|
|
perf = performance_at_threshold(y_true, y_prob, threshold)
|
|
|
@@ -315,19 +234,20 @@ def run_noise_analysis(
|
|
|
rows.append(
|
|
|
{
|
|
|
"uncertainty_metric": "std",
|
|
|
- "sigma": float(sigma),
|
|
|
+ "noise_factor": float(sigma),
|
|
|
"accuracy": float(perf["accuracy"]),
|
|
|
"f1": float(perf["f1"]),
|
|
|
"ece": float(cal["ece"]),
|
|
|
"mce": float(cal["mce"]),
|
|
|
"mean_confidence_certainty": float(
|
|
|
- np.nanmean(_confidence_certainty(y_prob))
|
|
|
+ np.nanmean(confidence_certainty(y_prob))
|
|
|
),
|
|
|
"mean_confidence_uncertainty": float(
|
|
|
- np.nanmean(_confidence_uncertainty(y_prob))
|
|
|
+ np.nanmean(confidence_uncertainty(y_prob))
|
|
|
),
|
|
|
"mean_std": float(np.nanmean(y_std)),
|
|
|
"mean_predictive_entropy": float("nan"),
|
|
|
+ "mri_intensity_range": float(intensity_range),
|
|
|
}
|
|
|
)
|
|
|
|
|
|
@@ -337,10 +257,10 @@ def run_noise_analysis(
|
|
|
device = next(models[0].parameters()).device
|
|
|
original_device = original_mri.float().to(device)
|
|
|
for sigma in noise_sigmas:
|
|
|
- noisy_mri = _apply_scaled_noise(original_device, sigma)
|
|
|
+ noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
|
|
|
example_rows.append((float(sigma), noisy_mri.detach().cpu()))
|
|
|
|
|
|
- _save_noise_example_grid(
|
|
|
+ save_noise_example_grid(
|
|
|
original_mri=original_mri,
|
|
|
noisy_by_sigma=example_rows,
|
|
|
output_path=examples_dir / f"{backend}_noise_examples.png",
|
|
|
@@ -354,6 +274,7 @@ def run_noise_analysis(
|
|
|
test_loader,
|
|
|
model,
|
|
|
sigma,
|
|
|
+ intensity_range,
|
|
|
class_index=class_index,
|
|
|
mc_passes=bayesian_mc_passes,
|
|
|
)
|
|
|
@@ -362,20 +283,21 @@ def run_noise_analysis(
|
|
|
rows.append(
|
|
|
{
|
|
|
"uncertainty_metric": "predictive_entropy",
|
|
|
- "sigma": float(sigma),
|
|
|
+ "noise_factor": float(sigma),
|
|
|
"accuracy": float(perf["accuracy"]),
|
|
|
"f1": float(perf["f1"]),
|
|
|
"ece": float(cal["ece"]),
|
|
|
"mce": float(cal["mce"]),
|
|
|
"mean_confidence_certainty": float(
|
|
|
- np.nanmean(_confidence_certainty(y_prob))
|
|
|
+ np.nanmean(confidence_certainty(y_prob))
|
|
|
),
|
|
|
"mean_confidence_uncertainty": float(
|
|
|
- np.nanmean(_confidence_uncertainty(y_prob))
|
|
|
+ np.nanmean(confidence_uncertainty(y_prob))
|
|
|
),
|
|
|
# Compatibility field name retained for downstream code.
|
|
|
"mean_std": float(np.nanmean(y_std)),
|
|
|
"mean_predictive_entropy": float(np.nanmean(y_std)),
|
|
|
+ "mri_intensity_range": float(intensity_range),
|
|
|
}
|
|
|
)
|
|
|
|
|
|
@@ -385,10 +307,10 @@ def run_noise_analysis(
|
|
|
device = next(model.parameters()).device
|
|
|
original_device = original_mri.float().to(device)
|
|
|
for sigma in noise_sigmas:
|
|
|
- noisy_mri = _apply_scaled_noise(original_device, sigma)
|
|
|
+ noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
|
|
|
example_rows.append((float(sigma), noisy_mri.detach().cpu()))
|
|
|
|
|
|
- _save_noise_example_grid(
|
|
|
+ save_noise_example_grid(
|
|
|
original_mri=original_mri,
|
|
|
noisy_by_sigma=example_rows,
|
|
|
output_path=examples_dir / f"{backend}_noise_examples.png",
|
|
|
@@ -397,65 +319,58 @@ def run_noise_analysis(
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported backend for noise analysis: {backend}")
|
|
|
|
|
|
- df = pd.DataFrame(rows).sort_values("sigma")
|
|
|
+ df = pd.DataFrame(rows).sort_values("noise_factor")
|
|
|
csv_path = output_dir / "noise_sensitivity.csv"
|
|
|
df.to_csv(csv_path, index=False)
|
|
|
|
|
|
- fig, ax = plt.subplots(figsize=(10, 5))
|
|
|
- ax.plot(df["sigma"], df["accuracy"], marker="o", label="accuracy")
|
|
|
- ax.plot(df["sigma"], df["f1"], marker="s", label="f1")
|
|
|
- ax.plot(df["sigma"], df["ece"], marker="^", label="ece")
|
|
|
- ax.set_xlabel("Gaussian Noise Sigma")
|
|
|
- ax.set_ylabel("Score")
|
|
|
- ax.set_title(f"Noise Sensitivity ({backend})")
|
|
|
- ax.grid(True, alpha=0.3)
|
|
|
- ax.legend()
|
|
|
- fig.tight_layout()
|
|
|
- plot_path = output_dir / "noise_sensitivity.png"
|
|
|
- fig.savefig(plot_path)
|
|
|
- plt.close(fig)
|
|
|
-
|
|
|
- fig_u, ax_u = plt.subplots(figsize=(10, 5))
|
|
|
- ax_u.plot(
|
|
|
- df["sigma"],
|
|
|
- df["mean_confidence_uncertainty"],
|
|
|
- marker="o",
|
|
|
- label="confidence_uncertainty",
|
|
|
+ plot_path = out_plots_dir / "noise_sensitivity.png"
|
|
|
+ uncertainty_plot_path = out_plots_dir / "noise_uncertainty.png"
|
|
|
+ certainty_plot_path = out_plots_dir / "noise_confidence_certainty.png"
|
|
|
+
|
|
|
+ save_noise_metrics_plot(
|
|
|
+ x=df["noise_factor"],
|
|
|
+ y_by_label=[
|
|
|
+ (df["accuracy"], "o", "accuracy"),
|
|
|
+ (df["f1"], "s", "f1"),
|
|
|
+ (df["ece"], "^", "ece"),
|
|
|
+ ],
|
|
|
+ x_label="Gaussian Noise Factor",
|
|
|
+ y_label="Score",
|
|
|
+ title=f"Noise Sensitivity ({backend})",
|
|
|
+ output_path=plot_path,
|
|
|
+ )
|
|
|
+ save_noise_metrics_plot(
|
|
|
+ x=df["noise_factor"],
|
|
|
+ y_by_label=[
|
|
|
+ (df["mean_confidence_uncertainty"], "o", "confidence_uncertainty"),
|
|
|
+ (df["mean_std"], "s", "std_uncertainty"),
|
|
|
+ ],
|
|
|
+ x_label="Gaussian Noise Factor",
|
|
|
+ y_label="Uncertainty",
|
|
|
+ title=f"Uncertainty vs Noise ({backend})",
|
|
|
+ output_path=uncertainty_plot_path,
|
|
|
)
|
|
|
- ax_u.plot(df["sigma"], df["mean_std"], marker="s", label="std_uncertainty")
|
|
|
- ax_u.set_xlabel("Gaussian Noise Sigma")
|
|
|
- ax_u.set_ylabel("Uncertainty")
|
|
|
- ax_u.set_title(f"Uncertainty vs Noise ({backend})")
|
|
|
- ax_u.grid(True, alpha=0.3)
|
|
|
- ax_u.legend()
|
|
|
- fig_u.tight_layout()
|
|
|
- uncertainty_plot_path = output_dir / "noise_uncertainty.png"
|
|
|
- fig_u.savefig(uncertainty_plot_path)
|
|
|
- plt.close(fig_u)
|
|
|
-
|
|
|
- fig_c, ax_c = plt.subplots(figsize=(10, 5))
|
|
|
- ax_c.plot(
|
|
|
- df["sigma"],
|
|
|
- df["mean_confidence_certainty"],
|
|
|
- marker="o",
|
|
|
- label="confidence_certainty",
|
|
|
+ save_noise_metrics_plot(
|
|
|
+ x=df["noise_factor"],
|
|
|
+ y_by_label=[
|
|
|
+ (df["mean_confidence_certainty"], "o", "confidence_certainty"),
|
|
|
+ ],
|
|
|
+ x_label="Gaussian Noise Factor",
|
|
|
+ y_label="Certainty",
|
|
|
+ title=f"Confidence Certainty vs Noise ({backend})",
|
|
|
+ output_path=certainty_plot_path,
|
|
|
)
|
|
|
- ax_c.set_xlabel("Gaussian Noise Sigma")
|
|
|
- ax_c.set_ylabel("Certainty")
|
|
|
- ax_c.set_title(f"Confidence Certainty vs Noise ({backend})")
|
|
|
- ax_c.grid(True, alpha=0.3)
|
|
|
- ax_c.legend()
|
|
|
- fig_c.tight_layout()
|
|
|
- certainty_plot_path = output_dir / "noise_confidence_certainty.png"
|
|
|
- fig_c.savefig(certainty_plot_path)
|
|
|
- plt.close(fig_c)
|
|
|
|
|
|
out = {
|
|
|
"table": str(csv_path),
|
|
|
"plot": str(plot_path),
|
|
|
"uncertainty_plot": str(uncertainty_plot_path),
|
|
|
"certainty_plot": str(certainty_plot_path),
|
|
|
+ "noise_factors": noise_sigmas,
|
|
|
"noise_sigmas": noise_sigmas,
|
|
|
+ # "mri_intensity_min": float(intensity_min),
|
|
|
+ # "mri_intensity_max": float(intensity_max),
|
|
|
+ "mri_intensity_range": float(intensity_range),
|
|
|
}
|
|
|
write_json(output_dir / "noise_summary.json", out)
|
|
|
return out
|