Source code for squidpy.gr._nhood

"""Functions for neighborhood enrichment analysis (permutation test, centralities measures etc.)."""

from __future__ import annotations

from collections.abc import Iterable, Sequence
from functools import partial
from typing import Any, Callable

import networkx as nx
import numba.types as nt
import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit, prange  # noqa: F401
from pandas import CategoricalDtype
from scanpy import logging as logg
from spatialdata import SpatialData

from squidpy._constants._constants import Centrality
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
from squidpy.gr._utils import (
    _assert_categorical_obs,
    _assert_connectivity_key,
    _assert_positive,
    _save_data,
    _shuffle_group,
)

__all__ = ["nhood_enrichment", "centrality_scores", "interaction_matrix"]

dt = nt.uint32  # data type aliases (both for numpy and numba should match)
ndt = np.uint32
_template = """
@njit(dt[:, :](dt[:], dt[:], dt[:]), parallel={parallel}, fastmath=True)
def _nenrich_{n_cls}_{parallel}(indices: NDArrayA, indptr: NDArrayA, clustering: NDArrayA) -> np.ndarray:
    '''
    Count how many times clusters :math:`i` and :math:`j` are connected.

    Parameters
    ----------
    indices
        :attr:`scipy.sparse.csr_matrix.indices`.
    indptr
        :attr:`scipy.sparse.csr_matrix.indptr`.
    clustering
        Array of shape ``(n_cells,)`` containig cluster labels ranging from `0` to `n_clusters - 1` inclusive.

    Returns
    -------
    :class:`numpy.ndarray`
        Array of shape ``(n_clusters, n_clusters)`` containing the pairwise counts.
    '''
    res = np.zeros((indptr.shape[0] - 1, {n_cls}), dtype=ndt)

    for i in prange(res.shape[0]):
        xs, xe = indptr[i], indptr[i + 1]
        cols = indices[xs:xe]
        for c in cols:
            res[i, clustering[c]] += 1
    {init}
    {loop}
    {finalize}
"""


def _create_function(n_cls: int, parallel: bool = False) -> Callable[[NDArrayA, NDArrayA, NDArrayA], NDArrayA]:
    """
    Create a :mod:`numba` function which counts the number of connections between clusters.

    Parameters
    ----------
    n_cls
        Number of clusters. We're assuming that cluster labels are `0`, `1`, ..., `n_cls - 1`.
    parallel
        Whether to enable :mod:`numba` parallelization.

    Returns
    -------
    The aforementioned function.
    """
    if n_cls <= 1:
        raise ValueError(f"Expected at least `2` clusters, found `{n_cls}`.")

    rng = range(n_cls)
    init = "".join(
        f"""
    g{i} = np.zeros(({n_cls},), dtype=ndt)"""
        for i in rng
    )

    loop_body = """
        if cl == 0:
            g0 += res[row]"""
    loop_body = loop_body + "".join(
        f"""
        elif cl == {i}:
            g{i} += res[row]"""
        for i in range(1, n_cls)
    )
    loop = f"""
    for row in prange(res.shape[0]):
        cl = clustering[row]
        {loop_body}
        else:
            assert False, "Unhandled case."
    """
    finalize = ", ".join(f"g{i}" for i in rng)
    finalize = f"return np.stack(({finalize}))"  # must really be a tuple

    fn_key = f"_nenrich_{n_cls}_{parallel}"
    if fn_key not in globals():
        template = _template.format(init=init, loop=loop, finalize=finalize, n_cls=n_cls, parallel=parallel)
        exec(compile(template, "", "exec"), globals())

    return globals()[fn_key]  # type: ignore[no-any-return]


