Source code for squidpy.pl._var_by_distance

from __future__ import annotations

from collections.abc import Mapping, Sequence
from pathlib import Path
from types import MappingProxyType
from typing import Any

import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from anndata import AnnData
from cycler import Cycler
from matplotlib import rcParams
from matplotlib.axes import Axes
from scipy.sparse import issparse

from squidpy._compat import panel_grid, set_default_colors_for_categorical_obs
from squidpy._docs import d
from squidpy.pl._utils import save_fig

__all__ = ["var_by_distance"]


[docs] @d.dedent def var_by_distance( adata: AnnData, var: str | list[str], anchor_key: str | list[str], design_matrix_key: str = "design_matrix", stack_vars: bool = False, covariate: str | None = None, order: int = 5, show_scatter: bool = True, color: str | None = None, line_palette: str | Sequence[str] | Cycler | None = None, scatter_palette: str | Sequence[str] | Cycler | None = "viridis", dpi: int | None = None, figsize: tuple[int, int] | None = None, save: str | Path | None = None, title: str | None = None, axis_label: str | None = None, return_ax: bool | None = None, regplot_kwargs: Mapping[str, Any] = MappingProxyType({}), scatterplot_kwargs: Mapping[str, Any] = MappingProxyType({}), ) -> Axes | None: """ Plot a variable using a smooth regression line with increasing distance to an anchor point. Parameters ---------- %(adata)s var Variables to plot on y-axis. anchor_key Anchor point column from which distances are taken. design_matrix_key Name of the design matrix, previously computed with :func:`squidpy.tl.var_by_distance`, to use. stack_vars Whether to show multiple variables on the same plot. Only works if 'covariate' is not specified. covariate A covariate for which separate regression lines are plotted for each category. order Order of the polynomial fit for :func:`seaborn.regplot`. show_scatter Whether to show a scatter plot underlying the regression line. color Variables in `adata.obs` to plot if 'show_scatter==True'. line_palette Categorical color palette used in case a covariate is specified. scatter_palette Color palette for the scatter plot underlying the :func:`seaborn.regplot`. dpi Dots per inch. figsize Size of the figure in inches. save Whether to save the plot. title Panel titles. axis_label Panel axis labels. return_ax Whether to return :class:`matplotlib.axes.Axes` object(s). regplot_kwargs Kwargs for :func:`seaborn.regplot`. scatterplot_kwargs Kwargs for :func:`matplotlib.pyplot.scatter`. Returns ------- %(plotting_returns)s """ dpi = rcParams["figure.dpi"] if dpi is None else dpi regplot_kwargs = dict(regplot_kwargs) scatterplot_kwargs = dict(scatterplot_kwargs) # if several variables are plotted, make a panel grid if isinstance(var, list) and not stack_vars: fig, grid = panel_grid( hspace=0.25, wspace=0.75 / rcParams["figure.figsize"][0] + 0.02, ncols=4, num_panels=len(var), ) elif isinstance(var, list) and stack_vars: var = var elif isinstance(var, str): var = [var] axs = [] df = adata.obsm[design_matrix_key] # get design matrix # add var column to design matrix for name in var: if name in adata.var_names: df[name] = ( np.array(adata[:, name].X.toarray()) if issparse(adata[:, name].X) else np.array(adata[:, name].X) ) elif name in adata.obs: df[name] = adata.obs[name].values else: raise ValueError(f"Variable {name} not found in `adata.var` or `adata.obs`.") if covariate is None and stack_vars: fig, ax = plt.subplots(1, 1, figsize=figsize) if isinstance(line_palette, str) or line_palette is None: line_palette = sns.color_palette("bright", len(var)) for i, v in enumerate(var): sns.regplot( data=df, x=anchor_key, y=v, label=v, order=order, color=line_palette[i], scatter=False, ax=ax, line_kws=regplot_kwargs, ) ax.legend(title=None) ax.set(ylabel="var") if title is not None: ax.set(title=title) if axis_label is None: ax.set(xlabel=f"distance to {anchor_key}") else: ax.set(xlabel=axis_label) else: # iterate over the variables to plot for i, v in enumerate(var): if len(var) > 1 and not stack_vars: ax = plt.subplot(grid[i]) axs.append(ax) else: # if a single variable and no covariate, then one ax object suffices fig, ax = plt.subplots(1, 1, figsize=figsize) # if no covariate is specified, 'sns.regplot' will take the values of all observations if covariate is None and not stack_vars: sns.regplot( data=df, x=anchor_key, y=v, order=order, color=line_palette, scatter=show_scatter, ax=ax, line_kws=regplot_kwargs, ) elif covariate is not None and not stack_vars: # make a categorical color palette if none was specified and there are several regplots to be plotted if isinstance(line_palette, str) or line_palette is None: set_default_colors_for_categorical_obs(adata, covariate) line_palette = adata.uns[covariate + "_colors"] covariate_instances = df[covariate].unique() # iterate over all covariate values and make 'sns.regplot' for each for i, co in enumerate(covariate_instances): sns.regplot( data=df.loc[df[covariate] == co], x=anchor_key, y=v, order=order, color=line_palette[i], scatter=show_scatter, ax=ax, label=co, line_kws=regplot_kwargs, ) label_colors, _ = ax.get_legend_handles_labels() ax.legend(label_colors, covariate_instances) else: raise ValueError("Can't stack variables and plot covariate at the same time.") # add scatter plot if specified if show_scatter: if color is None: plt.scatter(data=df, x=anchor_key, y=v, color="grey", **scatterplot_kwargs) # if variable to plot on color palette is categorical, make categorical color palette elif df[color].dtype.name == "category": unique_colors = df[color].unique() cNorm = colors.Normalize(vmin=0, vmax=len(unique_colors)) scalarMap = cm.ScalarMappable(norm=cNorm, cmap=scatter_palette) for i in range(len(unique_colors)): plt.scatter( data=df.loc[df[color] == unique_colors[i]], x=anchor_key, y=v, color=scalarMap.to_rgba(i), **scatterplot_kwargs, ) # if variable to plot on color palette is not categorical else: plt.scatter( data=df, x=anchor_key, y=v, c=color, cmap=scatter_palette, **scatterplot_kwargs, ) if title is not None: ax.set(title=title) if axis_label is None: ax.set(xlabel=f"distance to {anchor_key}") else: ax.set(xlabel=axis_label) # remove line palette if it was made earlier in function if f"{covariate}_colors" in adata.uns: del line_palette axs = axs if len(var) and not stack_vars > 1 else ax if save is not None: save_fig(fig, path=save, transparent=False, dpi=dpi) if return_ax: return axs