"""Functions for neighborhood enrichment analysis (permutation test, centralities measures etc.)."""
from __future__ import annotations
from functools import partial
from typing import (
Any,
Callable,
Iterable,
Sequence,
Union, # noqa: F401
)
import networkx as nx
import numba.types as nt
import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit, prange # noqa: F401
from pandas import CategoricalDtype
from scanpy import logging as logg
from spatialdata import SpatialData
from squidpy._constants._constants import Centrality
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
from squidpy.gr._utils import (
_assert_categorical_obs,
_assert_connectivity_key,
_assert_positive,
_save_data,
_shuffle_group,
)
__all__ = ["nhood_enrichment", "centrality_scores", "interaction_matrix"]
dt = nt.uint32 # data type aliases (both for numpy and numba should match)
ndt = np.uint32
_template = """
@njit(dt[:, :](dt[:], dt[:], dt[:]), parallel={parallel}, fastmath=True)
def _nenrich_{n_cls}_{parallel}(indices: NDArrayA, indptr: NDArrayA, clustering: NDArrayA) -> np.ndarray:
'''
Count how many times clusters :math:`i` and :math:`j` are connected.
Parameters
----------
indices
:attr:`scipy.sparse.csr_matrix.indices`.
indptr
:attr:`scipy.sparse.csr_matrix.indptr`.
clustering
Array of shape ``(n_cells,)`` containig cluster labels ranging from `0` to `n_clusters - 1` inclusive.
Returns
-------
:class:`numpy.ndarray`
Array of shape ``(n_clusters, n_clusters)`` containing the pairwise counts.
'''
res = np.zeros((indptr.shape[0] - 1, {n_cls}), dtype=ndt)
for i in prange(res.shape[0]):
xs, xe = indptr[i], indptr[i + 1]
cols = indices[xs:xe]
for c in cols:
res[i, clustering[c]] += 1
{init}
{loop}
{finalize}
"""
def _create_function(n_cls: int, parallel: bool = False) -> Callable[[NDArrayA, NDArrayA, NDArrayA], NDArrayA]:
"""
Create a :mod:`numba` function which counts the number of connections between clusters.
Parameters
----------
n_cls
Number of clusters. We're assuming that cluster labels are `0`, `1`, ..., `n_cls - 1`.
parallel
Whether to enable :mod:`numba` parallelization.
Returns
-------
The aforementioned function.
"""
if n_cls <= 1:
raise ValueError(f"Expected at least `2` clusters, found `{n_cls}`.")
rng = range(n_cls)
init = "".join(
f"""
g{i} = np.zeros(({n_cls},), dtype=ndt)"""
for i in rng
)
loop_body = """
if cl == 0:
g0 += res[row]"""
loop_body = loop_body + "".join(
f"""
elif cl == {i}:
g{i} += res[row]"""
for i in range(1, n_cls)
)
loop = f"""
for row in prange(res.shape[0]):
cl = clustering[row]
{loop_body}
else:
assert False, "Unhandled case."
"""
finalize = ", ".join(f"g{i}" for i in rng)
finalize = f"return np.stack(({finalize}))" # must really be a tuple
fn_key = f"_nenrich_{n_cls}_{parallel}"
if fn_key not in globals():
template = _template.format(init=init, loop=loop, finalize=finalize, n_cls=n_cls, parallel=parallel)
exec(compile(template, "", "exec"), globals())
return globals()[fn_key] # type: ignore[no-any-return]
[docs]
@d.get_sections(base="nhood_ench", sections=["Parameters"])
@d.dedent
def nhood_enrichment(
adata: AnnData | SpatialData,
cluster_key: str,
library_key: str | None = None,
connectivity_key: str | None = None,
n_perms: int = 1000,
numba_parallel: bool = False,
seed: int | None = None,
copy: bool = False,
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = True,
) -> tuple[NDArrayA, NDArrayA] | None:
"""
Compute neighborhood enrichment by permutation test.
Parameters
----------
%(adata)s
%(cluster_key)s
%(library_key)s
%(conn_key)s
%(n_perms)s
%(numba_parallel)s
%(seed)s
%(copy)s
%(parallelize)s
Returns
-------
If ``copy = True``, returns a :class:`tuple` with the z-score and the enrichment count.
Otherwise, modifies the ``adata`` with the following keys:
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score.
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count.
"""
if isinstance(adata, SpatialData):
adata = adata.table
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
_assert_categorical_obs(adata, cluster_key)
_assert_connectivity_key(adata, connectivity_key)
_assert_positive(n_perms, name="n_perms")
adj = adata.obsp[connectivity_key]
original_clust = adata.obs[cluster_key]
clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} # map categories
int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt)
if library_key is not None:
_assert_categorical_obs(adata, key=library_key)
libraries: pd.Series | None = adata.obs[library_key]
else:
libraries = None
indices, indptr = (adj.indices.astype(ndt), adj.indptr.astype(ndt))
n_cls = len(clust_map)
_test = _create_function(n_cls, parallel=numba_parallel)
count = _test(indices, indptr, int_clust)
n_jobs = _get_n_cores(n_jobs)
start = logg.info(f"Calculating neighborhood enrichment using `{n_jobs}` core(s)")
perms = parallelize(
_nhood_enrichment_helper,
collection=np.arange(n_perms).tolist(),
extractor=np.vstack,
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)(callback=_test, indices=indices, indptr=indptr, int_clust=int_clust, libraries=libraries, n_cls=n_cls, seed=seed)
zscore = (count - perms.mean(axis=0)) / perms.std(axis=0)
if copy:
return zscore, count
_save_data(
adata,
attr="uns",
key=Key.uns.nhood_enrichment(cluster_key),
data={"zscore": zscore, "count": count},
time=start,
)
[docs]
@d.dedent
@inject_docs(c=Centrality)
def centrality_scores(
adata: AnnData | SpatialData,
cluster_key: str,
score: str | Iterable[str] | None = None,
connectivity_key: str | None = None,
copy: bool = False,
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = False,
) -> pd.DataFrame | None:
"""
Compute centrality scores per cluster or cell type.
Inspired by usage in Gene Regulatory Networks (GRNs) in :cite:`celloracle`.
Parameters
----------
%(adata)s
%(cluster_key)s
score
Centrality measures as described in :mod:`networkx.algorithms.centrality` :cite:`networkx`.
If `None`, use all the options below. Valid options are:
- `{c.CLOSENESS.s!r}` - measure of how close the group is to other nodes.
- `{c.CLUSTERING.s!r}` - measure of the degree to which nodes cluster together.
- `{c.DEGREE.s!r}` - fraction of non-group members connected to group members.
%(conn_key)s
%(copy)s
%(parallelize)s
Returns
-------
If ``copy = True``, returns a :class:`pandas.DataFrame`. Otherwise, modifies the ``adata`` with the following key:
- :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_centrality_scores']`` - the centrality scores,
as mentioned above.
"""
if isinstance(adata, SpatialData):
adata = adata.table
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
_assert_categorical_obs(adata, cluster_key)
_assert_connectivity_key(adata, connectivity_key)
if isinstance(score, (str, Centrality)):
centrality = [score]
elif score is None:
centrality = [c.s for c in Centrality]
centralities = [Centrality(c) for c in centrality]
graph = nx.Graph(adata.obsp[connectivity_key])
cat = adata.obs[cluster_key].cat.categories.values
clusters = adata.obs[cluster_key].values
fun_dict = {}
for c in centralities:
if c == Centrality.CLOSENESS:
fun_dict[c.s] = partial(nx.algorithms.centrality.group_closeness_centrality, graph)
elif c == Centrality.DEGREE:
fun_dict[c.s] = partial(nx.algorithms.centrality.group_degree_centrality, graph)
elif c == Centrality.CLUSTERING:
fun_dict[c.s] = partial(nx.algorithms.cluster.average_clustering, graph)
else:
raise NotImplementedError(f"Centrality `{c}` is not yet implemented.")
n_jobs = _get_n_cores(n_jobs)
start = logg.info(f"Calculating centralities `{centralities}` using `{n_jobs}` core(s)")
res_list = []
for k, v in fun_dict.items():
df = parallelize(
_centrality_scores_helper,
collection=cat,
extractor=pd.concat,
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)(clusters=clusters, fun=v, method=k)
res_list.append(df)
df = pd.concat(res_list, axis=1)
if copy:
return df
_save_data(adata, attr="uns", key=Key.uns.centrality_scores(cluster_key), data=df, time=start)
[docs]
@d.dedent
def interaction_matrix(
adata: AnnData | SpatialData,
cluster_key: str,
connectivity_key: str | None = None,
normalized: bool = False,
copy: bool = False,
weights: bool = False,
) -> NDArrayA | None:
"""
Compute interaction matrix for clusters.
Parameters
----------
%(adata)s
%(cluster_key)s
%(conn_key)s
normalized
If `True`, each row is normalized to sum to 1.
%(copy)s
weights
Whether to use edge weights or binarize.
Returns
-------
If ``copy = True``, returns the interaction matrix.
Otherwise, modifies the ``adata`` with the following key:
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix.
"""
if isinstance(adata, SpatialData):
adata = adata.table
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
_assert_categorical_obs(adata, cluster_key)
_assert_connectivity_key(adata, connectivity_key)
cats = adata.obs[cluster_key]
mask = ~pd.isnull(cats).values
cats = cats.loc[mask]
if not len(cats):
raise RuntimeError(f"After removing NaNs in `adata.obs[{cluster_key!r}]`, none remain.")
g = adata.obsp[connectivity_key]
g = g[mask, :][:, mask]
n_cats = len(cats.cat.categories)
g_data = g.data if weights else np.broadcast_to(1, shape=len(g.data))
dtype = int if pd.api.types.is_bool_dtype(g.dtype) or pd.api.types.is_integer_dtype(g.dtype) else float
output = np.zeros((n_cats, n_cats), dtype=dtype) # type: ignore[var-annotated]
_interaction_matrix(g_data, g.indices, g.indptr, cats.cat.codes.to_numpy(), output)
if normalized:
output = output / output.sum(axis=1).reshape((-1, 1))
if copy:
return output
_save_data(adata, attr="uns", key=Key.uns.interaction_matrix(cluster_key), data=output)
@njit
def _interaction_matrix(
data: NDArrayA, indices: NDArrayA, indptr: NDArrayA, cats: NDArrayA, output: NDArrayA
) -> NDArrayA:
indices_list = np.split(indices, indptr[1:-1])
data_list = np.split(data, indptr[1:-1])
for i in range(len(data_list)):
cur_row = cats[i]
cur_indices = indices_list[i]
cur_data = data_list[i]
for j, val in zip(cur_indices, cur_data):
cur_col = cats[j]
output[cur_row, cur_col] += val
return output
def _centrality_scores_helper(
cat: Iterable[Any],
clusters: Sequence[str],
fun: Callable[..., float],
method: str,
queue: SigQueue | None = None,
) -> pd.DataFrame:
res_list = []
for c in cat:
idx = np.where(clusters == c)[0]
res = fun(idx)
res_list.append(res)
if queue is not None:
queue.put(Signal.UPDATE)
if queue is not None:
queue.put(Signal.FINISH)
return pd.DataFrame(res_list, columns=[method], index=cat)
def _nhood_enrichment_helper(
ixs: NDArrayA,
callback: Callable[[NDArrayA, NDArrayA, NDArrayA], NDArrayA],
indices: NDArrayA,
indptr: NDArrayA,
int_clust: NDArrayA,
libraries: pd.Series[CategoricalDtype] | None,
n_cls: int,
seed: int | None = None,
queue: SigQueue | None = None,
) -> NDArrayA:
perms = np.empty((len(ixs), n_cls, n_cls), dtype=np.float64)
int_clust = int_clust.copy() # threading
rs = np.random.RandomState(seed=None if seed is None else seed + ixs[0])
for i in range(len(ixs)):
if libraries is not None:
int_clust = _shuffle_group(int_clust, libraries, rs)
else:
rs.shuffle(int_clust)
perms[i, ...] = callback(indices, indptr, int_clust)
if queue is not None:
queue.put(Signal.UPDATE)
if queue is not None:
queue.put(Signal.FINISH)
return perms