Source code for squidpy.pl._ligrec

from __future__ import annotations

from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.colorbar import ColorbarBase
from scanpy import logging as logg
from scipy.cluster import hierarchy as sch

from squidpy._constants._constants import DendrogramAxis
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d
from squidpy._utils import _unique_order_preserving, verbosity
from squidpy.pl._utils import _dendrogram, _filter_kwargs, save_fig

__all__ = ["ligrec"]

_SEP = " | "  # cluster separator


class CustomDotplot(sc.pl.DotPlot):
    BASE = 10

    DEFAULT_LARGEST_DOT = 50.0
    DEFAULT_NUM_COLORBAR_TICKS = 5
    DEFAULT_NUM_LEGEND_DOTS = 5

    def __init__(self, minn: float, delta: float, alpha: float | None, *args: Any, **kwargs: Any):
        super().__init__(*args, **kwargs)
        self._delta = delta
        self._minn = minn
        self._alpha = alpha

    def _plot_size_legend(self, size_legend_ax: Axes) -> None:
        y = self.BASE ** -((self.dot_max * self._delta) + self._minn)
        x = self.BASE ** -((self.dot_min * self._delta) + self._minn)
        size_range = -(np.logspace(x, y, self.DEFAULT_NUM_LEGEND_DOTS + 1, base=10).astype(np.float64))
        size_range = (size_range - np.min(size_range)) / (np.max(size_range) - np.min(size_range))
        # no point in showing dot of size 0
        size_range = size_range[1:]

        size = size_range**self.size_exponent
        mult = (self.largest_dot - self.smallest_dot) + self.smallest_dot
        size = size * mult

        # plot size bar
        size_legend_ax.scatter(
            np.arange(len(size)) + 0.5,
            np.repeat(1, len(size)),
            s=size,
            color="black",
            edgecolor="black",
            linewidth=self.dot_edge_lw,
            zorder=100,
        )
        size_legend_ax.set_xticks(np.arange(len(size)) + 0.5)
        labels = [f"{(x * self._delta) + self._minn:.1f}" for x in size_range]
        size_legend_ax.set_xticklabels(labels, fontsize="small")

        # remove y ticks and labels
        size_legend_ax.tick_params(axis="y", left=False, labelleft=False, labelright=False)
        # remove surrounding lines
        for direction in ["right", "top", "left", "bottom"]:
            size_legend_ax.spines[direction].set_visible(False)

        ymax = size_legend_ax.get_ylim()[1]
        size_legend_ax.set_ylim(-1.05 - self.largest_dot * 0.003, 4)
        size_legend_ax.set_title(self.size_title, y=ymax + 0.25, size="small")

        xmin, xmax = size_legend_ax.get_xlim()
        size_legend_ax.set_xlim(xmin - 0.15, xmax + 0.5)

        if self._alpha is not None:
            ax = self.fig.add_subplot()
            ax.scatter(
                [0.35, 0.65],
                [0, 0],
                s=size[-1],
                color="black",
                edgecolor="black",
                linewidth=self.dot_edge_lw,
                zorder=100,
            )
            ax.scatter(
                [0.65], [0], s=0.33 * mult, color="white", edgecolor="black", linewidth=self.dot_edge_lw, zorder=100
            )
            ax.set_xlim([0, 1])
            ax.set_xticks([0.35, 0.65])
            ax.set_xticklabels(["false", "true"])
            ax.set_yticks([])
            ax.set_title(f"significant\n$p={self._alpha}$", y=ymax + 0.25, size="small")
            ax.set(frame_on=False)

            l, b, w, h = size_legend_ax.get_position().bounds
            ax.set_position([l + w, b, w, h])

    def _plot_colorbar(self, color_legend_ax: Axes, normalize: bool) -> None:
        cmap = plt.colormaps[self.cmap]

        ColorbarBase(
            color_legend_ax,
            orientation="horizontal",
            cmap=cmap,
            norm=normalize,
            ticks=np.linspace(
                np.nanmin(self.dot_color_df.values),
                np.nanmax(self.dot_color_df.values),
                self.DEFAULT_NUM_COLORBAR_TICKS,
            ),
            format="%.2f",
        )

        color_legend_ax.set_title(self.color_legend_title, fontsize="small")
        color_legend_ax.xaxis.set_tick_params(labelsize="small")


