from __future__ import annotations
import os
from collections.abc import Mapping, Sequence
from functools import wraps
from inspect import signature
from pathlib import Path
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import matplotlib as mpl
import numpy as np
import pandas as pd
from anndata import AnnData
from dask import array as da
from dask import delayed
from matplotlib import colors as mcolors
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numba import njit, prange
from pandas import CategoricalDtype
from pandas._libs.lib import infer_dtype
from pandas.core.dtypes.common import (
is_bool_dtype,
is_integer_dtype,
is_numeric_dtype,
is_object_dtype,
is_string_dtype,
)
from scanpy import logging as logg
from scanpy import settings
from scipy.cluster import hierarchy as sch
from scipy.sparse import issparse, spmatrix
from skimage import img_as_float32
from skimage.color import rgb2gray
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d
from squidpy._utils import NDArrayA
from squidpy.gr._utils import _assert_categorical_obs
Vector_name_t = tuple[Optional[Union[pd.Series, NDArrayA]], Optional[str]]
@d.dedent
def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "png", **kwargs: Any) -> None:
"""
Save a figure.
Parameters
----------
fig
Figure to save.
path
Path where to save the figure. If path is relative, save it under :attr:`scanpy.settings.figdir`.
make_dir
Whether to try making the directory if it does not exist.
ext
Extension to use if none is provided.
kwargs
Keyword arguments for :meth:`matplotlib.figure.Figure.savefig`.
Returns
-------
None
Just saves the plot.
"""
if os.path.splitext(path)[1] == "":
path = f"{path}.{ext}"
path = Path(path)
if not path.is_absolute():
path = Path(settings.figdir) / path
if make_dir:
try:
path.parent.mkdir(parents=True, exist_ok=True)
except OSError as e:
logg.debug(f"Unable to create directory `{path.parent}`. Reason: `{e}`")
logg.debug(f"Saving figure to `{path!r}`")
kwargs.setdefault("bbox_inches", "tight")
kwargs.setdefault("transparent", True)
fig.savefig(path, **kwargs)
@njit(cache=True, fastmath=True)
def _point_inside_triangles(triangles: NDArrayA) -> np.bool_:
# modified from napari
AB = triangles[:, 1, :] - triangles[:, 0, :]
AC = triangles[:, 2, :] - triangles[:, 0, :]
BC = triangles[:, 2, :] - triangles[:, 1, :]
s_AB = -AB[:, 0] * triangles[:, 0, 1] + AB[:, 1] * triangles[:, 0, 0] >= 0
s_AC = -AC[:, 0] * triangles[:, 0, 1] + AC[:, 1] * triangles[:, 0, 0] >= 0
s_BC = -BC[:, 0] * triangles[:, 1, 1] + BC[:, 1] * triangles[:, 1, 0] >= 0
return np.any((s_AB != s_AC) & (s_AB == s_BC))
@njit(parallel=True)
def _points_inside_triangles(points: NDArrayA, triangles: NDArrayA) -> NDArrayA:
out = np.empty(
len(
points,
),
dtype=np.bool_,
)
for i in prange(len(out)):
out[i] = _point_inside_triangles(triangles - points[i])
return out
def _min_max_norm(vec: spmatrix | NDArrayA) -> NDArrayA:
if issparse(vec):
if TYPE_CHECKING:
assert isinstance(vec, spmatrix)
vec = vec.toarray().squeeze()
vec = np.asarray(vec, dtype=np.float64)
if vec.ndim != 1:
raise ValueError(f"Expected `1` dimension, found `{vec.ndim}`.")
maxx, minn = np.nanmax(vec), np.nanmin(vec)
return ( # type: ignore[no-any-return]
np.ones_like(vec) if np.isclose(minn, maxx) else ((vec - minn) / (maxx - minn))
)
def _ensure_dense_vector(fn: Callable[..., Vector_name_t]) -> Callable[..., Vector_name_t]:
@wraps(fn)
def decorator(self: ALayer, *args: Any, **kwargs: Any) -> Vector_name_t:
normalize = kwargs.pop("normalize", False)
res, fmt = fn(self, *args, **kwargs)
if res is None:
return None, None
if isinstance(res, pd.Series):
if isinstance(res, CategoricalDtype):
return res, fmt
if is_string_dtype(res) or is_object_dtype(res) or is_bool_dtype(res):
return res.astype("category"), fmt
if is_integer_dtype(res):
unique = res.unique()
n_uniq = len(unique)
if n_uniq <= 2 and (set(unique) & {0, 1}):
return res.astype(bool).astype("category"), fmt
if len(unique) <= len(res) // 100:
return res.astype("category"), fmt
elif not is_numeric_dtype(res):
raise TypeError(f"Unable to process `pandas.Series` of type `{infer_dtype(res)}`.")
res = res.to_numpy()
elif issparse(res):
if TYPE_CHECKING:
assert isinstance(res, spmatrix)
res = res.toarray()
elif not isinstance(res, (np.ndarray, Sequence)):
raise TypeError(f"Unable to process result of type `{type(res).__name__}`.")
res = np.asarray(np.squeeze(res))
if res.ndim != 1:
raise ValueError(f"Expected 1-dimensional array, found `{res.ndim}`.")
return (_min_max_norm(res) if normalize else res), fmt
return decorator
def _only_not_raw(fn: Callable[..., Any | None]) -> Callable[..., Any | None]:
@wraps(fn)
def decorator(self: ALayer, *args: Any, **kwargs: Any) -> Any | None:
return None if self.raw else fn(self, *args, **kwargs)
return decorator
class ALayer:
"""
Class which helps with :attr:`anndata.AnnData.layers` logic.
Parameters
----------
%(adata)s
is_raw
Whether we want to access :attr:`anndata.AnnData.raw`.
palette
Color palette for categorical variables which don't have colors in :attr:`anndata.AnnData.uns`.
"""
VALID_ATTRIBUTES = ("obs", "var", "obsm")
def __init__(
self,
adata: AnnData,
library_ids: Sequence[str],
is_raw: bool = False,
palette: str | None = None,
):
if is_raw and adata.raw is None:
raise AttributeError("Attribute `.raw` is `None`.")
self._adata = adata
self._library_id = library_ids[0]
self._ix_to_group = dict(zip(range(len(library_ids)), library_ids))
self._layer: str | None = None
self._previous_layer: str | None = None
self._raw = is_raw
self._palette = palette
@property
def adata(self) -> AnnData:
"""The underlying annotated data object.""" # noqa: D401
return self._adata
@property
def layer(self) -> str | None:
"""Layer in :attr:`anndata.AnnData.layers`."""
return self._layer
@layer.setter
def layer(self, layer: str | None = None) -> None:
if layer not in (None,) + tuple(self.adata.layers.keys()):
raise KeyError(f"Invalid layer `{layer}`. Valid options are `{[None] + sorted(self.adata.layers.keys())}`.")
self._previous_layer = layer
# handle in raw setter
self.raw = False
@property
def raw(self) -> bool:
"""Whether to access :attr:`anndata.AnnData.raw`."""
return self._raw
@raw.setter
def raw(self, is_raw: bool) -> None:
if is_raw:
if self.adata.raw is None:
raise AttributeError("Attribute `.raw` is `None`.")
self._previous_layer = self.layer
self._layer = None
else:
self._layer = self._previous_layer
self._raw = is_raw
@property
def library_id(self) -> str:
"""Library id that is currently selected."""
return self._library_id
@library_id.setter
def library_id(self, library_id: str | int) -> None:
if isinstance(library_id, int):
library_id = self._ix_to_group[library_id]
self._library_id = library_id
@_ensure_dense_vector
def get_obs(self, name: str, **_: Any) -> tuple[pd.Series | NDArrayA | None, str]:
"""
Return an observation.
Parameters
----------
name
Key in :attr:`anndata.AnnData.obs` to access.
Returns
-------
The values and the formatted ``name``.
"""
if name not in self.adata.obs.columns:
raise KeyError(f"Key `{name}` not found in `adata.obs`.")
return self.adata.obs[name], self._format_key(name, layer_modifier=False)
@_ensure_dense_vector
def get_var(self, name: str | int, **_: Any) -> tuple[NDArrayA | None, str]:
"""
Return a gene.
Parameters
----------
name
Gene name in :attr:`anndata.AnnData.var_names` or :attr:`anndata.AnnData.raw.var_names`,
based on :paramref:`raw`.
Returns
-------
The values and the formatted ``name``.
"""
adata = self.adata.raw if self.raw else self.adata
try:
ix = adata._normalize_indices((slice(None), name))
except KeyError:
raise KeyError(f"Key `{name}` not found in `adata.{'raw.' if self.raw else ''}var_names`.") from None
return self.adata._get_X(use_raw=self.raw, layer=self.layer)[ix], self._format_key(name, layer_modifier=True)
def get_items(self, attr: str) -> tuple[str, ...]:
"""
Return valid keys for an attribute.
Parameters
----------
attr
Attribute of :mod:`anndata.AnnData` to access.
Returns
-------
The available items.
"""
adata = self.adata.raw if self.raw and attr in ("var",) else self.adata
if attr in ("obs", "obsm"):
return tuple(map(str, getattr(adata, attr).keys()))
return tuple(map(str, getattr(adata, attr).index))
@_ensure_dense_vector
def get_obsm(self, name: str, index: int | str = 0) -> tuple[NDArrayA | None, str]:
"""
Return a vector from :attr:`anndata.AnnData.obsm`.
Parameters
----------
name
Key in :attr:`anndata.AnnData.obsm`.
index
Index of the vector.
Returns
-------
The values and the formatted ``name``.
"""
if name not in self.adata.obsm:
raise KeyError(f"Unable to find key `{name!r}` in `adata.obsm`.")
res = self.adata.obsm[name]
pretty_name = self._format_key(name, layer_modifier=False, index=index)
if isinstance(res, pd.DataFrame):
try:
if isinstance(index, str):
return res[index], pretty_name
if isinstance(index, int):
return res.iloc[:, index], self._format_key(name, layer_modifier=False, index=res.columns[index])
except KeyError:
raise KeyError(f"Key `{index}` not found in `adata.obsm[{name!r}].`") from None
if not isinstance(index, int):
try:
index = int(index, base=10)
except ValueError:
raise ValueError(
f"Unable to convert `{index}` to an integer when accessing `adata.obsm[{name!r}]`."
) from None
res = np.asarray(res)
return (res if res.ndim == 1 else res[:, index]), pretty_name
def _format_key(self, key: str | int, layer_modifier: bool = False, index: int | str | None = None) -> str:
if not layer_modifier:
return str(key) + (f":{index}" if index is not None else "")
return str(key) + (":raw" if self.raw else f":{self.layer}" if self.layer is not None else "")
def __repr__(self) -> str:
return f"{self.__class__.__name__}<raw={self.raw}, layer={self.layer}>"
def __str__(self) -> str:
return repr(self)
def _contrasting_color(r: int, g: int, b: int) -> str:
for val in [r, g, b]:
assert 0 <= val <= 255, f"Color value `{val}` is not in `[0, 255]`."
return "#000000" if r * 0.299 + g * 0.587 + b * 0.114 > 186 else "#ffffff"
def _get_black_or_white(value: float, cmap: mcolors.Colormap) -> str:
if not (0.0 <= value <= 1.0):
raise ValueError(f"Value must be in range `[0, 1]`, found `{value}`.")
r, g, b, *_ = (int(c * 255) for c in cmap(value))
return _contrasting_color(r, g, b)
def _annotate_heatmap(
im: mpl.image.AxesImage, valfmt: str = "{x:.2f}", cmap: mpl.colors.Colormap | str = "viridis", **kwargs: Any
) -> None:
# modified from matplotlib's site
if isinstance(cmap, str):
cmap = plt.colormaps[cmap]
data = im.get_array()
kw = {"ha": "center", "va": "center"}
kw.update(**kwargs)
if isinstance(valfmt, str):
valfmt = mpl.ticker.StrMethodFormatter(valfmt)
if TYPE_CHECKING:
assert callable(valfmt)
for i in range(data.shape[0]):
for j in range(data.shape[1]):
val = im.norm(data[i, j])
if np.isnan(val):
continue
kw.update(color=_get_black_or_white(val, cmap))
im.axes.text(j, i, valfmt(data[i, j], None), **kw)
def _get_cmap_norm(
adata: AnnData,
key: str,
order: tuple[list[int], list[int]] | None | None = None,
) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, mcolors.BoundaryNorm, mcolors.BoundaryNorm, int]:
n_cls = adata.obs[key].nunique()
colors = adata.uns[Key.uns.colors(key)]
if order is not None:
row_order, col_order = order
row_colors = [colors[i] for i in row_order]
col_colors = [colors[i] for i in col_order]
else:
row_colors = col_colors = colors
row_cmap = mcolors.ListedColormap(row_colors)
col_cmap = mcolors.ListedColormap(col_colors)
row_norm = mcolors.BoundaryNorm(np.arange(n_cls + 1), row_cmap.N)
col_norm = mcolors.BoundaryNorm(np.arange(n_cls + 1), col_cmap.N)
return row_cmap, col_cmap, row_norm, col_norm, n_cls
def _heatmap(
adata: AnnData,
key: str,
title: str = "",
method: str | None = None,
cont_cmap: str | mcolors.Colormap = "viridis",
annotate: bool = True,
figsize: tuple[float, float] | None = None,
dpi: int | None = None,
cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
ax: Axes | None = None,
**kwargs: Any,
) -> mpl.figure.Figure:
_assert_categorical_obs(adata, key=key)
cbar_kwargs = dict(cbar_kwargs)
if ax is None:
fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize)
else:
fig = ax.figure
if method is not None:
row_order, col_order, _, col_link = _dendrogram(adata.X, method, optimal_ordering=adata.n_obs <= 1500)
else:
row_order = col_order = np.arange(len(adata.uns[Key.uns.colors(key)])).tolist()
row_order = row_order[::-1]
row_labels = adata.obs[key].iloc[row_order]
data = adata[row_order, col_order].X
row_cmap, col_cmap, row_norm, col_norm, n_cls = _get_cmap_norm(adata, key, order=(row_order, col_order))
row_sm = mpl.cm.ScalarMappable(cmap=row_cmap, norm=row_norm)
col_sm = mpl.cm.ScalarMappable(cmap=col_cmap, norm=col_norm)
norm = mpl.colors.Normalize(vmin=kwargs.pop("vmin", np.nanmin(data)), vmax=kwargs.pop("vmax", np.nanmax(data)))
if isinstance(cont_cmap, str):
cont_cmap = plt.colormaps[cont_cmap]
cont_cmap.set_bad(color="grey")
im = ax.imshow(data[::-1], cmap=cont_cmap, norm=norm)
ax.grid(False)
ax.tick_params(top=False, bottom=False, labeltop=False, labelbottom=False)
ax.set_xticks([])
ax.set_yticks([])
if annotate:
_annotate_heatmap(im, cmap=cont_cmap, **kwargs)
divider = make_axes_locatable(ax)
row_cats = divider.append_axes("left", size="2%", pad=0)
col_cats = divider.append_axes("top", size="2%", pad=0)
cax = divider.append_axes("right", size="1%", pad=0.1)
if method is not None: # cluster rows but don't plot dendrogram
col_ax = divider.append_axes("top", size="5%")
sch.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black")
col_ax.axis("off")
_ = fig.colorbar(
im,
cax=cax,
ticks=np.linspace(norm.vmin, norm.vmax, 10),
orientation="vertical",
format="%0.2f",
**cbar_kwargs,
)
# column labels colorbar
c = fig.colorbar(col_sm, cax=col_cats, orientation="horizontal")
c.set_ticks([])
(col_cats if method is None else col_ax).set_title(title)
# row labels colorbar
c = fig.colorbar(row_sm, cax=row_cats, orientation="vertical", ticklocation="left")
c.set_ticks(np.arange(n_cls) + 0.5)
c.set_ticklabels(row_labels)
c.set_label(key)
return fig
def _filter_kwargs(func: Callable[..., Any], kwargs: Mapping[str, Any]) -> dict[str, Any]:
style_args = {k for k in signature(func).parameters.keys()} # noqa: C416
return {k: v for k, v in kwargs.items() if k in style_args}
def _dendrogram(data: NDArrayA, method: str, **kwargs: Any) -> tuple[list[int], list[int], list[int], list[int]]:
link_kwargs = _filter_kwargs(sch.linkage, kwargs)
dendro_kwargs = _filter_kwargs(sch.dendrogram, kwargs)
# Row-cluster
row_link = sch.linkage(data, method=method, **link_kwargs)
row_dendro = sch.dendrogram(row_link, no_plot=True, **dendro_kwargs)
row_order = row_dendro["leaves"]
# Column-cluster
col_link = sch.linkage(data.T, method=method, **link_kwargs)
col_dendro = sch.dendrogram(col_link, no_plot=True, **dendro_kwargs)
col_order = col_dendro["leaves"]
return row_order, col_order, row_link, col_link
def sanitize_anndata(adata: AnnData) -> None:
"""Transform string annotations to categoricals."""
adata._sanitize()
def _assert_value_in_obs(adata: AnnData, key: str, val: Sequence[Any] | Any) -> None:
if key not in adata.obs:
raise KeyError(f"Key `{key}` not found in `adata.obs`.")
if not isinstance(val, list):
val = [val]
val = set(val) - set(adata.obs[key].unique())
if len(val) != 0:
raise ValueError(f"Values `{val}` not found in `adata.obs[{key}]`.")
def _to_grayscale(img: NDArrayA | da.Array) -> NDArrayA | da.Array:
if img.shape[-1] != 3:
raise ValueError(f"Expected channel dimension to be `3`, found `{img.shape[-1]}`.")
if isinstance(img, da.Array):
img = da.from_delayed(delayed(img_as_float32)(img), shape=img.shape, dtype=np.float32)
coeffs = np.array([0.2125, 0.7154, 0.0721], dtype=img.dtype)
return img @ coeffs
return rgb2gray(img)