from __future__ import annotations
from types import MappingProxyType
from typing import (
Any,
Callable,
Mapping,
Sequence,
Union, # noqa: F401
)
import dask.array as da
from dask_image.ndfilters import gaussian_filter as dask_gf
from scanpy import logging as logg
from scipy.ndimage import gaussian_filter as scipy_gf
from squidpy._constants._constants import Processing
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA
from squidpy.im._container import ImageContainer
__all__ = ["process"]
[docs]
@d.dedent
@inject_docs(p=Processing)
def process(
img: ImageContainer,
layer: str | None = None,
library_id: str | Sequence[str] | None = None,
method: str | Callable[..., NDArrayA] = "smooth",
chunks: int | None = None,
lazy: bool = False,
layer_added: str | None = None,
channel_dim: str | None = None,
copy: bool = False,
apply_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> ImageContainer | None:
"""
Process an image by applying a transformation.
Parameters
----------
%(img_container)s
%(img_layer)s
%(library_id)s
If `None`, all Z-dimensions are processed at once, treating the image as a 3D volume.
method
Processing method to use. Valid options are:
- `{p.SMOOTH.s!r}` - :func:`skimage.filters.gaussian`.
- `{p.GRAY.s!r}` - :func:`skimage.color.rgb2gray`.
%(custom_fn)s
%(chunks_lazy)s
%(layer_added)s
If `None`, use ``'{{layer}}_{{method}}'``.
channel_dim
Name of the channel dimension of the new image layer. Default is the same as the original, if the
processing function does not change the number of channels, and ``'{{channel}}_{{processing}}'`` otherwise.
%(copy_cont)s
apply_kwargs
Keyword arguments for :meth:`squidpy.im.ImageContainer.apply`.
kwargs
Keyword arguments for ``method``.
Returns
-------
If ``copy = True``, returns a new container with the processed image in ``'{{layer_added}}'``.
Otherwise, modifies the ``img`` with the following key:
- :class:`squidpy.im.ImageContainer` ``['{{layer_added}}']`` - the processed image.
Raises
------
NotImplementedError
If ``method`` has not been implemented.
"""
from squidpy.pl._utils import _to_grayscale
layer = img._get_layer(layer)
method = Processing(method) if isinstance(method, (str, Processing)) else method # type: ignore[assignment]
apply_kwargs = dict(apply_kwargs)
apply_kwargs["lazy"] = lazy
if channel_dim is None:
channel_dim = str(img[layer].dims[-1])
layer_new = Key.img.process(method, layer, layer_added=layer_added)
if callable(method):
callback = method
elif method == Processing.SMOOTH: # type: ignore[comparison-overlap]
if library_id is None:
expected_ndim = 4
kwargs.setdefault("sigma", [1, 1, 0, 0]) # y, x, z, c
else:
expected_ndim = 3
kwargs.setdefault("sigma", [1, 1, 0]) # y, x, c
sigma = kwargs["sigma"]
if isinstance(sigma, int):
kwargs["sigma"] = sigma = [sigma, sigma] + [0] * (expected_ndim - 2)
if len(sigma) != expected_ndim:
raise ValueError(f"Expected `sigma` to be of length `{expected_ndim}`, found `{len(sigma)}`.")
if chunks is not None:
# dask_image already handles map_overlap
chunks_, chunks = chunks, None
callback = lambda arr, **kwargs: dask_gf(da.asarray(arr).rechunk(chunks_), **kwargs) # noqa: E731
else:
callback = scipy_gf
elif method == Processing.GRAY: # type: ignore[comparison-overlap]
apply_kwargs["drop_axis"] = 3
callback = _to_grayscale
else:
raise NotImplementedError(f"Method `{method}` is not yet implemented.")
# to which library_ids should this function be applied?
if library_id is not None:
callback = {lid: callback for lid in img._get_library_ids(library_id)} # type: ignore[assignment]
start = logg.info(f"Processing image using `{method}` method")
res: ImageContainer = img.apply(
callback, layer=layer, copy=True, drop=copy, chunks=chunks, fn_kwargs=kwargs, **apply_kwargs
)
# if the method changes the number of channels
if res[layer].shape[-1] != img[layer].shape[-1]:
modifier = "_".join(layer_new.split("_")[1:]) if layer_added is None else layer_added
channel_dim = f"{channel_dim}_{modifier}"
res._data = res.data.rename({res[layer].dims[-1]: channel_dim})
logg.info("Finish", time=start)
if copy:
return res.rename(layer, layer_new)
img.add_img(
img=res[layer],
layer=layer_new,
copy=False,
lazy=lazy,
dims=res[layer].dims,
library_id=img[layer].coords["z"].values,
)