Source code for squidpy.experimental.tl._tiling_qc

"""QC metrics for detecting tile-boundary segmentation artifacts.

Cells cut by tile borders during segmentation have characteristic
straight edges that natural cell boundaries never produce.  This module
computes per-cell metrics that quantify this artifact:

- **max_straight_edge_ratio**: length of the longest straight contour
  segment normalised by the cell's equivalent diameter.
- **cardinal_alignment_score**: how closely that segment aligns with
  0° or 90° (axis-aligned tile borders).
- **cut_score**: product of the two, combining evidence from shape and
  orientation.
- **smoothed_cut_score**: cut_score multiplied by the mean cut_score of
  the ``n_neighbors`` nearest spatial neighbors - amplifies boundary
  cells while suppressing isolated high-scorers.
- **is_outlier**: boolean flag gated on per-cell cut_score and/or
  spatially smoothed score exceeding their respective MAD thresholds.
- **nhood_outlier_fraction**: fraction of ``n_neighbors`` nearest
  neighbors that are smoothed-score outliers (MAD-based).  Bounded
  [0, 1]; high values precisely trace the FOV tile grid.

All heavy computation is done per-tile via the tiling infrastructure
in :mod:`squidpy.experimental.im._tiling`, so this scales to
100k x 100k images without materialising the full array.
"""

from __future__ import annotations

import math
from collections.abc import Mapping
from dataclasses import asdict, dataclass, fields
from typing import Any, Literal

import anndata as ad
import dask
import numpy as np
import pandas as pd
import spatialdata as sd
import xarray as xr
from dask.diagnostics import ProgressBar
from numba import njit
from skimage.measure import find_contours, regionprops
from sklearn.neighbors import BallTree
from spatialdata._logging import logger as logg
from spatialdata.models import TableModel

from squidpy._utils import _get_n_cores
from squidpy.experimental.im._tiling import (
    build_tile_specs,
    compute_cell_info,
    compute_cell_info_multiscale,
    compute_cell_info_tiled,
    extract_labels_tile_lazy,
)
from squidpy.experimental.tl._tiling_stitch import _STITCH_COLUMNS, _STITCH_PARAM_KEYS, StitchParams
from squidpy.experimental.utils._labels import resolve_labels_array

__all__ = ["TilingQCParams", "calculate_tiling_qc"]


