Source code for squidpy.pl._graph

"""Plotting for graph functions."""

from __future__ import annotations

from pathlib import Path
from types import MappingProxyType
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Mapping,
    Sequence,
    Union,  # noqa: F401
)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from anndata import AnnData
from matplotlib.axes import Axes

from squidpy._constants._constants import RipleyStat
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d
from squidpy.gr._utils import (
    _assert_categorical_obs,
    _assert_non_empty_sequence,
    _get_valid_values,
)
from squidpy.pl._color_utils import Palette_t, _get_palette, _maybe_set_colors
from squidpy.pl._utils import _heatmap, save_fig

__all__ = ["centrality_scores", "interaction_matrix", "nhood_enrichment", "ripley", "co_occurrence"]


def _get_data(adata: AnnData, cluster_key: str, func_name: str, attr: str = "uns", **kwargs: Any) -> Any:
    key = getattr(Key.uns, func_name)(cluster_key, **kwargs)

    try:
        if attr == "uns":
            return adata.uns[key]
        elif attr == "obsm":
            return adata.obsm[key]
        else:
            raise ValueError(f"attr must be either 'uns' or 'obsm', got {attr}")
    except KeyError:
        raise KeyError(
            f"Unable to get the data from `adata.uns[{key!r}]`. "
            f"Please run `squidpy.gr.{func_name}(..., cluster_key={cluster_key!r})` first."
        ) from None


