"""Stitching of tile-cut cells flagged by :func:`~squidpy.experimental.tl.calculate_tiling_qc`.
When segmentation is run tile-by-tile (Cellpose, Stardist, Mesmer, ...) cells
that straddle tile boundaries get cut into 2-4 pieces with characteristic
straight, axis-aligned cut edges. :func:`~squidpy.experimental.tl.calculate_tiling_qc` flags these
as ``is_outlier=True``. This module pairs facing cut edges across boundaries
and assigns each candidate pair a heuristic geometric score in [0, 1].
The score is the flat (unweighted) mean of five dataset-independent geometric
features -- ``iou``, ``endpoint_match``, ``merge_compactness``,
``merge_solidity`` and ``gap_proximity`` -- computed from the cut-edge geometry
and the union mask after closing the seam gap. No model is fitted or shipped;
the features are recorded in ``.uns["tiling_stitch"]``. Users should tune
``min_confidence`` for their data; ``0.7`` is a reasonable starting point, not
a calibrated probability.
The labels element is **never** modified here -- only ``.obs`` columns are
written. Materialising a stitched labels element is opt-in via
:func:`!make_stitched_labels`.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import spatialdata as sd
import xarray as xr
from scipy.ndimage import binary_closing
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from skimage.measure import label as cc_label
from skimage.measure import regionprops
from skimage.morphology import disk as morph_disk
from spatialdata._logging import logger as logg
from squidpy.experimental.utils._geometry import equivalent_diameter, largest_contour
from squidpy.experimental.utils._labels import iter_chunked_regionprops, resolve_labels_array
from squidpy.experimental.utils._params import resolve_params
if TYPE_CHECKING:
from collections.abc import Iterable
import anndata as ad
__all__ = ["StitchParams", "assign_stitch_groups"]
# The geometric features whose flat mean is the stitch score.
_SCORE_FEATURES: tuple[str, ...] = ("iou", "endpoint_match", "merge_compactness", "merge_solidity", "gap_proximity")
# The subset computed by the expensive merge-union step; the rest are cheap
# geometry features known before it, which drives the scoring early-prune.
_SHAPE_FEATURES: tuple[str, ...] = ("merge_compactness", "merge_solidity")
[docs]
@dataclass(slots=True)
class StitchParams:
"""Advanced tuning knobs for :func:`~squidpy.experimental.tl.assign_stitch_groups`.
Defaults work for typical 2D segmentation tiles produced by
cellpose-like pipelines. Pass an instance (or a ``Mapping`` of
field names to values) as ``stitch_params`` to override. These are
advanced knobs -- the defaults rarely need changing.
"""
distance_tol: float = 0.75
"""Sub-pixel tolerance for "lies on a bbox edge"."""
min_edge_length: float = 5.0
"""Absolute floor on cut-edge length (pixels)."""
min_edge_length_ratio: float = 0.4
"""Minimum cut-edge length relative to the cell's equivalent diameter."""
min_edge_coverage: float = 0.5
"""Minimum fraction of parallel-axis positions covered by near-edge contour points."""
candidate_min_iou: float = 0.2
"""Loose 1-D IoU floor at candidate enumeration."""
close_radius: int = 3
"""Morphological closing disk radius for the union mask. Also the
length scale for ``gap_proximity`` (normalised by ``2 * close_radius``)."""
def __post_init__(self) -> None:
# Coerce numeric types (accept numpy scalars cleanly) and bounds-check.
self.distance_tol = float(self.distance_tol)
self.min_edge_length = float(self.min_edge_length)
self.min_edge_length_ratio = float(self.min_edge_length_ratio)
self.min_edge_coverage = float(self.min_edge_coverage)
self.candidate_min_iou = float(self.candidate_min_iou)
self.close_radius = int(self.close_radius)
if self.distance_tol < 0:
raise ValueError(f"distance_tol must be >= 0, got {self.distance_tol}.")
if self.min_edge_length < 0:
raise ValueError(f"min_edge_length must be >= 0, got {self.min_edge_length}.")
if not 0.0 <= self.min_edge_length_ratio <= 1.0:
raise ValueError(f"min_edge_length_ratio must be in [0, 1], got {self.min_edge_length_ratio}.")
if not 0.0 <= self.min_edge_coverage <= 1.0:
raise ValueError(f"min_edge_coverage must be in [0, 1], got {self.min_edge_coverage}.")
if not 0.0 <= self.candidate_min_iou <= 1.0:
raise ValueError(f"candidate_min_iou must be in [0, 1], got {self.candidate_min_iou}.")
if self.close_radius < 0:
raise ValueError(f"close_radius must be >= 0, got {self.close_radius}.")
def _resolve_stitch_params(stitch_params: StitchParams | Mapping[str, Any] | None) -> StitchParams:
"""Normalise the ``stitch_params`` argument to a :class:`StitchParams` instance."""
return resolve_params(stitch_params, StitchParams, label="stitch_params")
_METHOD_KEY = "tiling_stitch"
_STITCH_DEFAULTS = StitchParams()
# Contract between calculate_tiling_qc and assign_stitch_groups. _STITCH_COLUMNS
# is the obs columns stitch writes back into the QC table; _STITCH_PARAM_KEYS
# is the subset of top-level kwargs valid for re-running assign_stitch_groups
# (the advanced tuning lives in a nested ``stitch_params`` dict).
_STITCH_COLUMNS = ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence")
_STITCH_PARAM_KEYS = frozenset({"min_confidence", "max_gap", "max_group_size"})
# Dataclasses
@dataclass(frozen=True)
class _CutEdge:
"""A candidate cut edge on a single cell's bbox.
Attributes
----------
cell_id
Label ID of the piece carrying this edge.
axis
``"h"`` (horizontal cut: edge is a horizontal line, cell sits above
or below it) or ``"v"`` (vertical cut).
coord
Position of the cut line: y-coord for ``"h"``, x-coord for ``"v"``.
extent
``(min, max)`` along the parallel axis -- the chord at the cut line.
normal_dir
``+1`` if the cell's centroid sits at greater coord than the cut
line, ``-1`` otherwise. Used to enforce facing pairs.
length
Euclidean length of the run (``extent[1] - extent[0]``).
"""
cell_id: int
axis: str
coord: float
extent: tuple[float, float]
normal_dir: int
length: float
@dataclass(frozen=True)
class _StitchPair:
"""A scored candidate pairing of two cut edges across a tile boundary.
``confidence`` is the flat mean of the geometric features (see
:data:`_SCORE_FEATURES`); the individual feature components are kept for
diagnostics and for the ``min``-based group-confidence aggregation.
"""
cell_a: int
cell_b: int
axis: str
confidence: float
iou: float
endpoint_match: float
gap_proximity: float
merge_solidity: float
merge_compactness: float
edge_a: _CutEdge | None = field(default=None, repr=False)
edge_b: _CutEdge | None = field(default=None, repr=False)
# Cut-edge extraction
def _read_bbox_slice(labels_da: xr.DataArray | np.ndarray, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
"""Read a 2-D bbox slice from numpy or xarray, squeezing singleton dims."""
if isinstance(labels_da, np.ndarray):
return labels_da[y0:y1, x0:x1]
arr = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values
while arr.ndim > 2:
arr = arr.squeeze(0)
return arr
def _compute_outlier_bboxes(
labels_da: xr.DataArray | np.ndarray,
outlier_ids: Iterable[int],
chunk_size: int = 4096,
) -> dict[int, tuple[int, int, int, int]]:
"""Compute global bboxes for the outlier subset in a single chunked pass.
Returns mapping ``label_id -> (min_row, min_col, max_row, max_col)``.
Works on numpy or dask-backed xarray; for xarray the array is read in
``chunk_size`` x ``chunk_size`` tiles so memory is bounded.
"""
outlier_set = {int(x) for x in outlier_ids}
bboxes: dict[int, tuple[int, int, int, int]] = {}
# Single chunked pass (shared with the QC reader); only outlier labels are
# accumulated, merging bboxes across chunk boundaries for cells that span them.
# TODO: faster path -- pre-mask each chunk with np.where(np.isin(chunk,
# outlier_set), chunk, 0) before regionprops, so non-outlier cells are
# skipped instead of scanned. Worth doing if outlier fraction is < ~5%.
for lid, region, y0, x0 in iter_chunked_regionprops(labels_da, chunk_size=chunk_size, label_subset=outlier_set):
r0, c0, r1, c1 = region.bbox
r0 += y0
c0 += x0
r1 += y0
c1 += x0
prev = bboxes.get(lid)
if prev is None:
bboxes[lid] = (r0, c0, r1, c1)
else:
bboxes[lid] = (min(prev[0], r0), min(prev[1], c0), max(prev[2], r1), max(prev[3], c1))
return bboxes
def _bbox_edge_run(
contour: np.ndarray,
perp_axis: int,
target: float,
distance_tol: float = _STITCH_DEFAULTS.distance_tol,
min_coverage: float = _STITCH_DEFAULTS.min_edge_coverage,
) -> tuple[float, float, float] | None:
"""Find the extent of contour points lying near a single bbox edge.
A genuine cut edge has many contour points clustered at the bbox boundary,
spanning a long parallel-axis range with high integer-position coverage.
A naturally curved cell only touches its bbox at a single point, which
fails either the count, length, or coverage check.
Returns ``(ext_lo, ext_hi, length)`` if a substantial run is found.
"""
parallel_axis = 1 - perp_axis
near = np.abs(contour[:, perp_axis] - target) <= distance_tol
if near.sum() < 3:
return None
parallel_vals = contour[near, parallel_axis]
ext_lo = float(parallel_vals.min())
ext_hi = float(parallel_vals.max())
length = ext_hi - ext_lo
if length <= 0:
return None
width = max(int(np.ceil(length)), 1)
bins = np.zeros(width + 1, dtype=bool)
bins[np.clip((parallel_vals - ext_lo).astype(int), 0, width)] = True
coverage = float(bins.sum()) / (width + 1)
if coverage < min_coverage:
return None
return ext_lo, ext_hi, length
def _extract_cut_edges(
labels_da: xr.DataArray | np.ndarray,
outlier_ids: Iterable[int],
bboxes: dict[int, tuple[int, int, int, int]] | None = None,
distance_tol: float = _STITCH_DEFAULTS.distance_tol,
min_edge_length: float = _STITCH_DEFAULTS.min_edge_length,
min_edge_length_ratio: float = _STITCH_DEFAULTS.min_edge_length_ratio,
min_edge_coverage: float = _STITCH_DEFAULTS.min_edge_coverage,
) -> tuple[list[_CutEdge], dict[int, np.ndarray]]:
"""Extract cardinal-aligned bbox-edge runs (cut-edge candidates) per outlier.
For each outlier cell:
1. Crop labels to its bbox + 1 px pad, build a binary mask.
2. Trace its contour with :func:`skimage.measure.find_contours`.
3. Check each of the 4 bbox-edge lines for a substantial straight run.
A piece cut at a tile boundary always has its cut on a bbox edge -- the
piece terminates exactly at the cut. Curved cells only touch the bbox
at a single contour point, which the density check rejects.
Cells at a 4-tile corner produce 2 perpendicular edges; mid-stripe pieces
can produce 2 parallel edges.
Returns
-------
The list of cut edges and, as a by-product of the per-cell crop already
read here, a ``{label_id -> boolean bbox mask}`` dict that lets the scoring
pass reconstruct merge unions in memory without re-reading the labels array.
"""
outlier_list = [int(x) for x in outlier_ids]
if bboxes is None:
bboxes = _compute_outlier_bboxes(labels_da, outlier_list)
edges: list[_CutEdge] = []
outlier_crops: dict[int, np.ndarray] = {}
for lid in outlier_list:
bbox = bboxes.get(lid)
if bbox is None:
continue
min_r, min_c, max_r, max_c = bbox
crop_arr = _read_bbox_slice(labels_da, min_r, max_r, min_c, max_c)
cell_mask = crop_arr == lid # boolean bbox mask; reused by the scoring pass
if not cell_mask.any():
continue
outlier_crops[lid] = cell_mask
mask = np.pad(cell_mask.astype(np.float32), 1, mode="constant", constant_values=0)
contour = largest_contour(mask)
if contour is None:
continue
contour_global = contour.copy()
contour_global[:, 0] += min_r - 1
contour_global[:, 1] += min_c - 1
# Local centroid from the mask (avoids a second regionprops call).
ys, xs = np.where(mask)
cy = float(ys.mean()) + min_r - 1
cx = float(xs.mean()) + min_c - 1
area = float(mask.sum())
eq_diameter = equivalent_diameter(area)
min_len = max(min_edge_length, min_edge_length_ratio * eq_diameter)
# find_contours places level set 0.5 outside the integer pixel boundary.
bbox_targets = [
("h", float(min_r) - 0.5),
("h", float(max_r) - 0.5),
("v", float(min_c) - 0.5),
("v", float(max_c) - 0.5),
]
for axis, target in bbox_targets:
perp_axis = 0 if axis == "h" else 1
run = _bbox_edge_run(contour_global, perp_axis, target, distance_tol, min_edge_coverage)
if run is None:
continue
ext_lo, ext_hi, length = run
if length < min_len:
continue
cell_coord = cy if axis == "h" else cx
normal = 1 if cell_coord > target else -1
edges.append(
_CutEdge(
cell_id=lid,
axis=axis,
coord=target,
extent=(ext_lo, ext_hi),
normal_dir=normal,
length=float(length),
)
)
return edges, outlier_crops
# Pair candidate enumeration + features
def _extent_overlap(a: tuple[float, float], b: tuple[float, float]) -> float:
return max(0.0, min(a[1], b[1]) - max(a[0], b[0]))
def _merge_shape_features(
cell_a: int,
cell_b: int,
bboxes: dict[int, tuple[int, int, int, int]],
outlier_crops: dict[int, np.ndarray],
close_radius: int = _STITCH_DEFAULTS.close_radius,
*,
H: int,
W: int,
) -> dict[str, float]:
"""Reconstruct the union of two pieces, close the gap, and return shape stats.
Solidity (area / convex_hull_area) and compactness (4*pi*A / P^2) drop
sharply when two unrelated cells are joined -- the union is concave at the
join. ``merge_compactness`` is typically the strongest single
discriminator between true cuts and false merges.
The union mask is assembled in memory from the per-cell boolean crops
already collected by :func:`_extract_cut_edges`, so this never re-reads the
(possibly dask-backed) labels array -- which was the hot-loop cost, as the
old version fetched a crop once per candidate pair.
"""
zero = {"merge_solidity": 0.0, "merge_compactness": 0.0}
if cell_a not in bboxes or cell_b not in bboxes:
return zero
if cell_a not in outlier_crops or cell_b not in outlier_crops:
return zero
r0a, c0a, r1a, c1a = bboxes[cell_a]
r0b, c0b, r1b, c1b = bboxes[cell_b]
# Padded + border-clamped union bbox. Identical bounds to the old single
# `np.isin` crop, so the reconstructed mask matches it pixel-for-pixel.
pad = close_radius + 2
r0 = max(min(r0a, r0b) - pad, 0)
c0 = max(min(c0a, c0b) - pad, 0)
r1 = min(max(r1a, r1b) + pad, H)
c1 = min(max(c1a, c1b) + pad, W)
mask = np.zeros((r1 - r0, c1 - c0), dtype=bool)
# Place each cell's pre-fetched bbox mask at its offset within the union.
mask[r0a - r0 : r1a - r0, c0a - c0 : c1a - c0] |= outlier_crops[cell_a]
mask[r0b - r0 : r1b - r0, c0b - c0 : c1b - c0] |= outlier_crops[cell_b]
if not mask.any():
return zero
closed = binary_closing(mask, structure=morph_disk(close_radius))
cc = cc_label(closed, connectivity=2)
if cc.max() == 0:
return zero
sizes = np.bincount(cc.ravel())
sizes[0] = 0
biggest = int(sizes.argmax())
region = regionprops((cc == biggest).astype(np.uint8))[0]
perimeter = max(region.perimeter, 1.0)
compactness = float(min(4 * np.pi * region.area / (perimeter * perimeter), 1.0))
# Clamp solidity to 1.0: skimage can return area/convex_area slightly >1 for
# thin/degenerate rasterised regions, which would push the score out of [0, 1].
solidity = float(min(region.solidity, 1.0))
return {"merge_solidity": solidity, "merge_compactness": compactness}
def _pair_geometry_features(
e: _CutEdge,
c: _CutEdge,
max_gap: float,
candidate_min_iou: float = _STITCH_DEFAULTS.candidate_min_iou,
) -> dict[str, float] | None:
"""Compute geometry-only features for a candidate pair, returning ``None``
if the pair fails the basic facing/overlap/IoU filters.
"""
if c.normal_dir == e.normal_dir:
return None
# Facing: cell with +1 normal must sit at greater coord than cell with -1.
if (e.coord - c.coord) * e.normal_dir < -1e-6:
return None
overlap = _extent_overlap(e.extent, c.extent)
if overlap <= 0:
return None
union = e.length + c.length - overlap
iou = overlap / union if union > 0 else 0.0
if iou < candidate_min_iou:
return None
gap = abs(e.coord - c.coord)
if gap > max_gap:
return None
endpoint_dist = abs(e.extent[0] - c.extent[0]) + abs(e.extent[1] - c.extent[1])
max_len = max(e.length, c.length)
endpoint_match = max(0.0, 1.0 - endpoint_dist / max_len) if max_len > 0 else 0.0
# Return the raw perpendicular gap; gap_proximity is derived later against
# the closing reach (2*close_radius), NOT against max_gap (a search radius).
return {
"iou": float(iou),
"endpoint_match": float(endpoint_match),
"gap": float(gap),
}
def _enumerate_pair_candidates(
edges: list[_CutEdge],
max_gap: float,
candidate_min_iou: float = _STITCH_DEFAULTS.candidate_min_iou,
) -> list[tuple[_CutEdge, _CutEdge, dict[str, float]]]:
"""Find all (e, c) pairs of facing cut edges with their geometry features.
Returns one entry per surviving candidate. No selection / scoring yet.
"""
out: list[tuple[_CutEdge, _CutEdge, dict[str, float]]] = []
by_axis: dict[str, list[_CutEdge]] = {"h": [], "v": []}
for e in edges:
by_axis[e.axis].append(e)
for axis_edges in by_axis.values():
axis_edges.sort(key=lambda e: e.coord)
coords = np.array([e.coord for e in axis_edges])
for i, e in enumerate(axis_edges):
lo = int(np.searchsorted(coords, e.coord - max_gap, side="left"))
hi = int(np.searchsorted(coords, e.coord + max_gap, side="right"))
for j in range(lo, hi):
if j <= i:
continue # symmetry: emit each unordered pair once
c = axis_edges[j]
if c.cell_id == e.cell_id:
continue
feats = _pair_geometry_features(e, c, max_gap, candidate_min_iou=candidate_min_iou)
if feats is None:
continue
out.append((e, c, feats))
return out
# Scoring
def _gap_proximity(gap: float, close_radius: int) -> float:
"""Map the raw perpendicular gap to [0, 1] against the closing reach.
Normalised by ``2 * close_radius`` -- the scale at which morphological
closing could actually bridge the seam -- so the feature is independent of
the ``max_gap`` search radius and only reaches 0 when the gap genuinely
exceeds what closing can join. When closing is disabled (``close_radius=0``)
the feature is inactive and returns ``1.0`` rather than collapsing the score.
"""
reach = 2 * close_radius
# gap<=0 (touching/overlapping) or reach<=0 (closing disabled, close_radius=0)
# -> the feature is inactive (neutral 1.0), never a silent score cliff.
if gap <= 0 or reach <= 0:
return 1.0
return max(0.0, 1.0 - gap / reach)
def _score_pair_features(features: dict[str, float]) -> float:
"""Return the heuristic stitch score in [0, 1].
Flat (unweighted) mean of the five features in :data:`_SCORE_FEATURES`.
The score is dataset-independent and not a calibrated probability -- users
pick ``min_confidence`` based on their false-merge tolerance.
"""
return float(sum(features[name] for name in _SCORE_FEATURES) / len(_SCORE_FEATURES))
def _max_achievable_score(known_features: dict[str, float]) -> float:
"""Upper bound on the stitch score from the cheap geometry features alone.
The deferred shape features (:data:`_SHAPE_FEATURES`) are each in ``[0, 1]``,
so assume their best case. Built on :func:`_score_pair_features` so the bound
can never drift from the real score if the feature set or weighting changes.
"""
return _score_pair_features({**known_features, **dict.fromkeys(_SHAPE_FEATURES, 1.0)})
def _score_pairs(
candidates: list[tuple[_CutEdge, _CutEdge, dict[str, float]]],
bboxes: dict[int, tuple[int, int, int, int]],
outlier_crops: dict[int, np.ndarray],
min_confidence: float,
close_radius: int = _STITCH_DEFAULTS.close_radius,
*,
H: int,
W: int,
) -> list[_StitchPair]:
"""Compute shape features per candidate, score, and keep pairs >= min_confidence.
One entry per ``(cell_a, cell_b, axis)`` (keeping max confidence on duplicates).
"""
scored: list[_StitchPair] = []
for e, c, geom in candidates:
known = {**geom, "gap_proximity": _gap_proximity(geom["gap"], close_radius)}
# Skip the costly union reconstruction when even the best case for the
# deferred shape features can't reach min_confidence.
if _max_achievable_score(known) < min_confidence:
continue
shape = _merge_shape_features(e.cell_id, c.cell_id, bboxes, outlier_crops, close_radius=close_radius, H=H, W=W)
feats = {**known, **shape}
confidence = _score_pair_features(feats)
if confidence < min_confidence:
continue
# Canonicalise so cell_a < cell_b for deterministic union-find.
if e.cell_id < c.cell_id:
ea, eb = e, c
else:
ea, eb = c, e
scored.append(
_StitchPair(
cell_a=ea.cell_id,
cell_b=eb.cell_id,
axis=e.axis,
confidence=confidence,
iou=feats["iou"],
endpoint_match=feats["endpoint_match"],
gap_proximity=feats["gap_proximity"],
merge_solidity=feats["merge_solidity"],
merge_compactness=feats["merge_compactness"],
edge_a=ea,
edge_b=eb,
)
)
# Deduplicate to one entry per (cell_a, cell_b, axis), keeping max confidence.
by_pair: dict[tuple[int, int, str], _StitchPair] = {}
for p in scored:
k = (p.cell_a, p.cell_b, p.axis)
if k not in by_pair or by_pair[k].confidence < p.confidence:
by_pair[k] = p
return sorted(by_pair.values(), key=lambda p: (-p.confidence, p.cell_a, p.cell_b))
# Group assembly (union-find + validation)
def _validate_group_geometry(
pairs_in_group: list[_StitchPair],
size: int,
max_gap: float,
) -> bool:
"""Geometric sanity check for groups of size >= 3.
Two cases:
- **Corner group** (size 4, both axes present): the cut edges' endpoints
must converge near a single junction point (one ``h`` cut crossing one
``v`` cut defines the junction). If the spread of edge extents from
the junction is greater than ``max_gap``, the group is implausible.
- **Chain group** (size 3 or 4, all pairs share one axis): legitimate
same-axis chains (e.g., a cell split by 3 horizontal seams into 4
vertically-stacked pieces) have pairs at N-1 *distinct* seam
coordinates. Multiple pairs at the same seam coord would imply
geometrically impossible "two cuts at the same seam" pairings -- a
signature of a false-positive cluster -- so we reject.
"""
h_pairs = [p for p in pairs_in_group if p.axis == "h"]
v_pairs = [p for p in pairs_in_group if p.axis == "v"]
# Chain case: only one axis present and size >= 3.
if not h_pairs or not v_pairs:
if size < 3:
return True # 2-piece groups are trivially valid on one axis
# Each pair's seam coord is roughly midway between its two edges.
seam_coords = [round((p.edge_a.coord + p.edge_b.coord) / 2.0, 1) for p in pairs_in_group]
# Allow a max_gap-sized tolerance for "distinct" seams.
sorted_coords = sorted(seam_coords)
for prev, cur in zip(sorted_coords, sorted_coords[1:], strict=False):
if cur - prev <= max_gap:
return False
return True
# Mixed-axis case: only validate the 4-piece corner pattern. 3-piece
# L-shapes (one h pair + one v pair sharing a corner cell) are
# geometrically valid and don't have a junction to converge on.
if size != 4:
return True
# Corner case: both axes present, size 4. Junction y/x is the mean of edge coords.
h_edges = [p.edge_a for p in h_pairs] + [p.edge_b for p in h_pairs]
v_edges = [p.edge_a for p in v_pairs] + [p.edge_b for p in v_pairs]
junction_y = float(np.mean([e.coord for e in h_edges]))
junction_x = float(np.mean([e.coord for e in v_edges]))
for e in h_edges:
if min(abs(e.extent[0] - junction_x), abs(e.extent[1] - junction_x)) > max_gap:
return False
for e in v_edges:
if min(abs(e.extent[0] - junction_y), abs(e.extent[1] - junction_y)) > max_gap:
return False
return True
def _assemble_groups(
pairs: list[_StitchPair],
candidate_ids: Iterable[int],
max_group_size: int,
max_gap: float,
) -> tuple[dict[int, int], dict[int, float]]:
"""Build stitch groups via union-find with size + corner validation.
Returns
-------
groups
``cell_id -> group_id`` (group_id == own cell_id for unstitched).
confidences
``cell_id -> stitch_confidence`` -- min over pairwise confidences in
the cell's group; ``1.0`` for confirmed-solo (no surviving pair).
"""
# Build undirected connected components via scipy. Cells map to a
# contiguous [0, n) index space; pairs become symmetric edges in a CSR
# adjacency matrix. We then re-key components by the smallest cell_id
# they contain so the group root is deterministic.
candidate_list = sorted({int(c) for c in candidate_ids})
if not candidate_list:
return {}, {}
id_to_idx = {cid: i for i, cid in enumerate(candidate_list)}
n = len(candidate_list)
valid_pairs = [p for p in pairs if p.cell_a in id_to_idx and p.cell_b in id_to_idx]
if valid_pairs:
rows = [id_to_idx[p.cell_a] for p in valid_pairs]
cols = [id_to_idx[p.cell_b] for p in valid_pairs]
adj = csr_matrix((np.ones(len(rows), dtype=np.int8), (rows, cols)), shape=(n, n))
_, comp_labels = connected_components(adj, directed=False)
else:
comp_labels = np.arange(n)
cells_by_comp: dict[int, list[int]] = {}
for i, comp in enumerate(comp_labels):
cells_by_comp.setdefault(int(comp), []).append(candidate_list[i])
members: dict[int, list[int]] = {}
root_of_cell: dict[int, int] = {}
for comp_members in cells_by_comp.values():
comp_members.sort()
root = comp_members[0]
members[root] = comp_members
for cid in comp_members:
root_of_cell[cid] = root
pairs_by_group: dict[int, list[_StitchPair]] = {}
for p in valid_pairs:
pairs_by_group.setdefault(root_of_cell[p.cell_a], []).append(p)
groups: dict[int, int] = {}
confidences: dict[int, float] = {}
for root, mem in members.items():
size = len(mem)
group_pairs = pairs_by_group.get(root, [])
# Size cap: collapse oversized groups back to singletons.
if size > max_group_size:
for m in mem:
groups[m] = m
confidences[m] = 1.0
continue
# Geometric validation for 3+ piece groups: corner-junction for
# mixed-axis 4-groups, chain (distinct seam coords) for same-axis 3+.
if size >= 3 and not _validate_group_geometry(group_pairs, size, max_gap):
for m in mem:
groups[m] = m
confidences[m] = 1.0
continue
if size == 1:
groups[mem[0]] = mem[0]
confidences[mem[0]] = 1.0
continue
# Group confidence = min over pairwise confidences (weakest link).
group_conf = float(min(p.confidence for p in group_pairs))
for m in mem:
groups[m] = root
confidences[m] = group_conf
return groups, confidences
# Public entry point
[docs]
def assign_stitch_groups(
sdata: sd.SpatialData,
labels_key: str,
qc_table_key: str | None = None,
min_confidence: float = 0.7,
max_gap: float = 3.0,
max_group_size: int = 4,
stitch_params: StitchParams | Mapping[str, Any] | None = None,
inplace: bool = True,
) -> ad.AnnData | None:
"""Assign tile-cut cell pieces to stitch groups.
Reads ``is_outlier=True`` cells flagged by
:func:`~squidpy.experimental.tl.calculate_tiling_qc`, pairs facing cut
edges across tile boundaries, scores each pair via a transparent geometric
composite, and assembles high-confidence pairs into stitch groups via
union-find. This only *annotates* which pieces belong together -- it does
**not** modify the labels element. Materialising a stitched labels element
is opt-in via :func:`!make_stitched_labels`.
The score per pair is the flat (unweighted) mean of five geometric features
in [0, 1]: ``iou`` (1-D extent overlap), ``endpoint_match`` (chord endpoints
coincide), ``merge_compactness`` (``4*pi*A / P^2`` of the closed union mask),
``merge_solidity`` (union area / convex hull area), and ``gap_proximity``
(seam gap relative to the morphological closing reach). No coefficients are
fitted or shipped; the features are recorded in ``.uns["tiling_stitch"]``.
Parameters
----------
sdata
:class:`~spatialdata.SpatialData` with a labels element and a QC
table from :func:`~squidpy.experimental.tl.calculate_tiling_qc`.
labels_key
Key in ``sdata.labels``.
qc_table_key
Key of the QC table. Defaults to ``"{labels_key}_qc"``.
min_confidence
Threshold on ``stitch_confidence``. ``0.7`` (default) is a starting
point; raise it for stricter precision, lower for recall. Tune for
your data -- the score is heuristic, not a calibrated probability.
max_gap
Maximum perpendicular distance (px) between facing cut edges for a pair
to be *considered* a candidate. This is a search radius only; it does
not scale the score.
max_group_size
Cap on group size; oversized groups (likely false merges) collapse
to singletons.
stitch_params
Advanced tuning knobs as a :class:`StitchParams` instance or a
``Mapping`` of its field names to values. See :class:`StitchParams`
for each field's meaning and default. ``None`` (default) uses
all defaults.
inplace
If ``True``, write back into ``sdata.tables[qc_table_key]``.
Otherwise return the modified AnnData.
Returns
-------
The QC :class:`~anndata.AnnData` with four new ``.obs`` columns when
``inplace=False``, otherwise ``None``.
"""
if labels_key not in sdata.labels:
raise ValueError(f"Labels key '{labels_key}' not found in sdata.labels.")
if min_confidence < 0 or min_confidence > 1:
raise ValueError(f"min_confidence must be in [0, 1], got {min_confidence}.")
if max_gap < 0:
raise ValueError(f"max_gap must be non-negative, got {max_gap}.")
if max_group_size < 1:
raise ValueError(f"max_group_size must be >= 1, got {max_group_size}.")
params = _resolve_stitch_params(stitch_params)
table_key = qc_table_key if qc_table_key is not None else f"{labels_key}_qc"
if table_key not in sdata.tables:
raise ValueError(f"QC table '{table_key}' not found. Run calculate_tiling_qc first.")
adata = sdata.tables[table_key].copy()
if "is_outlier" not in adata.obs.columns:
raise ValueError(f"QC table '{table_key}' is missing 'is_outlier'; re-run calculate_tiling_qc.")
if "label_id" not in adata.obs.columns:
raise ValueError(f"QC table '{table_key}' is missing 'label_id'.")
existing = [c for c in _STITCH_COLUMNS if c in adata.obs.columns]
if existing:
logg.warning(f"Overwriting existing stitch columns: {existing}.")
adata.obs.drop(columns=existing, inplace=True)
# Resolve which labels DataArray was used at QC time (multi-scale aware).
qc_params = adata.uns.get("tiling_qc", {})
scale = qc_params.get("scale")
labels_da = resolve_labels_array(sdata, labels_key, scale)
label_ids = adata.obs["label_id"].astype(int).to_numpy()
is_outlier = adata.obs["is_outlier"].to_numpy(dtype=bool)
outlier_ids = label_ids[is_outlier].tolist()
n_outliers = len(outlier_ids)
logg.info(f"Stitching {n_outliers} outlier cells (out of {len(label_ids)} total).")
if n_outliers == 0:
logg.warning("No outliers flagged; nothing to stitch.")
groups: dict[int, int] = {}
confidences: dict[int, float] = {}
edges: list[_CutEdge] = []
pairs: list[_StitchPair] = []
else:
bboxes = _compute_outlier_bboxes(labels_da, outlier_ids)
missing = [lid for lid in outlier_ids if lid not in bboxes]
if missing:
logg.warning(
f"{len(missing)} outlier label_id(s) flagged in the QC table do not appear "
f"in '{labels_key}' (e.g. {missing[:5]}); they will not be stitched."
)
edges, outlier_crops = _extract_cut_edges(
labels_da,
outlier_ids,
bboxes=bboxes,
distance_tol=params.distance_tol,
min_edge_length=params.min_edge_length,
min_edge_length_ratio=params.min_edge_length_ratio,
min_edge_coverage=params.min_edge_coverage,
)
H, W = labels_da.shape[-2], labels_da.shape[-1]
candidates = _enumerate_pair_candidates(edges, max_gap=max_gap, candidate_min_iou=params.candidate_min_iou)
pairs = _score_pairs(
candidates, bboxes, outlier_crops, min_confidence, close_radius=params.close_radius, H=H, W=W
)
groups, confidences = _assemble_groups(pairs, outlier_ids, max_group_size=max_group_size, max_gap=max_gap)
# Write .obs columns with three states distinguished by stitch_confidence:
# - non-outlier cell -> own label_id, False, 1, NaN (not evaluated)
# - outlier solo -> own label_id, False, 1, 1.0 (checked, no partner)
# - outlier stitched -> shared root, True, n, composite score
n = len(label_ids)
stitch_group_id = label_ids.copy()
is_stitched = np.zeros(n, dtype=bool)
n_pieces = np.ones(n, dtype=np.int32)
stitch_confidence = np.full(n, np.nan, dtype=np.float64)
group_sizes: dict[int, int] = {}
if outlier_ids:
for root in groups.values():
group_sizes[root] = group_sizes.get(root, 0) + 1
id_to_idx = {int(lid): i for i, lid in enumerate(label_ids)}
for cid, root in groups.items():
i = id_to_idx[int(cid)]
stitch_group_id[i] = int(root)
size = group_sizes[root]
n_pieces[i] = size
is_stitched[i] = size > 1
stitch_confidence[i] = float(confidences.get(cid, 1.0))
adata.obs["stitch_group_id"] = stitch_group_id
adata.obs["is_stitched"] = is_stitched
adata.obs["n_pieces"] = n_pieces
adata.obs["stitch_confidence"] = stitch_confidence
n_groups = sum(1 for s in group_sizes.values() if s > 1)
n_stitched = int(is_stitched.sum())
# Use string keys so the dict round-trips through zarr-backed .uns cleanly.
pieces_dist: dict[str, int] = {}
for s in group_sizes.values():
if s > 1:
key = str(int(s))
pieces_dist[key] = pieces_dist.get(key, 0) + 1
adata.uns[_METHOD_KEY] = {
"min_confidence": float(min_confidence),
"max_gap": float(max_gap),
"max_group_size": int(max_group_size),
"stitch_params": asdict(params),
"n_outliers": int(n_outliers),
"n_candidate_pairs": int(len(pairs)),
"n_stitched_groups": int(n_groups),
"n_stitched_cells": int(n_stitched),
"n_pieces_distribution": pieces_dist,
"score_features": list(_SCORE_FEATURES),
}
if not inplace:
return adata
sdata.tables[table_key] = adata
return None