[docs] @dataclass(slots=True, frozen=True) class TilingQCParams: """Advanced tuning knobs for :func:`~squidpy.experimental.tl.calculate_tiling_qc`. Pass an instance (or a ``Mapping`` of field names to values) as ``tiling_qc_params`` to override. Frozen so that validation done in ``__post_init__`` cannot be silently bypassed by later mutation. """ distance_tol: float = 0.75 """Maximum perpendicular distance (pixels) from the fitted line for a contour point to count as straight.""" min_area: int = 20 """Cells smaller than this (pixels at analysis resolution) are skipped (NaN scores).""" max_contour_points: int = 500 """Cap on contour resolution; longer contours are arc-length-resampled before the O(n^2) collinearity scan.""" def __post_init__(self) -> None: # frozen=True forbids direct assignment; use object.__setattr__ for coercion. object.__setattr__(self, "distance_tol", float(self.distance_tol)) object.__setattr__(self, "min_area", int(self.min_area)) object.__setattr__(self, "max_contour_points", int(self.max_contour_points)) if self.distance_tol < 0: raise ValueError(f"`distance_tol` must be >= 0, got {self.distance_tol}.") if self.min_area < 1: raise ValueError(f"`min_area` must be >= 1, got {self.min_area}.") if self.max_contour_points < 3: raise ValueError( f"`max_contour_points` must be >= 3 (collinearity needs 3 points), got {self.max_contour_points}." )
_QC_DEFAULTS = TilingQCParams() _QC_FIELDS = frozenset(f.name for f in fields(TilingQCParams)) def _resolve_qc_params(qc_params: TilingQCParams | Mapping[str, Any] | None) -> TilingQCParams: """Normalise the ``tiling_qc_params`` argument to a :class:`TilingQCParams` instance.""" if qc_params is None: return _QC_DEFAULTS if isinstance(qc_params, TilingQCParams): return qc_params if isinstance(qc_params, Mapping): unknown = set(qc_params) - _QC_FIELDS if unknown: raise ValueError( f"Unknown `tiling_qc_params` field(s): {sorted(unknown)}; expected from {sorted(_QC_FIELDS)}." ) return TilingQCParams(**qc_params) raise TypeError(f"`tiling_qc_params` must be TilingQCParams, Mapping, or None; got {type(qc_params).__name__}.") # Standard consistency factor sd ~ 1.4826 x MAD for normal distributions. _MAD_TO_SD = 1.4826 _TILE_SCORE_COLUMNS = ["max_straight_edge_ratio", "cardinal_alignment_score", "cut_score"] _POST_SCORE_COLUMNS = ["smoothed_cut_score", "is_outlier", "nhood_outlier_fraction"] _SCORE_COLUMNS = _TILE_SCORE_COLUMNS + _POST_SCORE_COLUMNS _NAN_TILE_SCORES = dict.fromkeys(_TILE_SCORE_COLUMNS, np.nan) def _has_distributed_client() -> bool: """Return True iff a ``dask.distributed.Client`` is active in this process. Mirrors the public dask idiom: if a Client is in scope, ``dask.compute`` will pick it up automatically — we only need to know whether to fall back to the local threaded scheduler. """ try: # ImportError guards against partial dask installs without the distributed extra; # ValueError is what get_client() raises when no Client is currently active. from dask.distributed import get_client get_client() except (ImportError, ValueError): return False return True # Core geometry @njit(cache=True, nogil=True) def _collinear_scan( contour: np.ndarray, cum_arc: np.ndarray, total_arc: float, distance_tol: float, ) -> tuple[float, float]: """Numba-accelerated two-pointer collinearity scan. For each start index, extends the end index as long as all intermediate points stay within ``distance_tol`` of the start→end line. Returns ``(best_length, best_angle)``. """ n = contour.shape[0] best_len = 0.0 best_angle = 0.0 for start in range(n - 2): remaining_arc = total_arc - cum_arc[start] if remaining_arc <= best_len: break for end in range(start + 2, n): d0 = contour[end, 0] - contour[start, 0] d1 = contour[end, 1] - contour[start, 1] seg_len = math.sqrt(d0 * d0 + d1 * d1) if seg_len < 1e-12: continue max_perp = 0.0 for k in range(start + 1, end): r0 = contour[k, 0] - contour[start, 0] r1 = contour[k, 1] - contour[start, 1] perp = abs(d0 * r1 - d1 * r0) / seg_len if perp > max_perp: max_perp = perp if perp > distance_tol: break if max_perp > distance_tol: break if seg_len > best_len: best_len = seg_len best_angle = math.atan2(d0, d1) return best_len, best_angle def _resample_contour(contour: np.ndarray, max_points: int) -> np.ndarray: """Resample a contour to at most *max_points* via arc-length interpolation. Fully vectorised using :func:`numpy.searchsorted` - no Python loops. Preserves geometry far better than naive stride-based subsampling because points are placed equidistantly along the contour arc. """ n = len(contour) if n <= max_points: return contour diffs = np.diff(contour, axis=0) seg_lengths = np.sqrt((diffs**2).sum(axis=1)) cum_arc = np.empty(n, dtype=np.float64) cum_arc[0] = 0.0 cum_arc[1:] = np.cumsum(seg_lengths) total = cum_arc[-1] if total < 1e-12: return contour[:max_points] targets = np.linspace(0.0, total, max_points) idx = np.searchsorted(cum_arc, targets, side="right") - 1 idx = np.clip(idx, 0, n - 2) seg = cum_arc[idx + 1] - cum_arc[idx] safe_seg = np.where(seg < 1e-12, 1.0, seg) frac = np.where(seg < 1e-12, 0.0, (targets - cum_arc[idx]) / safe_seg) return contour[idx] + frac[:, np.newaxis] * (contour[idx + 1] - contour[idx]) def _longest_collinear_segment( contour: np.ndarray, distance_tol: float = _QC_DEFAULTS.distance_tol, max_contour_points: int = _QC_DEFAULTS.max_contour_points, ) -> tuple[float, float]: """Find the longest collinear run of contour points. Uses a numba-compiled two-pointer scan with three contour rotations to handle the closure point. Long contours are resampled to at most ``max_contour_points`` via arc-length interpolation to bound worst-case runtime. Parameters ---------- contour ``(N, 2)`` array of ``(row, col)`` contour coordinates. distance_tol Maximum perpendicular distance (pixels) from the start→end line for a point to be considered part of the straight segment. max_contour_points Cap on contour resolution; longer contours are resampled to this length before the collinearity scan. Returns ------- run_length Euclidean length of the longest straight segment (pixels). run_angle Angle (radians, ``[-π, π]``) of that segment. """ n = len(contour) if n < 3: return 0.0, 0.0 pts = np.asarray(contour, dtype=np.float64) pts = _resample_contour(pts, max_contour_points) n = len(pts) # find_contours returns closed contours (first ≈ last point) closed = np.sqrt(((pts[0] - pts[-1]) ** 2).sum()) < 1.0 # For closed contours, drop the duplicate last point and precompute # segment lengths once - rotations reuse the same distances. if closed and n > 6: core = pts[:-1] core_diffs = np.diff(core, axis=0) core_seg_lens = np.sqrt((core_diffs**2).sum(axis=1)) rotations = [0, len(core) // 3, 2 * len(core) // 3] else: core = pts core_diffs = np.diff(core, axis=0) core_seg_lens = np.sqrt((core_diffs**2).sum(axis=1)) rotations = [0] best_len = 0.0 best_angle = 0.0 # Scan at multiple rotations so straight segments crossing the # closure point are not split. for shift in rotations: if shift == 0: rotated = core sl = core_seg_lens else: rotated = np.roll(core, -shift, axis=0) sl = np.roll(core_seg_lens, -shift) cum_arc = np.empty(len(rotated), dtype=np.float64) cum_arc[0] = 0.0 cum_arc[1:] = np.cumsum(sl) length, angle = _collinear_scan(rotated, cum_arc, cum_arc[-1], distance_tol) if length > best_len: best_len = length best_angle = angle return best_len, best_angle def _cardinal_alignment(angle: float) -> float: """Score how close an angle is to a cardinal direction (0° or 90°). Returns a value in ``[0, 1]`` where 1 means perfectly axis-aligned and 0 means maximally diagonal (45°). """ a = abs(angle) % np.pi dist = min(a, abs(a - np.pi / 2), abs(a - np.pi)) # Map [0, π/4] → [1, 0] return float(1.0 - dist / (np.pi / 4)) def _straight_edge_metrics( contour: np.ndarray, cell_area: float, distance_tol: float = _QC_DEFAULTS.distance_tol, max_contour_points: int = _QC_DEFAULTS.max_contour_points, ) -> tuple[float, float, float]: """Compute straight-edge metrics for a single cell contour. Parameters ---------- contour ``(N, 2)`` contour coordinates from :func:`skimage.measure.find_contours`. cell_area Area of the cell in pixels (for normalisation). distance_tol Perpendicular distance tolerance for collinearity (pixels). Returns ------- straight_edge_ratio Longest collinear segment / equivalent diameter. cardinal_score Cardinal alignment of the longest straight segment. cut_score Product of the two. """ eq_diam = np.sqrt(4 * cell_area / np.pi) if eq_diam == 0: return 0.0, 0.0, 0.0 run_length, run_angle = _longest_collinear_segment(contour, distance_tol, max_contour_points) straight_ratio = run_length / eq_diam cardinal = _cardinal_alignment(run_angle) cut_score = straight_ratio * cardinal return float(straight_ratio), float(cardinal), float(cut_score) # Per-tile scoring def _score_tile( tile_labels: np.ndarray, distance_tol: float = _QC_DEFAULTS.distance_tol, min_area: int = _QC_DEFAULTS.min_area, downsample: int = 1, max_contour_points: int = _QC_DEFAULTS.max_contour_points, ) -> pd.DataFrame: """Compute tiling QC metrics for all cells in a numpy label tile. Parameters ---------- tile_labels ``(H, W)`` label array (background = 0, owned cells only). distance_tol Perpendicular distance tolerance for collinearity (pixels). min_area Cells smaller than this (in pixels at analysis resolution) are skipped and get NaN values. downsample Factor by which to downsample each cell's bounding-box crop before contour extraction. ``1`` = full resolution, ``2`` = half, etc. Straight edges are scale-invariant so moderate downsampling (2–4x) is safe and much faster for large cells. Returns ------- DataFrame with columns ``max_straight_edge_ratio``, ``cardinal_alignment_score``, ``cut_score``, indexed by cell label. """ regions = regionprops(tile_labels) if not regions: return pd.DataFrame(columns=_TILE_SCORE_COLUMNS, dtype=float) rows: dict[int, dict[str, float]] = {} for region in regions: lid = region.label area = region.area if area < min_area * (downsample**2): rows[lid] = dict(_NAN_TILE_SCORES) continue # Pad with 1px of zeros so find_contours can trace cells # that touch the crop edge (e.g., cells filling their bbox). min_row, min_col, max_row, max_col = region.bbox crop = (tile_labels[min_row:max_row, min_col:max_col] == lid).astype(np.float32) crop = np.pad(crop, 1, mode="constant", constant_values=0) if downsample > 1: crop = crop[::downsample, ::downsample] contours = find_contours(crop, 0.5) if not contours: rows[lid] = dict(_NAN_TILE_SCORES) continue contour = max(contours, key=len) analysis_area = area / (downsample**2) if downsample > 1 else area ser, cas, cs = _straight_edge_metrics(contour, analysis_area, distance_tol, max_contour_points) rows[lid] = { "max_straight_edge_ratio": ser, "cardinal_alignment_score": cas, "cut_score": cs, } return pd.DataFrame.from_dict(rows, orient="index") # Centroid computation (shared logic with _feature.py) def _compute_centroids_for_labels( sdata: sd.SpatialData, labels_key: str, labels_da: xr.DataArray, scale: str | None, ) -> dict: """Compute cell centroids using the most efficient strategy available.""" if isinstance(sdata.labels[labels_key], xr.DataTree): logg.info("Computing centroids from coarse scale.") return compute_cell_info_multiscale(sdata.labels[labels_key], target_scale=scale or "scale0") n_pixels = labels_da.sizes.get("y", 1) * labels_da.sizes.get("x", 1) if n_pixels <= 4096 * 4096: lbl_np = labels_da.values if lbl_np.ndim > 2: lbl_np = lbl_np.squeeze() return compute_cell_info(lbl_np) logg.info("Computing centroids in tiled mode (large single-scale labels).") return compute_cell_info_tiled(labels_da) # Public API _METHOD_KEY = "tiling_qc"
[docs] def calculate_tiling_qc( sdata: sd.SpatialData, labels_key: str, scale: str | None = None, tile_size: int = 2048, overlap_margin: int | Literal["auto"] = "auto", downsample: int = 1, outlier_use_cut: bool = True, outlier_use_smoothed: bool = True, nmads_cut: float = 1.5, nmads_smoothed: float = 3, n_neighbors: int = 10, tiling_qc_params: TilingQCParams | Mapping[str, Any] | None = None, n_jobs: int = -1, table_key_added: str | None = None, inplace: bool = True, ) -> ad.AnnData | None: """Score cells for tile-boundary segmentation artifacts. Computes per-cell metrics that detect artificially straight edges caused by tiled segmentation. Large images are processed via the cell-aware tiling infrastructure in ``squidpy.experimental.im._tiling``. Results are stored in a QC table (default ``sdata.tables["{labels_key}_qc"]``). Scores live in ``.obs``; the ``.X`` matrix is empty. Algorithm parameters are recorded in ``.uns["tiling_qc"]``. Parameters ---------- sdata SpatialData object. labels_key Key in ``sdata.labels`` with segmentation masks. scale Scale level for multi-scale labels. tile_size Side length of the tiling grid (pixels). overlap_margin Overlap around each tile. ``"auto"`` computes the minimum from the largest cell's bounding box. downsample Factor by which to downsample each cell's bounding-box crop before contour extraction. Straightness is scale-invariant, so ``2``--``4`` is safe and much faster on large cells. outlier_use_cut Gate ``is_outlier`` on the per-cell ``cut_score`` exceeding its own MAD threshold. Requires the cell itself to have a straight cardinal-aligned edge. outlier_use_smoothed Gate ``is_outlier`` on the spatially smoothed score (``smoothed_cut_score``) exceeding its MAD threshold. Requires the cell to be in a spatial cluster of high-scorers. nmads_cut Number of MADs for the ``cut_score`` outlier gate. Threshold is ``median + nmads_cut x MAD x 1.4826``. nmads_smoothed Number of MADs for the ``smoothed_cut_score`` outlier gate. Threshold is ``median + nmads_smoothed x MAD x 1.4826``. n_neighbors Number of nearest spatial neighbors used to compute ``smoothed_cut_score`` and ``nhood_outlier_fraction``. In a perfect grid each cell has 8 immediate neighbours; the default of 10 leaves a little wiggle room for biological irregularity without wasting compute on distant cells. tiling_qc_params Advanced tuning knobs as a :class:`TilingQCParams` instance or a ``Mapping`` of its field names to values. See :class:`TilingQCParams` for each field's meaning and default. ``None`` (default) uses all defaults. n_jobs Number of threads for tile processing. ``-1`` (default) uses all available CPUs. Ignored when an active ``dask.distributed.Client`` is in scope (the client's own worker pool is used instead). table_key_added Key under which to store the result in ``sdata.tables``. Defaults to ``"{labels_key}_qc"``. inplace If ``True``, store result in ``sdata.tables``. Otherwise return the AnnData directly. Returns ------- :class:`~anndata.AnnData` when ``inplace=False``, otherwise ``None``. The AnnData ``.obs`` contains five scores per cell: - ``max_straight_edge_ratio``: longest collinear boundary segment / equivalent diameter. - ``cardinal_alignment_score``: axis-alignment of that segment (1 = cardinal, 0 = diagonal). - ``cut_score``: product of the two. - ``smoothed_cut_score``: ``cut_score x mean(neighbor cut_scores)`` over the ``n_neighbors`` nearest spatial neighbors. Amplifies cells on FOV boundaries while suppressing isolated high-scorers. - ``is_outlier``: boolean, ``True`` when the enabled outlier gates are satisfied (``cut_score`` and/or ``smoothed_cut_score`` exceeding their respective MAD thresholds). - ``nhood_outlier_fraction``: fraction of ``n_neighbors`` nearest neighbors that are smoothed-score outliers (MAD-based). Bounded [0, 1]; high values trace the tile grid. Notes ----- Tile processing is parallelised via :func:`dask.compute`. When an active ``dask.distributed.Client`` is in scope it is picked up automatically and used for execution; otherwise a local threaded scheduler with ``n_jobs`` workers is used. If you invoke this function from inside a dask worker task (e.g., via ``client.submit(calculate_tiling_qc, ...)``), wrap the call in ``distributed.secede`` / ``distributed.rejoin`` to release the worker slot before the inner tile tasks are submitted; without that, the cluster can deadlock when all workers are busy holding the outer job. """ if labels_key not in sdata.labels: raise ValueError(f"Labels key '{labels_key}' not found, valid keys: {list(sdata.labels.keys())}") if not outlier_use_cut and not outlier_use_smoothed: raise ValueError("At least one outlier gate must be enabled (outlier_use_cut or outlier_use_smoothed).") if outlier_use_cut and nmads_cut <= 0: raise ValueError(f"nmads_cut must be positive, got {nmads_cut}.") if outlier_use_smoothed and nmads_smoothed <= 0: raise ValueError(f"nmads_smoothed must be positive, got {nmads_smoothed}.") if n_neighbors < 1: raise ValueError(f"n_neighbors must be >= 1, got {n_neighbors}.") qc_params = _resolve_qc_params(tiling_qc_params) labels_da = resolve_labels_array(sdata, labels_key, scale) cell_info = _compute_centroids_for_labels(sdata, labels_key, labels_da, scale) if not cell_info: raise ValueError("No cells found in labels (all zeros).") H = int(labels_da.sizes.get("y", labels_da.shape[-2])) W = int(labels_da.sizes.get("x", labels_da.shape[-1])) specs = build_tile_specs((H, W), cell_info, tile_size=tile_size, overlap_margin=overlap_margin) logg.info( f"Tiling QC: {len(specs)} tiles ({tile_size}x{tile_size}, margin={overlap_margin}, downsample={downsample}x)." ) @dask.delayed def _process_one(spec): tile_lbl = extract_labels_tile_lazy(labels_da, spec) return _score_tile( tile_lbl, distance_tol=qc_params.distance_tol, min_area=qc_params.min_area, downsample=downsample, max_contour_points=qc_params.max_contour_points, ) tasks = [_process_one(spec) for spec in specs] if _has_distributed_client(): if n_jobs != -1: logg.warning( "`n_jobs` is ignored when an active dask.distributed Client is in scope. " "Parallelism is controlled by the client." ) results = dask.compute(*tasks) else: num_workers = _get_n_cores(n_jobs) with ProgressBar(): results = dask.compute(*tasks, scheduler="threads", num_workers=num_workers) tile_dfs = [df for df in results if not df.empty] if not tile_dfs: raise ValueError("No cells scored - labels may be empty or all below min_area.") combined = pd.concat(tile_dfs, axis=0).sort_index() if combined.index.duplicated().any(): dups = combined.index[combined.index.duplicated()].unique().tolist() raise RuntimeError(f"Duplicate cell IDs across tiles - tile ownership may be broken. Duplicates: {dups}") # --- Spatial context post-processing --- n_cells = len(combined) centroid_y = np.array([cell_info[lid].centroid_y for lid in combined.index]) centroid_x = np.array([cell_info[lid].centroid_x for lid in combined.index]) centroids = np.column_stack([centroid_y, centroid_x]) if n_cells <= 1: combined["smoothed_cut_score"] = combined["cut_score"] combined["is_outlier"] = False combined["nhood_outlier_fraction"] = 0.0 else: effective_k = min(n_neighbors, n_cells - 1) tree = BallTree(centroids) _, indices = tree.query(centroids, k=effective_k + 1) # +1 because query includes self neighbor_idx = indices[:, 1:] cut_scores = combined["cut_score"].values.copy() cut_scores = np.where(np.isnan(cut_scores), 0.0, cut_scores) neighbor_mean = cut_scores[neighbor_idx].mean(axis=1) smoothed = cut_scores * neighbor_mean combined["smoothed_cut_score"] = smoothed # Build is_outlier from enabled gates (AND when both active). # A gate whose MAD is degenerate has no signal — treat it as a # no-op so it cannot poison the other gate's result. If no gate # produced a meaningful filter, fall back to "no outliers". is_outlier = np.ones(n_cells, dtype=bool) gates_applied = 0 if outlier_use_cut: median_c = np.median(cut_scores) mad_c = np.median(np.abs(cut_scores - median_c)) if mad_c >= 1e-12: is_outlier &= cut_scores >= median_c + nmads_cut * mad_c * _MAD_TO_SD gates_applied += 1 if outlier_use_smoothed: median_s = np.median(smoothed) mad_s = np.median(np.abs(smoothed - median_s)) if mad_s >= 1e-12: is_outlier &= smoothed >= median_s + nmads_smoothed * mad_s * _MAD_TO_SD gates_applied += 1 if gates_applied == 0: is_outlier[:] = False combined["is_outlier"] = is_outlier neighbor_outlier_frac = combined["is_outlier"].values[neighbor_idx].mean(axis=1) combined["nhood_outlier_fraction"] = neighbor_outlier_frac adata = ad.AnnData( X=np.empty((n_cells, 0), dtype=np.float32), ) adata.obs_names = [f"cell_{i}" for i in combined.index] adata.obs["region"] = pd.Categorical([labels_key] * n_cells) adata.obs["label_id"] = combined.index.values adata.uns["spatialdata_attrs"] = { "region": labels_key, "region_key": "region", "instance_key": "label_id", } # TODO: migrate tiling QC scores to .obsm once spatialdata-plot # supports rendering labels colored by obsm keys. # See scverse/spatialdata-plot#587. for col in combined.columns: adata.obs[col] = combined[col].values adata.obs["centroid_y"] = centroid_y adata.obs["centroid_x"] = centroid_x adata.uns[_METHOD_KEY] = { "scale": scale, "tile_size": tile_size, "overlap_margin": overlap_margin, "downsample": downsample, "outlier_use_cut": outlier_use_cut, "outlier_use_smoothed": outlier_use_smoothed, "nmads_cut": nmads_cut, "nmads_smoothed": nmads_smoothed, "n_neighbors": n_neighbors, "tiling_qc_params": asdict(qc_params), } if inplace: table_key = table_key_added if table_key_added is not None else f"{labels_key}_qc" _warn_if_dropping_stitch_columns(sdata, table_key, labels_key) sdata.tables[table_key] = TableModel.parse(adata) return None return adata
def _warn_if_dropping_stitch_columns(sdata: sd.SpatialData, table_key: str, labels_key: str) -> None: """Warn if re-running QC would drop downstream stitch results. ``calculate_tiling_qc`` replaces the QC table wholesale, so any columns added by :func:`~squidpy.experimental.tl.assign_stitch_groups` to a previous version of this table are about to disappear. We emit an actionable warning listing the previous stitch parameters (from ``.uns["tiling_stitch"]``) and a copy-pasteable invocation to restore them. """ if table_key not in sdata.tables: return existing = sdata.tables[table_key] present = [c for c in _STITCH_COLUMNS if c in existing.obs.columns] if not present: return prev_params = existing.uns.get("tiling_stitch", {}) if hasattr(existing, "uns") else {} parts = [f"labels_key={labels_key!r}"] parts.extend(f"{k}={v!r}" for k, v in prev_params.items() if k in _STITCH_PARAM_KEYS) nested = prev_params.get("stitch_params") if isinstance(nested, dict) and nested: defaults = asdict(StitchParams()) diff = {k: v for k, v in nested.items() if k in defaults and defaults[k] != v} if diff: parts.append(f"stitch_params={diff!r}") rerun = f"sq.experimental.tl.assign_stitch_groups(sdata, {', '.join(parts)})" logg.warning( f"Re-running calculate_tiling_qc dropped previous stitch columns " f"({', '.join(present)}) from sdata.tables[{table_key!r}]. " f"To restore them, run: {rerun}" )