"""Public, ``sdata``-aware entry points for stain normalization.
The single integration boundary for the stain module: the only file that
reads ``sdata.images[...]``, writes back via :class:`Image2DModel`, and is
re-exported publicly. Everything it calls is a pure DataArray-layer
primitive (:mod:`._reinhard`, :mod:`._mask`, :mod:`._conversion`).
Both entry points dispatch on the fitting ``method`` (``"reinhard"`` colour
transfer, or ``"macenko"``/``"vahadane"`` absorbance decomposition); a third
entry, :func:`decompose_stains`, projects an image onto its stain matrix.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Literal
import numpy as np
import spatialdata as sd
import xarray as xr
from numpy.typing import DTypeLike
from spatialdata.models import Image2DModel
from spatialdata.transformations import get_transformation
from squidpy._utils import _get_scale_factors
from squidpy.experimental.im._stain._constants import RUIFROK_HE
from squidpy.experimental.im._stain._conversion import _check_channel_dim, cast_to_image_dtype
from squidpy.experimental.im._stain._decomposition import (
MacenkoParams,
VahadaneParams,
_resolve_macenko_params,
_resolve_vahadane_params,
apply_decomposition,
decompose_to_concentrations,
fit_decomposition,
)
from squidpy.experimental.im._stain._reference import StainMethod, StainReference
from squidpy.experimental.im._stain._reinhard import (
ReinhardParams,
_resolve_reinhard_params,
apply_reinhard,
fit_reinhard,
)
from squidpy.experimental.im._stain._white_point import (
default_white_point,
validate_rgb_range,
white_point_from_background,
)
from squidpy.experimental.im._utils import (
_choose_label_scale_for_image,
get_element_data,
get_mask_materialized,
resolve_tissue_mask,
)
_VALID_METHODS = ("reinhard", "macenko", "vahadane")
_DECOMPOSITION_METHODS = ("macenko", "vahadane")
_CONCENTRATION_CHANNELS = ["hematoxylin", "eosin", "residual"]
# Public union accepted by the method_params argument of the dispatchers.
MethodParams = ReinhardParams | MacenkoParams | VahadaneParams | Mapping[str, Any] | None
def _resolve_image(
sdata: sd.SpatialData,
image_key: str,
scale: str,
*,
prefer: Literal["coarsest", "finest"],
) -> xr.DataArray:
if image_key not in sdata.images:
raise ValueError(f"image_key {image_key!r} not found, valid keys: {list(sdata.images.keys())}")
node = sdata.images[image_key]
da = get_element_data(node, scale, "image", image_key, prefer=prefer)
_check_channel_dim(da)
return da
def _resolve_mask_key_and_scale(
sdata: sd.SpatialData, image_key: str, target_da: xr.DataArray, tissue_mask_key: str | None
) -> tuple[str, str, tuple[int, int]]:
"""Resolve the (mandatory) tissue-mask key and the label scale closest to ``target_da``.
Shared by the two mask consumers below. Consumes a
:func:`!detect_tissue` labels element - raises if
none exists.
"""
mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False)
target_hw = (int(target_da.sizes["y"]), int(target_da.sizes["x"]))
label_scale = _choose_label_scale_for_image(sdata.labels[mask_key], target_hw)
return mask_key, label_scale, target_hw
def _resolve_tissue_bool_mask(
sdata: sd.SpatialData, image_key: str, fit_da: xr.DataArray, tissue_mask_key: str | None
) -> np.ndarray:
"""Return a materialised ``(y, x)`` boolean tissue mask aligned to ``fit_da``.
For the (coarse) fit: nearest-resizes to ``fit_da``'s ``(y, x)`` when the
closest label scale differs. The fits run on a coarse level, so the mask
stays small.
"""
mask_key, label_scale, target_hw = _resolve_mask_key_and_scale(sdata, image_key, fit_da, tissue_mask_key)
mask = get_mask_materialized(sdata, mask_key, label_scale) > 0
if mask.shape != target_hw:
from skimage.transform import resize
mask = resize(mask, target_hw, order=0, preserve_range=True) > 0.5
return mask
def _resolve_output_tissue_mask(
sdata: sd.SpatialData, image_key: str, target_da: xr.DataArray, tissue_mask_key: str | None
) -> xr.DataArray:
"""Return a lazy ``(y, x)`` boolean tissue mask aligned to ``target_da``.
Like :func:`_resolve_tissue_bool_mask` but kept lazy and at the (full-res)
output resolution, for compositing the original background back into the
normalized image without materialising the full frame. The label pyramid
shares the image's scale factors, so the matching level usually lines up
exactly; only a residual size mismatch forces a (small) eager resize.
"""
mask_key, label_scale, target_hw = _resolve_mask_key_and_scale(sdata, image_key, target_da, tissue_mask_key)
coords = {d: target_da.coords[d] for d in ("y", "x") if d in target_da.coords}
mask = get_element_data(sdata.labels[mask_key], label_scale, "label", mask_key).squeeze() > 0
if (int(mask.sizes["y"]), int(mask.sizes["x"])) == target_hw:
return mask.assign_coords(coords)
from skimage.transform import resize
resized = resize(np.asarray(mask.data) > 0, target_hw, order=0, preserve_range=True) > 0.5
return xr.DataArray(resized, dims=("y", "x"), coords=coords)
def _resolve_method_params(method: str, method_params: MethodParams) -> Any:
"""Pick the right Params dataclass for ``method`` and resolve a mapping/instance/None."""
if method == "reinhard":
return _resolve_reinhard_params(method_params)
if method == "macenko":
return _resolve_macenko_params(method_params)
if method == "vahadane":
return _resolve_vahadane_params(method_params)
raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.")
def _write_image(
sdata: sd.SpatialData,
source_node: Any,
image_key_added: str,
data_array: xr.DataArray,
*,
c_coords: list[Any] | None = None,
) -> None:
"""Write a derived image element, preserving the source's transforms/pyramid.
Reconstructs the element from the bare array (a derived DataArray would
carry the source's ``transform`` attr and collide with the transformations
we pass) plus the dims/channel-coords/transforms to preserve. The same
idiom as detect_tissue. ``_get_scale_factors`` returns ``[]`` for a
single-scale source; parse needs ``None`` there (an empty list builds a
degenerate single-level pyramid).
"""
if image_key_added in sdata.images:
raise ValueError(f"image_key_added={image_key_added!r} already exists in sdata.images.")
if c_coords is None:
c_coords = data_array.coords["c"].values.tolist() if "c" in data_array.coords else None
sdata.images[image_key_added] = Image2DModel.parse(
data_array.data,
dims=data_array.dims,
c_coords=c_coords,
transformations=get_transformation(source_node, get_all=True),
scale_factors=_get_scale_factors(source_node) or None,
)
[docs]
def estimate_white_point(
sdata: sd.SpatialData,
image_key: str,
*,
tissue_mask_key: str | None = None,
scale: str | Literal["auto"] = "auto",
) -> np.ndarray:
"""Estimate the white point ``I_0`` from a slide's background (non-tissue median).
Opt-in alternative to the fixed dtype-aware default white point, for a slide
whose unstained background is genuinely not full white. Samples the
per-channel median over **non-tissue** pixels (background = the complement of
the :func:`!detect_tissue` mask).
Parameters
----------
sdata, image_key
The SpatialData object and the RGB image key.
tissue_mask_key
Tissue-label element key (defaults to ``f"{image_key}_tissue"``); a
tissue mask is required, as for :func:`fit_stain_reference`.
scale
Scale level to sample on. ``"auto"`` (default) uses the coarsest level.
The sampled level is materialised to take the median, so keep this
coarse - do not pass a fine level on a whole-slide image.
Returns
-------
Shape-``(3,)`` white point; pass it as ``white_point`` to
:func:`fit_stain_reference` / :func:`decompose_stains`.
"""
da = _resolve_image(sdata, image_key, scale, prefer="coarsest")
validate_rgb_range(da)
tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key)
return white_point_from_background(da, ~tissue_mask)
[docs]
def fit_stain_reference(
sdata: sd.SpatialData,
image_key: str,
*,
method: StainMethod = "macenko",
scale: str | Literal["auto"] = "auto",
method_params: MethodParams = None,
white_point: np.ndarray | None = None,
tissue_mask_key: str | None = None,
max_angle_deg: float = 45.0,
canonical_reference: Mapping[str, np.ndarray] | None = None,
) -> StainReference:
"""Fit a stain reference from an image in a :class:`~spatialdata.SpatialData` object.
Parameters
----------
sdata
SpatialData object containing the image.
image_key
Key of the RGB image in ``sdata.images`` to fit on.
method
Fitting method: ``"macenko"`` (default) or ``"vahadane"`` (physical
stain-matrix decomposition, usable by both :func:`normalize_stains` and
:func:`decompose_stains`), or ``"reinhard"`` (faster statistical colour
transfer, no stain separation). Macenko is the default because its one
documented weakness - artifact pixels contaminating the fit - is removed
by the mandatory tissue mask.
scale
Scale level to fit on. ``"auto"`` (default) uses the coarsest level,
which is cheap and sufficient for colour statistics.
method_params
A :class:`ReinhardParams`/:class:`MacenkoParams`/:class:`VahadaneParams`
instance, a mapping of its fields, or ``None`` for defaults. Must match
``method``.
white_point
Per-channel white point ``I_0`` ``(3,)`` for the decomposition methods.
If ``None``, a fixed full-white ``[255, 255, 255]`` is used (the
HistomicsTK/Macenko convention), so unstained pixels round-trip to
white. Pass :func:`estimate_white_point` only for slides with a
known non-white background. Ignored by Reinhard.
tissue_mask_key
Key of a tissue-label element in ``sdata.labels`` (as produced by
:func:`!detect_tissue`) restricting the fit to
tissue pixels. If ``None``, ``f"{image_key}_tissue"`` is used. A tissue
mask is **required**: if neither exists, a :class:`KeyError` asks you to
run :func:`!detect_tissue` first.
max_angle_deg
Tolerance of the H/E sanity gate for the decomposition methods: the fit
raises :class:`!StainFittingError` if either recovered stain vector
deviates more than this many degrees from its canonical reference.
Default ``45``. Ignored by Reinhard.
canonical_reference
Canonical H/E reference for the decomposition methods, a mapping with
``"hematoxylin"`` and ``"eosin"`` keys to ``(3,)`` RGB optical-density
unit vectors. Drives both the H/E column ordering and the deviation
gate. If ``None``, the Ruifrok H&E vectors are used. Ignored by Reinhard.
Returns
-------
The fitted :class:`StainReference`. Nothing is written to ``sdata``.
"""
if method not in _VALID_METHODS:
raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.")
da = _resolve_image(sdata, image_key, scale, prefer="coarsest")
validate_rgb_range(da)
params = _resolve_method_params(method, method_params)
tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key)
if method == "reinhard":
return fit_reinhard(da, params, tissue_mask=tissue_mask)
bg = default_white_point(da) if white_point is None else np.asarray(white_point, np.float64)
reference = RUIFROK_HE if canonical_reference is None else dict(canonical_reference)
return fit_decomposition(
da,
method,
params,
bg,
tissue_mask=tissue_mask,
image_key=image_key,
reference=reference,
max_angle_deg=max_angle_deg,
)
[docs]
def normalize_stains(
sdata: sd.SpatialData,
image_key: str,
reference: StainReference,
*,
scale: str | Literal["auto"] = "auto",
method_params: MethodParams = None,
image_key_added: str | None = None,
inplace: bool = True,
output_dtype: DTypeLike | None = None,
tissue_mask_key: str | None = None,
preserve_background: bool = True,
) -> xr.DataArray | None:
"""Normalize an image to a fitted stain reference.
Parameters
----------
sdata
SpatialData object containing the source image.
image_key
Key of the RGB image in ``sdata.images`` to normalize.
reference
A :class:`StainReference` fitted with :func:`fit_stain_reference`.
Dispatch is on ``reference.method``.
scale
Scale level to normalize. ``"auto"`` (default) uses the finest level
so the result is not downsampled; source statistics are reduced
lazily so memory stays bounded.
method_params
Params matching ``reference.method`` (instance, mapping, or ``None``).
image_key_added
Key for the written image when ``inplace=True``. If ``None`` (default),
``f"{image_key}_normalized"`` is used. Ignored when ``inplace=False``.
inplace
If ``True`` (default), write the normalized image to
``sdata.images[image_key_added]`` (rebuilding the pyramid for multiscale
sources, preserving transforms) and return ``None``; raises if the key
already exists. If ``False``, leave ``sdata`` untouched and return the
lazy normalized :class:`~xarray.DataArray`.
output_dtype
Dtype of the result. If ``None`` (default), the source image's dtype is
used. The reconstruction is clipped to that dtype's valid range and
rounded (for integer dtypes) at the write boundary.
tissue_mask_key
Key of a tissue-label element in ``sdata.labels`` restricting the
*source* statistics to tissue pixels. As for
:func:`fit_stain_reference`, a tissue mask is required (defaults to
``f"{image_key}_tissue"``; raises if missing).
preserve_background
If ``True`` (default), non-tissue (background) pixels are passed through
unchanged from the source image, so the normalization recolours only
tissue. The colour map is a global linear transform that would otherwise
tint background/white pixels. Set ``False`` for full-frame normalization.
Returns
-------
``None`` if ``inplace=True`` (the image is written), otherwise the lazy
normalized :class:`xarray.DataArray`.
"""
da = _resolve_image(sdata, image_key, scale, prefer="finest")
target_key = image_key_added if image_key_added is not None else f"{image_key}_normalized"
if inplace and target_key in sdata.images:
raise ValueError(f"image_key_added={target_key!r} already exists in sdata.images.")
params = _resolve_method_params(reference.method, method_params)
# Source statistics (Reinhard mu/sigma or the decomposition source matrix)
# are reduced on a coarse level with a tissue mask; the lazy transform is
# then applied to the full-resolution `da`.
fit_rgb = _resolve_image(sdata, image_key, scale, prefer="coarsest")
validate_rgb_range(fit_rgb) # reject mis-typed source (e.g. 0-255 float) before the dtype-clipped reconstruction
tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, fit_rgb, tissue_mask_key)
out_dtype = da.dtype if output_dtype is None else np.dtype(output_dtype) # clip range + final cast
if reference.method == "reinhard":
normalized = apply_reinhard(
da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype
)
else:
normalized = apply_decomposition(
da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype
)
if preserve_background:
# Keep non-tissue pixels byte-identical to the source: the global colour
# map would otherwise recolour background/white pixels (HistomicsTK's
# `mask_out`). Stays lazy - the mask aligns to `da` without materialising.
keep = _resolve_output_tissue_mask(sdata, image_key, da, tissue_mask_key)
normalized = normalized.where(keep, da)
# Deferred cast at the write boundary: the reconstruction was kept in float
# (clipped to `out_dtype`'s range); round + cast here so the stored image is
# the requested dtype and integer background stays byte-identical.
normalized = cast_to_image_dtype(normalized, out_dtype)
# The output is a 3-channel RGB image; tag it r/g/b so RGB-aware viewers
# (spatialdata-plot) use one hue-preserving scale, not per-channel auto-contrast.
normalized = normalized.assign_coords(c=["r", "g", "b"])
if not inplace:
return normalized
_write_image(sdata, sdata.images[image_key], target_key, normalized)
return None
[docs]
def decompose_stains(
sdata: sd.SpatialData,
image_key: str,
reference_or_method: StainReference | Literal["macenko", "vahadane"],
*,
scale: str | Literal["auto"] = "auto",
method_params: MethodParams = None,
white_point: np.ndarray | None = None,
image_key_added: str | None = None,
inplace: bool = True,
output_dtype: DTypeLike = np.float16,
tissue_mask_key: str | None = None,
include_residual: bool = True,
) -> dict[str, xr.DataArray] | None:
"""Decompose an image into separate per-stain concentration maps.
Parameters
----------
sdata, image_key
The SpatialData object and the RGB image key to decompose.
reference_or_method
Either a decomposition :class:`StainReference` (its stain matrix and
white point are used) or a method name (``"macenko"``/``"vahadane"``)
to fit on this image first. The reference is the provenance record of
how the maps were produced (method, stain matrix, white point).
scale, method_params, white_point, tissue_mask_key
As for :func:`fit_stain_reference` (only used when a method name is
given; a reference is projected as-is and needs no tissue mask).
image_key_added
Key *prefix* for the written images when ``inplace=True``. If ``None``
(default), ``image_key`` is used, so each stain is written as its own
single-channel image ``sdata.images[f"{image_key}_{stain}"]`` (e.g.
``f"{image_key}_hematoxylin"``). Ignored when ``inplace=False``.
inplace
If ``True`` (default), write each stain as a separate single-channel
image under the ``image_key_added`` prefix and return ``None``; the
write is atomic (all target keys are validated free before any is
written). If ``False``, leave ``sdata`` untouched and return the maps
as a dict.
output_dtype
Dtype of the concentration maps. Defaults to ``float16`` (half the
storage; ~3 significant figures, adequate for concentrations); pass
``float32`` for strict quantification.
include_residual
If ``True`` (default), also produce the ``"residual"`` map. The residual
is the absorbance along the complement direction - a diagnostic of
decomposition quality (extra chromogen, artifacts, or a poor fit), not a
biological stain. Set ``False`` to keep only ``hematoxylin``/``eosin``.
Returns
-------
``None`` if ``inplace=True`` (the maps are written as separate images),
otherwise a ``dict`` mapping each stain name to its ``(y, x)`` concentration
:class:`~xarray.DataArray` (``"hematoxylin"``, ``"eosin"``, and
``"residual"`` unless dropped).
"""
da = _resolve_image(sdata, image_key, scale, prefer="finest")
if isinstance(reference_or_method, StainReference):
reference = reference_or_method
if reference.method not in _DECOMPOSITION_METHODS or reference.stain_matrix is None:
raise ValueError("decompose_stains requires a macenko/vahadane reference with a stain matrix.")
stain_matrix, bg = reference.stain_matrix, reference.white_point
else:
if reference_or_method not in _DECOMPOSITION_METHODS:
raise ValueError(f"method must be one of {list(_DECOMPOSITION_METHODS)}; got {reference_or_method!r}.")
reference = fit_stain_reference(
sdata,
image_key,
method=reference_or_method,
scale=scale,
method_params=method_params,
white_point=white_point,
tissue_mask_key=tissue_mask_key,
)
stain_matrix, bg = reference.stain_matrix, reference.white_point
names = ["hematoxylin", "eosin"] + (["residual"] if include_residual else [])
prefix = image_key_added if image_key_added is not None else image_key
target_keys = [f"{prefix}_{name}" for name in names]
if inplace: # validate all keys free up front, so a partial write can't leave a half-decomposed sdata
clashes = [k for k in target_keys if k in sdata.images]
if clashes:
raise ValueError(f"decompose_stains would overwrite existing image(s): {clashes}.")
concentrations = decompose_to_concentrations(da, stain_matrix, bg).assign_coords(c=_CONCENTRATION_CHANNELS)
concentrations = concentrations.astype(np.dtype(output_dtype))
if not inplace:
return {name: concentrations.sel(c=name) for name in names}
source = sdata.images[image_key]
for name, key in zip(names, target_keys, strict=True):
# keep the c dim (length 1) so Image2DModel.parse accepts it
_write_image(sdata, source, key, concentrations.sel(c=[name]), c_coords=[name])
return None