Source code for squidpy.pl._var_by_distance

from __future__ import annotations

from collections.abc import Mapping, Sequence
from pathlib import Path
from types import MappingProxyType
from typing import Any

import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from anndata import AnnData
from cycler import Cycler
from matplotlib import rcParams
from matplotlib.axes import Axes
from scanpy.plotting._tools.scatterplots import _panel_grid
from scanpy.plotting._utils import _set_default_colors_for_categorical_obs
from scipy.sparse import issparse

from squidpy._docs import d
from squidpy.pl._utils import save_fig

__all__ = ["var_by_distance"]


[docs] @d.dedent def var_by_distance( adata: AnnData, var: str | list[str], anchor_key: str | list[str], design_matrix_key: str = "design_matrix", color: str | None = None, covariate: str | None = None, order: int = 5, show_scatter: bool = True, line_palette: str | Sequence[str] | Cycler | None = None, scatter_palette: str | Sequence[str] | Cycler | None = "viridis", dpi: int | None = None, figsize: tuple[int, int] | None = None, save: str | Path | None = None, title: str | None = None, axis_label: str | None = None, return_ax: bool | None = None, regplot_kwargs: Mapping[str, Any] = MappingProxyType({}), scatterplot_kwargs: Mapping[str, Any] = MappingProxyType({}), ) -> Axes | None: """ Plot a variable using a smooth regression line with increasing distance to an anchor point. Parameters ---------- %(adata)s design_matrix_key Name of the design matrix, previously computed with :func:`squidpy.tl.var_by_distance`, to use. var Variables to plot on y-axis. anchor_key Anchor point column from which distances are taken. color Variables to plot on color palette. covariate A covariate for which separate regression lines are plotted for each category. order Order of the polynomial fit for :func:`seaborn.regplot`. show_scatter Whether to show a scatter plot underlying the regression line. line_palette Categorical color palette used in case a covariate is specified. scatter_palette Color palette for the scatter plot underlying the :func:`seaborn.regplot`. dpi Dots per inch. figsize Size of the figure in inches. save Whether to save the plot. title Panel titles. axis_label Panel axis labels. return_ax Whether to return :class:`matplotlib.axes.Axes` object(s). regplot_kwargs Kwargs for :func:`seaborn.regplot`. scatterplot_kwargs Kwargs for :func:`matplotlib.pyplot.scatter`. Returns ------- %(plotting_returns)s """ dpi = rcParams["figure.dpi"] if dpi is None else dpi regplot_kwargs = dict(regplot_kwargs) scatterplot_kwargs = dict(scatterplot_kwargs) df = adata.obsm[design_matrix_key] # get design matrix df[var] = np.array(adata[:, var].X.A) if issparse(adata[:, var].X) else np.array(adata[:, var].X) # add var column # if several variables are plotted, make a panel grid if isinstance(var, list): fig, grid = _panel_grid( hspace=0.25, wspace=0.75 / rcParams["figure.figsize"][0] + 0.02, ncols=4, num_panels=len(var) ) axs = [] else: var = [var] # iterate over the variables to plot for i, v in enumerate(var): if len(var) > 1: ax = plt.subplot(grid[i]) axs.append(ax) else: # if a single variable and no grid, then one ax object suffices fig, ax = plt.subplots(1, 1, figsize=figsize) # if no covariate is specified, 'sns.regplot' will take the values of all observations if covariate is None: sns.regplot( data=df, x=anchor_key, y=v, order=order, color=line_palette, scatter=show_scatter, ax=ax, line_kws=regplot_kwargs, ) else: # make a categorical color palette if none was specified and there are several regplots to be plotted if isinstance(line_palette, str) or line_palette is None: _set_default_colors_for_categorical_obs(adata, covariate) line_palette = adata.uns[covariate + "_colors"] covariate_instances = df[covariate].unique() # iterate over all covariate values and make 'sns.regplot' for each for i, co in enumerate(covariate_instances): sns.regplot( data=df.loc[df[covariate] == co], x=anchor_key, y=v, order=order, color=line_palette[i], scatter=show_scatter, ax=ax, label=co, line_kws=regplot_kwargs, ) label_colors, _ = ax.get_legend_handles_labels() ax.legend(label_colors, covariate_instances) # add scatter plot if specified if show_scatter: if color is None: plt.scatter(data=df, x=anchor_key, y=v, color="grey", **scatterplot_kwargs) # if variable to plot on color palette is categorical, make categorical color palette elif df[color].dtype.name == "category": unique_colors = df[color].unique() cNorm = colors.Normalize(vmin=0, vmax=len(unique_colors)) scalarMap = cm.ScalarMappable(norm=cNorm, cmap=scatter_palette) for i in range(len(unique_colors)): plt.scatter( data=df.loc[df[color] == unique_colors[i]], x=anchor_key, y=v, color=scalarMap.to_rgba(i), **scatterplot_kwargs, ) # if variable to plot on color palette is not categorical else: plt.scatter(data=df, x=anchor_key, y=v, c=color, cmap=scatter_palette, **scatterplot_kwargs) if title is not None: ax.set(title=title) if axis_label is None: ax.set(xlabel=f"distance to {anchor_key}") else: ax.set(xlabel=axis_label) # remove line palette if it was made earlier in function if f"{covariate}_colors" in adata.uns: del line_palette axs = axs if len(var) > 1 else ax if save is not None: save_fig(fig, path=save, transparent=False, dpi=dpi) if return_ax: return axs