"""Macenko and Vahadane stain decomposition (fit + apply).
Pure DataArray/numpy layer: no ``sdata``, no public export. The stain-matrix
fits run on tissue pixels (a bounded reduction at the chosen scale); the apply
transform is a single per-pixel matmul and stays lazy.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass, fields
from typing import Any
import numpy as np
import xarray as xr
from squidpy.experimental.im._stain._constants import RUIFROK_HE
from squidpy.experimental.im._stain._conversion import (
_apply_along_channel,
_check_channel_dim,
_working_dtype,
rgb_to_sda,
sda_to_rgb,
)
from squidpy.experimental.im._stain._mask import as_spatial_mask, foreground_mask_from_sda
from squidpy.experimental.im._stain._reference import StainMethod, StainReference
from squidpy.experimental.im._stain._validation import (
StainFittingError,
_unit_columns,
complement_third_column,
reorder_to_canonical,
validate_stain_matrix,
)
_MAXC_PERCENTILE = 99.0
_MAXC_FLOOR = 1e-6
[docs]
@dataclass(slots=True, frozen=True)
class MacenkoParams:
"""Tuning knobs for Macenko stain-matrix fitting."""
alpha: float = 1.0
"""Angular percentile (deg) for the two stain directions; the extremes are taken at ``alpha`` / ``100 - alpha``."""
beta: float = 0.15
"""Mean-absorbance cutoff selecting tissue pixels (optical-density space)."""
def __post_init__(self) -> None:
object.__setattr__(self, "alpha", float(self.alpha))
object.__setattr__(self, "beta", float(self.beta))
if not 0.0 < self.alpha < 50.0:
raise ValueError(f"`alpha` must be in (0, 50), got {self.alpha}.")
if self.beta < 0.0:
raise ValueError(f"`beta` must be >= 0, got {self.beta}.")
[docs]
@dataclass(slots=True, frozen=True)
class VahadaneParams:
"""Tuning knobs for Vahadane (sparse-NMF) stain-matrix fitting."""
beta: float = 0.15
"""Mean-absorbance cutoff selecting tissue pixels (optical-density space)."""
lambda1: float = 0.1
"""L1 sparsity regularisation on the concentration factor of the NMF."""
n_iter: int = 200
"""Maximum NMF iterations."""
random_state: int | None = 0
"""Seed for NMF initialisation tie-breaking; fixed for reproducible fits."""
def __post_init__(self) -> None:
object.__setattr__(self, "beta", float(self.beta))
object.__setattr__(self, "lambda1", float(self.lambda1))
object.__setattr__(self, "n_iter", int(self.n_iter))
if self.beta < 0.0:
raise ValueError(f"`beta` must be >= 0, got {self.beta}.")
if self.lambda1 < 0.0:
raise ValueError(f"`lambda1` must be >= 0, got {self.lambda1}.")
if self.n_iter < 1:
raise ValueError(f"`n_iter` must be >= 1, got {self.n_iter}.")
_MACENKO_DEFAULTS = MacenkoParams()
_VAHADANE_DEFAULTS = VahadaneParams()
_MACENKO_FIELDS = frozenset(f.name for f in fields(MacenkoParams))
_VAHADANE_FIELDS = frozenset(f.name for f in fields(VahadaneParams))
def _resolve_params(params: Any, cls: type, defaults: Any, valid_fields: frozenset[str]) -> Any:
if params is None:
return defaults
if isinstance(params, cls):
return params
if isinstance(params, Mapping):
unknown = set(params) - valid_fields
if unknown:
raise ValueError(
f"Unknown `method_params` field(s): {sorted(unknown)}; expected from {sorted(valid_fields)}."
)
return cls(**params)
raise TypeError(f"`method_params` must be {cls.__name__}, Mapping, or None; got {type(params).__name__}.")
def _resolve_macenko_params(params: MacenkoParams | Mapping[str, Any] | None) -> MacenkoParams:
return _resolve_params(params, MacenkoParams, _MACENKO_DEFAULTS, _MACENKO_FIELDS)
def _resolve_vahadane_params(params: VahadaneParams | Mapping[str, Any] | None) -> VahadaneParams:
return _resolve_params(params, VahadaneParams, _VAHADANE_DEFAULTS, _VAHADANE_FIELDS)
def _tissue_od(
image_rgb: xr.DataArray,
white_point: np.ndarray,
beta: float,
*,
tissue_mask: np.ndarray | None = None,
image_key: str | None,
) -> np.ndarray:
"""Flatten tissue pixels to an ``(N, 3)`` optical-density matrix.
Reduces over the chosen scale only (bounded); the stain fits need the
tissue pixels resident for SVD/NMF, so this is the one materialising step.
When ``tissue_mask`` (a ``(y, x)`` boolean aligned to ``image_rgb``) is
given it selects the tissue pixels; otherwise the absorbance threshold
``beta`` is used.
"""
sda = rgb_to_sda(image_rgb, white_point)
mask = as_spatial_mask(tissue_mask, sda) if tissue_mask is not None else foreground_mask_from_sda(sda, beta)
od = np.asarray(sda.where(mask).transpose("c", "y", "x").data.reshape(3, -1)).T
od = od[np.all(np.isfinite(od), axis=1)]
if od.shape[0] == 0:
raise StainFittingError("no tissue pixels for stain fitting; the mask is empty.", image_key=image_key)
# Keep signed OD: pixels brighter than the estimated background carry
# negative absorbance that Macenko's SVD legitimately uses. Only Vahadane's
# NMF requires non-negativity, and clips locally.
return od
def _macenko_stain_matrix(od: np.ndarray, alpha: float) -> np.ndarray:
"""Recover a ``(3, 2)`` H/E matrix via Macenko's angular-extreme method."""
# right singular vectors of OD = principal absorbance directions through 0
_, _, vh = np.linalg.svd(od, full_matrices=False)
plane = vh[:2].T # (3, 2)
# SVD sign is arbitrary; orient the basis into the data so the projected
# angles cluster around 0 instead of straddling the atan2 branch cut at
# +-180 deg (which would collapse the angular percentiles).
signs = np.sign(od.mean(axis=0) @ plane)
signs[signs == 0] = 1.0
plane = plane * signs
proj = od @ plane # (N, 2)
phi = np.arctan2(proj[:, 1], proj[:, 0])
lo, hi = np.percentile(phi, [alpha, 100.0 - alpha])
extremes = np.stack(
[plane @ np.array([np.cos(lo), np.sin(lo)]), plane @ np.array([np.cos(hi), np.sin(hi)])],
axis=1,
)
return _unit_columns(extremes)
def _vahadane_stain_matrix(od: np.ndarray, params: VahadaneParams) -> np.ndarray:
"""Recover a ``(3, 2)`` H/E matrix via sparse NMF (Vahadane)."""
from sklearn.decomposition import NMF
nmf = NMF(
n_components=2,
init="nndsvda",
random_state=params.random_state,
alpha_W=params.lambda1,
l1_ratio=1.0,
max_iter=params.n_iter,
)
nmf.fit(np.clip(od, 0.0, None)) # NMF requires non-negative absorbance
stains = nmf.components_.T # (3, 2)
if np.any(np.linalg.norm(stains, axis=0) < 1e-8):
raise StainFittingError("Vahadane NMF produced a zero-norm stain vector.")
return _unit_columns(stains)
def _stain_matrix(
od: np.ndarray,
method: StainMethod,
params: Any,
*,
image_key: str | None,
reference: dict[str, np.ndarray] = RUIFROK_HE,
max_angle_deg: float = 45.0,
) -> np.ndarray:
"""Fit, canonicalise, complete and validate a ``(3, 3)`` stain matrix.
``reference`` (the canonical H/E vectors) drives both the column ordering and
the deviation gate; ``max_angle_deg`` is the gate tolerance.
"""
raw = _macenko_stain_matrix(od, params.alpha) if method == "macenko" else _vahadane_stain_matrix(od, params)
matrix = complement_third_column(reorder_to_canonical(raw, reference))
validate_stain_matrix(matrix, reference=reference, max_angle_deg=max_angle_deg, image_key=image_key)
return matrix
def _concentrations(od: np.ndarray, stain_matrix: np.ndarray) -> np.ndarray:
"""Per-pixel stain concentrations ``(N, 3)`` from optical density."""
return od @ np.linalg.pinv(stain_matrix).T
def _max_concentrations(concentrations: np.ndarray) -> np.ndarray:
"""Robust per-stain (H, E) maximum concentrations ``(2,)`` from an ``(N, 3)`` array."""
return np.maximum(np.percentile(concentrations[:, :2], _MAXC_PERCENTILE, axis=0), _MAXC_FLOOR)
def fit_decomposition(
image_rgb: xr.DataArray,
method: StainMethod,
params: Any,
white_point: np.ndarray,
*,
tissue_mask: np.ndarray | None = None,
image_key: str | None = None,
reference: dict[str, np.ndarray] = RUIFROK_HE,
max_angle_deg: float = 45.0,
) -> StainReference:
"""Fit a decomposition :class:`StainReference` (stain matrix + max concentrations)."""
od = _tissue_od(image_rgb, white_point, params.beta, tissue_mask=tissue_mask, image_key=image_key)
matrix = _stain_matrix(od, method, params, image_key=image_key, reference=reference, max_angle_deg=max_angle_deg)
return StainReference(
method=method,
stain_matrix=matrix,
white_point=np.asarray(white_point, dtype=np.float64),
max_concentrations=_max_concentrations(_concentrations(od, matrix)),
)
def _matmul_kernel(x: np.ndarray, *, matrix: np.ndarray, dtype: np.dtype) -> np.ndarray:
return (x.astype(dtype, copy=False) @ matrix.T).astype(dtype, copy=False)
def apply_decomposition(
image_rgb: xr.DataArray,
reference: StainReference,
params: Any,
*,
fit_rgb: xr.DataArray | None = None,
tissue_mask: np.ndarray | None = None,
out_dtype: np.dtype | type = np.uint8,
) -> xr.DataArray:
"""Normalize a source image to a decomposition reference (colour-basis transfer).
Deconvolves with the *source's* own stain matrix and reconvolves with the
reference matrix: only the stain *colour basis* changes, the *amount*
(concentration magnitude) is untouched (HistomicsTK's
``deconvolution_based_normalization``). Use ``method="reinhard"`` for
intensity/appearance matching.
The source matrix is a colour property, so it is fit on ``fit_rgb`` (a
coarse level) when given, while ``image_rgb`` (which may be full
resolution) is only ever touched by the lazy operator - never
materialised to fit a matrix.
"""
_check_channel_dim(image_rgb)
bg = reference.white_point
od_src = _tissue_od(
fit_rgb if fit_rgb is not None else image_rgb, bg, params.beta, tissue_mask=tissue_mask, image_key=None
)
w_src = _stain_matrix(od_src, reference.method, params, image_key=None)
operator = reference.stain_matrix @ np.linalg.pinv(w_src)
sda = rgb_to_sda(image_rgb, bg)
dtype = _working_dtype(sda)
sda_out = _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=operator.astype(dtype), dtype=dtype)
return sda_to_rgb(sda_out, bg, out_dtype=out_dtype)
def decompose_to_concentrations(
image_rgb: xr.DataArray, stain_matrix: np.ndarray, white_point: np.ndarray
) -> xr.DataArray:
"""Project an image onto a stain matrix, returning a 3-channel concentration image.
Channels are ``(hematoxylin, eosin, residual)``; the residual is the
concentration along the complement vector and is a diagnostic, not a stain.
"""
_check_channel_dim(image_rgb)
sda = rgb_to_sda(image_rgb, white_point)
dtype = _working_dtype(sda)
pinv = np.linalg.pinv(stain_matrix)
return _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=pinv.astype(dtype), dtype=dtype)