"""Reinhard (2001) colour transfer in Ruderman Lab space.
Pure DataArray layer: every function takes and returns ``xr.DataArray`` (or
numpy), stays lazy, touches no ``sdata``, and exposes no public surface. The
thin ``sdata`` wrapper lives in :mod:`._normalize`.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass, fields
from typing import Any
import numpy as np
import xarray as xr
from squidpy.experimental.im._stain._constants import DEFAULT_LUMINOSITY_THRESHOLD
from squidpy.experimental.im._stain._conversion import (
_apply_along_channel,
_check_channel_dim,
_working_dtype,
lab_ruderman_to_rgb,
rgb_to_lab_ruderman,
)
from squidpy.experimental.im._stain._mask import as_spatial_mask, foreground_mask_from_lab
from squidpy.experimental.im._stain._reference import StainReference
# Numerical safeguard against divide-by-zero on flat (constant-colour)
# channels. Not a tuning knob, so kept off the public ReinhardParams surface.
_SIGMA_FLOOR: float = 1e-6
[docs]
@dataclass(slots=True, frozen=True)
class ReinhardParams:
"""Tuning knobs for Reinhard stain normalization.
Pass an instance (or a ``Mapping`` of field names to values) as
``method_params``. Frozen so validation in ``__post_init__`` cannot be
silently bypassed by later mutation.
"""
luminosity_threshold: float = DEFAULT_LUMINOSITY_THRESHOLD
"""Normalised Ruderman Lab-L cutoff in ``(0, 1]``; pixels brighter than this are excluded from the fit."""
mask_background: bool = True
"""If ``True``, fit channel statistics over tissue pixels only; if ``False``, use every pixel (vanilla Reinhard)."""
def __post_init__(self) -> None:
object.__setattr__(self, "luminosity_threshold", float(self.luminosity_threshold))
object.__setattr__(self, "mask_background", bool(self.mask_background))
if not 0.0 < self.luminosity_threshold <= 1.0:
raise ValueError(f"`luminosity_threshold` must be in (0, 1], got {self.luminosity_threshold}.")
_REINHARD_DEFAULTS = ReinhardParams()
_REINHARD_FIELDS = frozenset(f.name for f in fields(ReinhardParams))
def _resolve_reinhard_params(method_params: ReinhardParams | Mapping[str, Any] | None) -> ReinhardParams:
"""Normalise the ``method_params`` argument to a :class:`ReinhardParams` instance."""
if method_params is None:
return _REINHARD_DEFAULTS
if isinstance(method_params, ReinhardParams):
return method_params
if isinstance(method_params, Mapping):
unknown = set(method_params) - _REINHARD_FIELDS
if unknown:
raise ValueError(
f"Unknown `method_params` field(s): {sorted(unknown)}; expected from {sorted(_REINHARD_FIELDS)}."
)
return ReinhardParams(**method_params)
raise TypeError(f"`method_params` must be ReinhardParams, Mapping, or None; got {type(method_params).__name__}.")
def _masked_channel_stats(lab: xr.DataArray, mask: xr.DataArray | None) -> tuple[np.ndarray, np.ndarray]:
"""Per-channel mean and std over the spatial dims, tissue pixels only.
Lazy: the masked mean and std are bundled into one dataset and computed
in a single pass, never materialising the full image. Returns two
shape-``(3,)`` float64 arrays in channel order.
Raises ``ValueError`` if the mask leaves no tissue pixels in any channel
(the mean would be NaN), with an actionable message.
"""
masked = lab.where(mask) if mask is not None else lab
stats = xr.Dataset(
{
"mu": masked.mean(dim=("y", "x"), skipna=True),
"sigma": masked.std(dim=("y", "x"), skipna=True),
}
).compute()
mu = np.asarray(stats["mu"].values, dtype=np.float64)
sigma = np.asarray(stats["sigma"].values, dtype=np.float64)
if not (np.all(np.isfinite(mu)) and np.all(np.isfinite(sigma))):
raise ValueError(
"Foreground mask leaves zero tissue pixels in at least one channel; "
"the luminosity_threshold may be too low or the image may be blank."
)
return mu, sigma
def _transfer_kernel(
x: np.ndarray,
*,
mu_src: np.ndarray,
sigma_src: np.ndarray,
mu_ref: np.ndarray,
sigma_ref: np.ndarray,
dtype: np.dtype,
) -> np.ndarray:
x = x.astype(dtype, copy=False)
return ((x - mu_src) / sigma_src * sigma_ref + mu_ref).astype(dtype, copy=False)
def _reinhard_mask(lab: xr.DataArray, params: ReinhardParams, tissue_mask: np.ndarray | None) -> xr.DataArray | None:
"""Resolve the tissue mask for the Reinhard stats: external mask wins, else
the param-driven luminosity mask (or ``None`` for vanilla Reinhard)."""
if tissue_mask is not None:
return as_spatial_mask(tissue_mask, lab)
if params.mask_background:
return foreground_mask_from_lab(lab, params.luminosity_threshold)
return None
def fit_reinhard(
image_rgb: xr.DataArray, params: ReinhardParams, *, tissue_mask: np.ndarray | None = None
) -> StainReference:
"""Fit Reinhard channel statistics on a reference image.
Converts to Ruderman Lab, computes per-channel ``mu``/``sigma`` over
tissue pixels, and packs them into a ``StainReference(method="reinhard")``.
``tissue_mask`` (a ``(y, x)`` boolean aligned to ``image_rgb``) selects the
tissue pixels when given; otherwise the ``mask_background`` /
``luminosity_threshold`` params drive the mask.
"""
_check_channel_dim(image_rgb)
lab = rgb_to_lab_ruderman(image_rgb)
mu, sigma = _masked_channel_stats(lab, _reinhard_mask(lab, params, tissue_mask))
return StainReference(method="reinhard", mu=mu, sigma=sigma)
def apply_reinhard(
image_rgb: xr.DataArray,
reference: StainReference,
params: ReinhardParams,
*,
fit_rgb: xr.DataArray | None = None,
tissue_mask: np.ndarray | None = None,
out_dtype: np.dtype | type = np.uint8,
) -> xr.DataArray:
"""Apply a Reinhard reference to a source image.
Standardises by the source's own tissue statistics, rescales to the
reference statistics, and converts back to RGB. The transform is applied
to every pixel of ``image_rgb`` (the map is global); the defining
statistics are reduced on ``fit_rgb`` (a coarse level) when given, so the
full-resolution image is never materialised to compute them.
``tissue_mask`` (aligned to ``fit_rgb``) selects the source tissue pixels.
Lazy if and only if ``image_rgb`` is lazy.
"""
_check_channel_dim(image_rgb)
fit_lab = rgb_to_lab_ruderman(fit_rgb if fit_rgb is not None else image_rgb)
mu_src, sigma_src = _masked_channel_stats(fit_lab, _reinhard_mask(fit_lab, params, tissue_mask))
sigma_src = np.maximum(sigma_src, _SIGMA_FLOOR)
lab = rgb_to_lab_ruderman(image_rgb)
dtype = _working_dtype(lab)
lab_out = _apply_along_channel(
lab,
_transfer_kernel,
out_dtype=dtype,
mu_src=mu_src.astype(dtype, copy=False),
sigma_src=sigma_src.astype(dtype, copy=False),
mu_ref=np.asarray(reference.mu, dtype=dtype),
sigma_ref=np.asarray(reference.sigma, dtype=dtype),
dtype=dtype,
)
return lab_ruderman_to_rgb(lab_out, out_dtype=out_dtype)