Source code for squidpy.im._segment

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable

import dask.array as da
import numpy as np
from scanpy import logging as logg
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.filters import threshold_otsu
from skimage.segmentation import watershed

from squidpy._constants._constants import SegmentationBackend
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, singledispatchmethod
from squidpy.im._container import ImageContainer

__all__ = ["SegmentationModel", "SegmentationWatershed", "SegmentationCustom"]
_SEG_DTYPE = np.uint32
_SEG_DTYPE_N_BITS = _SEG_DTYPE(0).nbytes * 8


class SegmentationModel(ABC):
    """
    Base class for all segmentation models.

    Contains core shared functions related to cell and nuclei segmentation.
    Specific segmentation models can be implemented by inheriting from this class.

    Parameters
    ----------
    model
        Underlying segmentation model.
    """

    def __init__(
        self,
        model: Any,
    ):
        self._model = model

    @singledispatchmethod
    @d.get_full_description(base="segment")
    @d.get_sections(base="segment", sections=["Parameters", "Returns"])
    @d.dedent
    def segment(self, img: NDArrayA | ImageContainer, **kwargs: Any) -> NDArrayA | ImageContainer:
        """
        Segment an image.

        Parameters
        ----------
        %(img_container)s
        %(img_layer)s
            Only used when ``img`` is :class:`squidpy.im.ImageContainer`.
        kwargs
            Keyword arguments for the underlying ``model``.

        Returns
        -------
        Segmentation mask for the high-resolution image of shape ``(height, width, z, 1)``.

        Raises
        ------
        ValueError
            If the number of dimensions is neither 2 nor 3.
        NotImplementedError
            If trying to segment a type for which the segmentation has not been registered.
        """
        raise NotImplementedError(f"Segmentation of `{type(img).__name__}` is not yet implemented.")

    @staticmethod
    def _precondition(img: NDArrayA | da.Array) -> NDArrayA | da.Array:
        if img.ndim == 2:
            img = img[:, :, np.newaxis]
        if img.ndim != 3:
            raise ValueError(f"Expected `2` or `3` dimensions, found `{img.ndim}`.")

        return img

    @staticmethod
    def _postcondition(img: NDArrayA | da.Array) -> NDArrayA | da.Array:
        if img.ndim == 2:
            img = img[..., np.newaxis]
        if img.ndim != 3:
            raise ValueError(f"Expected segmentation to return `2` or `3` dimensional array, found `{img.ndim}`.")
        if not np.issubdtype(img.dtype, np.integer):
            raise TypeError(f"Expected segmentation to be of integer type, found `{img.dtype}`.")

        return img.astype(_SEG_DTYPE)

    @segment.register(np.ndarray)
    def _(self, img: NDArrayA, **kwargs: Any) -> NDArrayA:
        chunks = kwargs.pop("chunks", None)
        if chunks is not None:
            return self.segment(da.asarray(img).rechunk(chunks), **kwargs)  # type: ignore[no-any-return]

        img = SegmentationModel._precondition(img)
        img = self._segment(img, **kwargs)
        return SegmentationModel._postcondition(img)

    @segment.register(da.Array)
    def _(self, img: da.Array, chunks: str | int | tuple[int, ...] | None = None, **kwargs: Any) -> NDArrayA:
        img = SegmentationModel._precondition(img)
        if chunks is not None:
            img = img.rechunk(chunks)

        shift = int(np.prod(img.numblocks) - 1).bit_length()
        kwargs.setdefault("depth", {0: 30, 1: 30})
        kwargs.setdefault("boundary", "reflect")

        img = da.map_overlap(
            self._segment_chunk,
            img,
            dtype=_SEG_DTYPE,
            num_blocks=img.numblocks,
            shift=shift,
            drop_axis=img.ndim - 1,  # y, x, z, c; -1 seems to be bugged
            **kwargs,
        )
        from dask_image.ndmeasure._utils._label import (
            connected_components_delayed,
            label_adjacency_graph,
            relabel_blocks,
        )

        # max because labels are not continuous (and won't be continuous)
        label_groups = label_adjacency_graph(img, None, img.max())
        new_labeling = connected_components_delayed(label_groups)
        relabeled = relabel_blocks(img, new_labeling)

        return SegmentationModel._postcondition(relabeled)

    @segment.register(ImageContainer)
    def _(
        self,
        img: ImageContainer,
        layer: str,
        library_id: str | Sequence[str],
        channel: int | None = None,
        fn_kwargs: Mapping[str, Any] = MappingProxyType({}),
        **kwargs: Any,
    ) -> ImageContainer:
        channel_dim = img[layer].dims[-1]
        if img[layer].shape[-1] == 1:
            new_channel_dim = channel_dim
        else:
            new_channel_dim = f"{channel_dim}:{'all' if channel is None else channel}"

        _ = kwargs.pop("copy", None)
        # TODO(michalk8): allow volumetric segmentation? (precondition/postcondition needs change)
        if isinstance(library_id, str):
            func = {library_id: self.segment}
        elif isinstance(library_id, Sequence):
            func = {lid: self.segment for lid in library_id}
        else:
            raise TypeError(
                f"Expected library id to be `None` or of type `str` or `sequence`, found `{type(library_id).__name__}`."
            )

        res: ImageContainer = img.apply(func, layer=layer, channel=channel, fn_kwargs=fn_kwargs, copy=True, **kwargs)
        res._data = res.data.rename({channel_dim: new_channel_dim})

        for k in res:
            res[k].attrs["segmentation"] = True

        return res

    @abstractmethod
    def _segment(self, arr: NDArrayA, **kwargs: Any) -> NDArrayA:
        pass

    def _segment_chunk(
        self,
        block: NDArrayA,
        block_id: tuple[int, ...],
        num_blocks: tuple[int, ...],
        shift: int,
        **kwargs: Any,
    ) -> NDArrayA:
        if len(num_blocks) == 2:
            block_num = block_id[0] * num_blocks[1] + block_id[1]
        elif len(num_blocks) == 3:
            block_num = block_id[0] * (num_blocks[1] * num_blocks[2]) + block_id[1] * num_blocks[2]
        elif len(num_blocks) == 4:
            if num_blocks[-1] != 1:
                raise ValueError(
                    f"Expected the number of blocks in the Z-dimension to be `1`, found `{num_blocks[-1]}`."
                )
            block_num = block_id[0] * (num_blocks[1] * num_blocks[2]) + block_id[1] * num_blocks[2]
        else:
            raise ValueError(f"Expected either `2`, `3` or `4` dimensional chunks, found `{len(num_blocks)}`.")

        labels = self._segment(block, **kwargs).astype(_SEG_DTYPE)
        mask: NDArrayA = labels > 0
        labels[mask] = (labels[mask] << shift) | block_num

        return labels

    def __repr__(self) -> str:
        return self.__class__.__name__

    def __str__(self) -> str:
        return repr(self)