[docs] @d.dedent def ligrec( adata: AnnData | Mapping[str, pd.DataFrame], cluster_key: str | None = None, source_groups: str | Sequence[str] | None = None, target_groups: str | Sequence[str] | None = None, means_range: tuple[float, float] = (-np.inf, np.inf), pvalue_threshold: float = 1.0, remove_empty_interactions: bool = True, remove_nonsig_interactions: bool = False, dendrogram: str | None = None, alpha: float | None = 0.001, swap_axes: bool = False, title: str | None = None, figsize: tuple[float, float] | None = None, dpi: int | None = None, save: str | Path | None = None, **kwargs: Any, ) -> None: """ Plot the result of a receptor-ligand permutation test. The result was computed by :func:`squidpy.gr.ligrec`. :math:`molecule_1` belongs to the source clusters displayed on the top (or on the right, if ``swap_axes = True``, whereas :math:`molecule_2` belongs to the target clusters. Parameters ---------- %(adata)s It can also be a :class:`dict`, as returned by :func:`squidpy.gr.ligrec`. %(cluster_key)s Only used when ``adata`` is of type :class:`AnnData`. source_groups Source interaction clusters. If `None`, select all clusters. target_groups Target interaction clusters. If `None`, select all clusters. means_range Only show interactions whose means are within this **closed** interval. pvalue_threshold Only show interactions with p-value <= ``pvalue_threshold``. remove_empty_interactions Remove rows and columns that only contain interactions with `NaN` values. remove_nonsig_interactions Remove rows and columns that only contain interactions that are larger than ``alpha``. dendrogram How to cluster based on the p-values. Valid options are: - `None` - do not perform clustering. - `'interacting_molecules'` - cluster the interacting molecules. - `'interacting_clusters'` - cluster the interacting clusters. - `'both'` - cluster both rows and columns. Note that in this case, the dendrogram is not shown. alpha Significance threshold. All elements with p-values <= ``alpha`` will be marked by tori instead of dots. swap_axes Whether to show the cluster combinations as rows and the interacting pairs as columns. title Title of the plot. %(plotting_save)s kwargs Keyword arguments for :meth:`scanpy.pl.DotPlot.style` or :meth:`scanpy.pl.DotPlot.legend`. Returns ------- %(plotting_returns)s """ def filter_values( pvals: pd.DataFrame, means: pd.DataFrame, *, mask: pd.DataFrame, kind: str ) -> tuple[pd.DataFrame, pd.DataFrame]: mask_rows = mask.any(axis=1) pvals = pvals.loc[mask_rows] means = means.loc[mask_rows] if pvals.empty: raise ValueError(f"After removing rows with only {kind} interactions, none remain.") mask_cols = mask.any(axis=0) pvals = pvals.loc[:, mask_cols] means = means.loc[:, mask_cols] if pvals.empty: raise ValueError(f"After removing columns with only {kind} interactions, none remain.") return pvals, means def get_dendrogram(adata: AnnData, linkage: str = "complete") -> Mapping[str, Any]: z_var = sch.linkage( adata.X, metric="correlation", method=linkage, optimal_ordering=adata.n_obs <= 1500, # matplotlib will most likely give up first ) dendro_info = sch.dendrogram(z_var, labels=adata.obs_names.values, no_plot=True) # this is what the DotPlot requires return { "linkage": z_var, "groupby": ["groups"], "cor_method": "pearson", "use_rep": None, "linkage_method": linkage, "categories_ordered": dendro_info["ivl"], "categories_idx_ordered": dendro_info["leaves"], "dendrogram_info": dendro_info, } if dendrogram is not None: dendrogram = DendrogramAxis(dendrogram) # type: ignore[assignment] if TYPE_CHECKING: assert isinstance(dendrogram, DendrogramAxis) if isinstance(adata, AnnData): if cluster_key is None: raise ValueError("Please provide `cluster_key` when supplying an `AnnData` object.") cluster_key = Key.uns.ligrec(cluster_key) if cluster_key not in adata.uns_keys(): raise KeyError(f"Key `{cluster_key}` not found in `adata.uns`.") adata = adata.uns[cluster_key] if not isinstance(adata, dict): raise TypeError( f"Expected `adata` to be either of type `anndata.AnnData` or `dict`, found `{type(adata).__name__}`." ) if len(means_range) != 2: raise ValueError(f"Expected `means_range` to be a sequence of size `2`, found `{len(means_range)}`.") means_range = tuple(sorted(means_range)) # type: ignore[assignment] if alpha is not None and not (0 <= alpha <= 1): raise ValueError(f"Expected `alpha` to be in range `[0, 1]`, found `{alpha}`.") if source_groups is None: source_groups = adata["pvalues"].columns.get_level_values(0) elif isinstance(source_groups, str): source_groups = (source_groups,) if target_groups is None: target_groups = adata["pvalues"].columns.get_level_values(1) if isinstance(target_groups, str): target_groups = (target_groups,) if not isinstance(adata, AnnData): for s in source_groups: if s not in adata["means"].columns.get_level_values(0): raise ValueError(f"Invalid cluster in source group: {s}.") for t in target_groups: if t not in adata["means"].columns.get_level_values(1): raise ValueError(f"Invalid cluster in target group: {t}.") if title is None: title = "Receptor-ligand test" source_groups, _ = _unique_order_preserving(source_groups) # type: ignore[assignment] target_groups, _ = _unique_order_preserving(target_groups) # type: ignore[assignment] pvals: pd.DataFrame = adata["pvalues"].loc[:, (source_groups, target_groups)] means: pd.DataFrame = adata["means"].loc[:, (source_groups, target_groups)] if pvals.empty: raise ValueError("No valid clusters have been selected.") means = means[(means >= means_range[0]) & (means <= means_range[1])] pvals = pvals[pvals <= pvalue_threshold] if remove_empty_interactions: pvals, means = filter_values(pvals, means, mask=~(pd.isnull(means) | pd.isnull(pvals)), kind="NaN") if remove_nonsig_interactions and alpha is not None: pvals, means = filter_values(pvals, means, mask=pvals <= alpha, kind="non-significant") start, label_ranges = 0, {} if dendrogram == DendrogramAxis.INTERACTING_CLUSTERS: # rows are now cluster combinations, not interacting pairs pvals = pvals.T means = means.T for cls, size in (pvals.T.groupby(level=0)).size().to_dict().items(): label_ranges[cls] = (start, start + size - 1) start += size label_ranges = {k: label_ranges[k] for k in sorted(label_ranges.keys())} pvals = pvals[label_ranges.keys()] pvals = -np.log10(pvals + min(1e-3, alpha if alpha is not None else 1e-3)).fillna(0) pvals.columns = map(_SEP.join, pvals.columns.to_flat_index()) pvals.index = map(_SEP.join, pvals.index.to_flat_index()) means = means[label_ranges.keys()].fillna(0) means.columns = map(_SEP.join, means.columns.to_flat_index()) means.index = map(_SEP.join, means.index.to_flat_index()) means = np.log2(means + 1) var = pd.DataFrame(pvals.columns) var = var.set_index(var.columns[0]) adata = AnnData(pvals.values, obs={"groups": pd.Categorical(pvals.index)}, var=var) adata.obs_names = pvals.index minn = np.nanmin(adata.X) delta = np.nanmax(adata.X) - minn adata.X = (adata.X - minn) / delta try: if dendrogram == DendrogramAxis.BOTH: row_order, col_order, _, _ = _dendrogram( adata.X, method="complete", metric="correlation", optimal_ordering=adata.n_obs <= 1500 ) adata = adata[row_order, :][:, col_order] pvals = pvals.iloc[row_order, :].iloc[:, col_order] means = means.iloc[row_order, :].iloc[:, col_order] elif dendrogram is not None: adata.uns["dendrogram"] = get_dendrogram(adata) except IndexError: # just in case pandas indexing fails raise except Exception as e: # noqa: BLE001 logg.warning(f"Unable to create a dendrogram. Reason: `{e}`") dendrogram = None kwargs["dot_edge_lw"] = 0 kwargs.setdefault("cmap", "viridis") kwargs.setdefault("grid", True) kwargs.pop("color_on", None) # interferes with tori dp = ( CustomDotplot( delta=delta, minn=minn, alpha=alpha, adata=adata, var_names=adata.var_names, groupby="groups", dot_color_df=means, dot_size_df=pvals, title=title, var_group_labels=None if dendrogram == DendrogramAxis.BOTH else list(label_ranges.keys()), var_group_positions=None if dendrogram == DendrogramAxis.BOTH else list(label_ranges.values()), standard_scale=None, figsize=figsize, ) .style( **_filter_kwargs(sc.pl.DotPlot.style, kwargs), ) .legend( size_title=r"$-\log_{10} ~ P$", colorbar_title=r"$log_2(\frac{molecule_1 + molecule_2}{2} + 1)$", **_filter_kwargs(sc.pl.DotPlot.legend, kwargs), ) ) if dendrogram in (DendrogramAxis.INTERACTING_MOLS, DendrogramAxis.INTERACTING_CLUSTERS): # ignore the warning about mismatching groups with verbosity(0): dp.add_dendrogram(size=1.6, dendrogram_key="dendrogram") if swap_axes: dp.swap_axes() dp.make_figure() if dendrogram != DendrogramAxis.BOTH: # remove the target part in: source | target labs = dp.ax_dict["mainplot_ax"].get_yticklabels() if swap_axes else dp.ax_dict["mainplot_ax"].get_xticklabels() for text in labs: text.set_text(text.get_text().split(_SEP)[1]) if swap_axes: dp.ax_dict["mainplot_ax"].set_yticklabels(labs) else: dp.ax_dict["mainplot_ax"].set_xticklabels(labs) if alpha is not None: yy, xx = np.where((pvals.values + alpha) >= -np.log10(alpha)) if len(xx) and len(yy): # for dendrogram='both', they are already re-ordered mapper = ( np.argsort(adata.uns["dendrogram"]["categories_idx_ordered"]) if "dendrogram" in adata.uns else np.arange(len(pvals)) ) logg.info(f"Found `{len(yy)}` significant interactions at level `{alpha}`") ss = 0.33 * (adata.X[yy, xx] * (dp.largest_dot - dp.smallest_dot) + dp.smallest_dot) # must be after ss = ..., cc = ... yy = np.array([mapper[y] for y in yy]) if swap_axes: xx, yy = yy, xx dp.ax_dict["mainplot_ax"].scatter(xx + 0.5, yy + 0.5, color="white", s=ss, lw=0) if dpi is not None: dp.fig.set_dpi(dpi) if save is not None: save_fig(dp.fig, save)