[docs] @d.get_sections(base="nhood_ench", sections=["Parameters"]) @d.dedent def nhood_enrichment( adata: AnnData | SpatialData, cluster_key: str, library_key: str | None = None, connectivity_key: str | None = None, n_perms: int = 1000, numba_parallel: bool = False, seed: int | None = None, copy: bool = False, n_jobs: int | None = None, backend: str = "loky", show_progress_bar: bool = True, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute neighborhood enrichment by permutation test. Parameters ---------- %(adata)s %(cluster_key)s %(library_key)s %(conn_key)s %(n_perms)s %(numba_parallel)s %(seed)s %(copy)s %(parallelize)s Returns ------- If ``copy = True``, returns a :class:`tuple` with the z-score and the enrichment count. Otherwise, modifies the ``adata`` with the following keys: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score. - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count. """ if isinstance(adata, SpatialData): adata = adata.table connectivity_key = Key.obsp.spatial_conn(connectivity_key) _assert_categorical_obs(adata, cluster_key) _assert_connectivity_key(adata, connectivity_key) _assert_positive(n_perms, name="n_perms") adj = adata.obsp[connectivity_key] original_clust = adata.obs[cluster_key] clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} # map categories int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt) if library_key is not None: _assert_categorical_obs(adata, key=library_key) libraries: pd.Series | None = adata.obs[library_key] else: libraries = None indices, indptr = (adj.indices.astype(ndt), adj.indptr.astype(ndt)) n_cls = len(clust_map) _test = _create_function(n_cls, parallel=numba_parallel) count = _test(indices, indptr, int_clust) n_jobs = _get_n_cores(n_jobs) start = logg.info(f"Calculating neighborhood enrichment using `{n_jobs}` core(s)") perms = parallelize( _nhood_enrichment_helper, collection=np.arange(n_perms).tolist(), extractor=np.vstack, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )(callback=_test, indices=indices, indptr=indptr, int_clust=int_clust, libraries=libraries, n_cls=n_cls, seed=seed) zscore = (count - perms.mean(axis=0)) / perms.std(axis=0) if copy: return zscore, count _save_data( adata, attr="uns", key=Key.uns.nhood_enrichment(cluster_key), data={"zscore": zscore, "count": count}, time=start, )
[docs] @d.dedent @inject_docs(c=Centrality) def centrality_scores( adata: AnnData | SpatialData, cluster_key: str, score: str | Iterable[str] | None = None, connectivity_key: str | None = None, copy: bool = False, n_jobs: int | None = None, backend: str = "loky", show_progress_bar: bool = False, ) -> pd.DataFrame | None: """ Compute centrality scores per cluster or cell type. Inspired by usage in Gene Regulatory Networks (GRNs) in :cite:`celloracle`. Parameters ---------- %(adata)s %(cluster_key)s score Centrality measures as described in :mod:`networkx.algorithms.centrality` :cite:`networkx`. If `None`, use all the options below. Valid options are: - `{c.CLOSENESS.s!r}` - measure of how close the group is to other nodes. - `{c.CLUSTERING.s!r}` - measure of the degree to which nodes cluster together. - `{c.DEGREE.s!r}` - fraction of non-group members connected to group members. %(conn_key)s %(copy)s %(parallelize)s Returns ------- If ``copy = True``, returns a :class:`pandas.DataFrame`. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_centrality_scores']`` - the centrality scores, as mentioned above. """ if isinstance(adata, SpatialData): adata = adata.table connectivity_key = Key.obsp.spatial_conn(connectivity_key) _assert_categorical_obs(adata, cluster_key) _assert_connectivity_key(adata, connectivity_key) if isinstance(score, (str, Centrality)): centrality = [score] elif score is None: centrality = [c.s for c in Centrality] centralities = [Centrality(c) for c in centrality] graph = nx.Graph(adata.obsp[connectivity_key]) cat = adata.obs[cluster_key].cat.categories.values clusters = adata.obs[cluster_key].values fun_dict = {} for c in centralities: if c == Centrality.CLOSENESS: fun_dict[c.s] = partial(nx.algorithms.centrality.group_closeness_centrality, graph) elif c == Centrality.DEGREE: fun_dict[c.s] = partial(nx.algorithms.centrality.group_degree_centrality, graph) elif c == Centrality.CLUSTERING: fun_dict[c.s] = partial(nx.algorithms.cluster.average_clustering, graph) else: raise NotImplementedError(f"Centrality `{c}` is not yet implemented.") n_jobs = _get_n_cores(n_jobs) start = logg.info(f"Calculating centralities `{centralities}` using `{n_jobs}` core(s)") res_list = [] for k, v in fun_dict.items(): df = parallelize( _centrality_scores_helper, collection=cat, extractor=pd.concat, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )(clusters=clusters, fun=v, method=k) res_list.append(df) df = pd.concat(res_list, axis=1) if copy: return df _save_data(adata, attr="uns", key=Key.uns.centrality_scores(cluster_key), data=df, time=start)
[docs] @d.dedent def interaction_matrix( adata: AnnData | SpatialData, cluster_key: str, connectivity_key: str | None = None, normalized: bool = False, copy: bool = False, weights: bool = False, ) -> NDArrayA | None: """ Compute interaction matrix for clusters. Parameters ---------- %(adata)s %(cluster_key)s %(conn_key)s normalized If `True`, each row is normalized to sum to 1. %(copy)s weights Whether to use edge weights or binarize. Returns ------- If ``copy = True``, returns the interaction matrix. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix. """ if isinstance(adata, SpatialData): adata = adata.table connectivity_key = Key.obsp.spatial_conn(connectivity_key) _assert_categorical_obs(adata, cluster_key) _assert_connectivity_key(adata, connectivity_key) cats = adata.obs[cluster_key] mask = ~pd.isnull(cats).values cats = cats.loc[mask] if not len(cats): raise RuntimeError(f"After removing NaNs in `adata.obs[{cluster_key!r}]`, none remain.") g = adata.obsp[connectivity_key] g = g[mask, :][:, mask] n_cats = len(cats.cat.categories) g_data = g.data if weights else np.broadcast_to(1, shape=len(g.data)) dtype = int if pd.api.types.is_bool_dtype(g.dtype) or pd.api.types.is_integer_dtype(g.dtype) else float output = np.zeros((n_cats, n_cats), dtype=dtype) # type: ignore[var-annotated] _interaction_matrix(g_data, g.indices, g.indptr, cats.cat.codes.to_numpy(), output) if normalized: output = output / output.sum(axis=1).reshape((-1, 1)) if copy: return output _save_data(adata, attr="uns", key=Key.uns.interaction_matrix(cluster_key), data=output)
@njit def _interaction_matrix( data: NDArrayA, indices: NDArrayA, indptr: NDArrayA, cats: NDArrayA, output: NDArrayA ) -> NDArrayA: indices_list = np.split(indices, indptr[1:-1]) data_list = np.split(data, indptr[1:-1]) for i in range(len(data_list)): cur_row = cats[i] cur_indices = indices_list[i] cur_data = data_list[i] for j, val in zip(cur_indices, cur_data): cur_col = cats[j] output[cur_row, cur_col] += val return output def _centrality_scores_helper( cat: Iterable[Any], clusters: Sequence[str], fun: Callable[..., float], method: str, queue: SigQueue | None = None, ) -> pd.DataFrame: res_list = [] for c in cat: idx = np.where(clusters == c)[0] res = fun(idx) res_list.append(res) if queue is not None: queue.put(Signal.UPDATE) if queue is not None: queue.put(Signal.FINISH) return pd.DataFrame(res_list, columns=[method], index=cat) def _nhood_enrichment_helper( ixs: NDArrayA, callback: Callable[[NDArrayA, NDArrayA, NDArrayA], NDArrayA], indices: NDArrayA, indptr: NDArrayA, int_clust: NDArrayA, libraries: pd.Series[CategoricalDtype] | None, n_cls: int, seed: int | None = None, queue: SigQueue | None = None, ) -> NDArrayA: perms = np.empty((len(ixs), n_cls, n_cls), dtype=np.float64) int_clust = int_clust.copy() # threading rs = np.random.RandomState(seed=None if seed is None else seed + ixs[0]) for i in range(len(ixs)): if libraries is not None: int_clust = _shuffle_group(int_clust, libraries, rs) else: rs.shuffle(int_clust) perms[i, ...] = callback(indices, indptr, int_clust) if queue is not None: queue.put(Signal.UPDATE) if queue is not None: queue.put(Signal.FINISH) return perms