[docs] class SegmentationWatershed(SegmentationModel): """Segmentation model based on :mod:`skimage` watershed segmentation.""" def __init__(self) -> None: super().__init__(model=None) def _segment( self, arr: NDArrayA, thresh: float | None = None, geq: bool = True, **kwargs: Any, ) -> NDArrayA | da.Array: arr = arr.squeeze(-1) # we always pass a 3D image if thresh is None: thresh = threshold_otsu(arr) mask: NDArrayA = (arr >= thresh) if geq else (arr < thresh) distance = ndi.distance_transform_edt(mask) coords = peak_local_max(distance, footprint=np.ones((5, 5)), labels=mask) local_maxi = np.zeros(distance.shape, dtype=np.bool_) local_maxi[tuple(coords.T)] = True markers, _ = ndi.label(local_maxi) return np.asarray(watershed(-distance, markers, mask=mask))
[docs] class SegmentationCustom(SegmentationModel): """ Segmentation model based on a user-defined function. Parameters ---------- func Segmentation function to use. Can be any :func:`callable`, as long as it has the following signature: :class:`numpy.ndarray` ``(height, width, channels)`` **->** :class:`numpy.ndarray` ``(height, width[, 1])``. The segmentation must be of :class:`numpy.uint32` type, where 0 marks background. """ def __init__(self, func: Callable[..., NDArrayA]): if not callable(func): raise TypeError() super().__init__(model=func) def _segment(self, arr: NDArrayA, **kwargs: Any) -> NDArrayA: return np.asarray(self._model(arr, **kwargs)) def __repr__(self) -> str: return f"{self.__class__.__name__}[function={getattr(self._model, '__name__', None)}]" def __str__(self) -> str: return repr(self)
[docs] @d.dedent @inject_docs(m=SegmentationBackend) def segment( img: ImageContainer, layer: str | None = None, library_id: str | Sequence[str] | None = None, method: str | SegmentationModel | Callable[..., NDArrayA] = "watershed", channel: int | None = 0, chunks: str | int | tuple[int, int] | None = None, lazy: bool = False, layer_added: str | None = None, copy: bool = False, **kwargs: Any, ) -> ImageContainer | None: """ Segment an image. Parameters ---------- %(img_container)s %(img_layer)s %(library_id)s If `None`, all Z-dimensions are segmented separately. method Segmentation method to use. Valid options are: - `{m.WATERSHED.s!r}` - :func:`skimage.segmentation.watershed`. %(custom_fn)s channel Channel index to use for segmentation. If `None`, use all channels. %(chunks_lazy)s %(layer_added)s If `None`, use ``'segmented_{{model}}'``. thresh Threshold for creation of masked image. The areas to segment should be contained in this mask. If `None`, it is determined by `Otsu's method <https://en.wikipedia.org/wiki/Otsu%27s_method>`_. Only used if ``method = {m.WATERSHED.s!r}``. geq Treat ``thresh`` as upper or lower bound for defining areas to segment. If ``geq = True``, mask is defined as ``mask = arr >= thresh``, meaning high values in ``arr`` denote areas to segment. Only used if ``method = {m.WATERSHED.s!r}``. %(copy_cont)s %(segment_kwargs)s Returns ------- If ``copy = True``, returns a new container with the segmented image in ``'{{layer_added}}'``. Otherwise, modifies the ``img`` with the following key: - :class:`squidpy.im.ImageContainer` ``['{{layer_added}}']`` - the segmented image. """ layer = img._get_layer(layer) kind = SegmentationBackend.CUSTOM if callable(method) else SegmentationBackend(method) layer_new = Key.img.segment(kind, layer_added=layer_added) kwargs["chunks"] = chunks library_id = img._get_library_ids(library_id) if not isinstance(method, SegmentationModel): if kind == SegmentationBackend.WATERSHED: if channel is None and img[layer].shape[-1] > 1: raise ValueError("Watershed segmentation does not work with multiple channels.") method: SegmentationModel = SegmentationWatershed() # type: ignore[no-redef] elif kind == SegmentationBackend.CUSTOM: if not callable(method): raise TypeError(f"Expected `method` to be a callable, found `{type(method)}`.") method = SegmentationCustom(func=method) else: raise NotImplementedError(f"Model `{kind}` is not yet implemented.") if TYPE_CHECKING: assert isinstance(method, SegmentationModel) start = logg.info(f"Segmenting an image of shape `{img[layer].shape}` using `{method}`") res: ImageContainer = method.segment( img, layer=layer, channel=channel, library_id=library_id, chunks=None, fn_kwargs=kwargs, copy=True, drop=copy, lazy=lazy, ) logg.info("Finish", time=start) if copy: return res.rename(layer, layer_new) img.add_img( res[layer], layer=layer_new, copy=False, lazy=lazy, dims=res[layer].dims, library_id=res[layer].coords["z"].values, )