[docs] @d.dedent def centrality_scores( adata: AnnData, cluster_key: str, score: str | Sequence[str] | None = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({}), palette: Palette_t = None, figsize: tuple[float, float] | None = None, dpi: int | None = None, save: str | Path | None = None, **kwargs: Any, ) -> None: """ Plot centrality scores. The centrality scores are computed by :func:`squidpy.gr.centrality_scores`. Parameters ---------- %(adata)s %(cluster_key)s score Whether to plot all scores or only selected ones. legend_kwargs Keyword arguments for :func:`matplotlib.pyplot.legend`. %(cat_plotting)s Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) df = _get_data(adata, cluster_key=cluster_key, func_name="centrality_scores").copy() legend_kwargs = dict(legend_kwargs) if "loc" not in legend_kwargs: legend_kwargs["loc"] = "center left" legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5)) scores = df.columns.values df[cluster_key] = df.index.values clusters = adata.obs[cluster_key].cat.categories palette = _get_palette(adata, cluster_key=cluster_key, categories=clusters, palette=palette) score = scores if score is None else score score = _assert_non_empty_sequence(score, name="centrality scores") score = sorted(_get_valid_values(score, scores)) fig, axs = plt.subplots(1, len(score), figsize=figsize, dpi=dpi, constrained_layout=True) axs = np.ravel(axs) # make into iterable for g, ax in zip(score, axs): sns.scatterplot( x=g, y=cluster_key, data=df, hue=cluster_key, hue_order=clusters, palette=palette, ax=ax, **kwargs, ) ax.set_title(str(g).replace("_", " ").capitalize()) ax.set_xlabel("value") ax.set_yticks([]) ax.legend(**legend_kwargs) if save is not None: save_fig(fig, path=save)
[docs] @d.dedent def interaction_matrix( adata: AnnData, cluster_key: str, annotate: bool = False, method: str | None = None, title: str | None = None, cmap: str = "viridis", palette: Palette_t = None, cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), figsize: tuple[float, float] | None = None, dpi: int | None = None, save: str | Path | None = None, ax: Axes | None = None, **kwargs: Any, ) -> None: """ Plot cluster interaction matrix. The interaction matrix is computed by :func:`squidpy.gr.interaction_matrix`. Parameters ---------- %(adata)s %(cluster_key)s %(heatmap_plotting)s kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) array = _get_data(adata, cluster_key=cluster_key, func_name="interaction_matrix") ad = AnnData(X=array, obs={cluster_key: pd.Categorical(adata.obs[cluster_key].cat.categories)}, dtype=array.dtype) _maybe_set_colors(source=adata, target=ad, key=cluster_key, palette=palette) if title is None: title = "Interaction matrix" fig = _heatmap( ad, key=cluster_key, title=title, method=method, cont_cmap=cmap, annotate=annotate, figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize, dpi=dpi, cbar_kwargs=cbar_kwargs, ax=ax, **kwargs, ) if save is not None: save_fig(fig, path=save)
[docs] @d.dedent def nhood_enrichment( adata: AnnData, cluster_key: str, mode: Literal["zscore", "count"] = "zscore", annotate: bool = False, method: str | None = None, title: str | None = None, cmap: str = "viridis", palette: Palette_t = None, cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), figsize: tuple[float, float] | None = None, dpi: int | None = None, save: str | Path | None = None, ax: Axes | None = None, **kwargs: Any, ) -> None: """ Plot neighborhood enrichment. The enrichment is computed by :func:`squidpy.gr.nhood_enrichment`. Parameters ---------- %(adata)s %(cluster_key)s mode Which :func:`squidpy.gr.nhood_enrichment` result to plot. Valid options are: - `'zscore'` - z-score values of enrichment statistic. - `'count'` - enrichment count. %(heatmap_plotting)s kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) array = _get_data(adata, cluster_key=cluster_key, func_name="nhood_enrichment")[mode] ad = AnnData(X=array, obs={cluster_key: pd.Categorical(adata.obs[cluster_key].cat.categories)}, dtype=array.dtype) _maybe_set_colors(source=adata, target=ad, key=cluster_key, palette=palette) if title is None: title = "Neighborhood enrichment" fig = _heatmap( ad, key=cluster_key, title=title, method=method, cont_cmap=cmap, annotate=annotate, figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize, dpi=dpi, cbar_kwargs=cbar_kwargs, ax=ax, **kwargs, ) if save is not None: save_fig(fig, path=save)
[docs] @d.dedent def ripley( adata: AnnData, cluster_key: str, mode: Literal["F", "G", "L"] = "F", plot_sims: bool = True, palette: Palette_t = None, figsize: tuple[float, float] | None = None, dpi: int | None = None, save: str | Path | None = None, ax: Axes | None = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> None: """ Plot Ripley's statistics for each cluster. The estimate is computed by :func:`squidpy.gr.ripley`. Parameters ---------- %(adata)s %(cluster_key)s mode Ripley's statistics to be plotted. plot_sims Whether to overlay simulations in the plot. %(cat_plotting)s ax Axes, :class:`matplotlib.axes.Axes`. legend_kwargs Keyword arguments for :func:`matplotlib.pyplot.legend`. kwargs Keyword arguments for :func:`seaborn.lineplot`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) res = _get_data(adata, cluster_key=cluster_key, func_name="ripley", mode=mode) mode = RipleyStat(mode) # type: ignore[assignment] if TYPE_CHECKING: assert isinstance(mode, RipleyStat) legend_kwargs = dict(legend_kwargs) if "loc" not in legend_kwargs: legend_kwargs["loc"] = "center left" legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5)) categories = adata.obs[cluster_key].cat.categories palette = _get_palette(adata, cluster_key=cluster_key, categories=categories, palette=palette) if ax is None: fig, ax = plt.subplots(figsize=figsize, dpi=dpi) else: fig = ax.figure sns.lineplot( y="stats", x="bins", hue=cluster_key, data=res[f"{mode.s}_stat"], hue_order=categories, palette=palette, ax=ax, **kwargs, ) if plot_sims: sns.lineplot(y="stats", x="bins", ci="sd", alpha=0.01, color="gray", data=res["sims_stat"], ax=ax) ax.legend(**legend_kwargs) ax.set_ylabel("value") ax.set_title(f"Ripley's {mode.s}") if save is not None: save_fig(fig, path=save)
[docs] @d.dedent def co_occurrence( adata: AnnData, cluster_key: str, palette: Palette_t = None, clusters: str | Sequence[str] | None = None, figsize: tuple[float, float] | None = None, dpi: int | None = None, save: str | Path | None = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> None: """ Plot co-occurrence probability ratio for each cluster. The co-occurrence is computed by :func:`squidpy.gr.co_occurrence`. Parameters ---------- %(adata)s %(cluster_key)s clusters Cluster instances for which to plot conditional probability. %(cat_plotting)s legend_kwargs Keyword arguments for :func:`matplotlib.pyplot.legend`. kwargs Keyword arguments for :func:`seaborn.lineplot`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) occurrence_data = _get_data(adata, cluster_key=cluster_key, func_name="co_occurrence") legend_kwargs = dict(legend_kwargs) if "loc" not in legend_kwargs: legend_kwargs["loc"] = "center left" legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5)) out = occurrence_data["occ"] interval = occurrence_data["interval"][1:] categories = adata.obs[cluster_key].cat.categories clusters = categories if clusters is None else clusters clusters = _assert_non_empty_sequence(clusters, name="clusters") clusters = sorted(_get_valid_values(clusters, categories)) palette = _get_palette(adata, cluster_key=cluster_key, categories=categories, palette=palette) fig, axs = plt.subplots( 1, len(clusters), figsize=(5 * len(clusters), 5) if figsize is None else figsize, dpi=dpi, constrained_layout=True, ) axs = np.ravel(axs) # make into iterable for g, ax in zip(clusters, axs): idx = np.where(categories == g)[0][0] df = pd.DataFrame(out[idx, :, :].T, columns=categories).melt(var_name=cluster_key, value_name="probability") df["distance"] = np.tile(interval, len(categories)) sns.lineplot( x="distance", y="probability", data=df, dashes=False, hue=cluster_key, hue_order=categories, palette=palette, ax=ax, **kwargs, ) ax.legend(**legend_kwargs) ax.set_title(rf"$\frac{{p(exp|{g})}}{{p(exp)}}$") ax.set_ylabel("value") if save is not None: save_fig(fig, path=save)