Source code for squidpy.gr._ligrec

"""Permutation test function as described in CellPhoneDB 2.0."""

from __future__ import annotations

from abc import ABC
from collections import namedtuple
from functools import partial
from itertools import product
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Sequence, Union

import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit, prange  # noqa: F401
from scanpy import logging as logg
from scipy.sparse import csc_matrix
from spatialdata import SpatialData

from squidpy._constants._constants import ComplexPolicy, CorrAxis
from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
from squidpy.gr._utils import (
    _assert_categorical_obs,
    _assert_positive,
    _check_tuple_needles,
    _create_sparse_df,
    _genesymbols,
    _save_data,
)

__all__ = ["ligrec", "PermutationTest"]

StrSeq = Sequence[str]
SeqTuple = Sequence[tuple[str, str]]
Interaction_t = Union[pd.DataFrame, Mapping[str, StrSeq], StrSeq, tuple[StrSeq, StrSeq], SeqTuple]
Cluster_t = Union[StrSeq, tuple[StrSeq, StrSeq], SeqTuple]

SOURCE = "source"
TARGET = "target"

TempResult = namedtuple("TempResult", ["means", "pvalues"])

_template = """
@njit(parallel={parallel}, cache=False, fastmath=False)
def _test_{n_cls}_{ret_means}_{parallel}(
    interactions: NDArrayA,  # [np.uint32],
    interaction_clusters: NDArrayA,  # [np.uint32],
    data: NDArrayA,  # [np.float64],
    clustering: NDArrayA,  # [np.uint32],
    mean: NDArrayA,  # [np.float64],
    mask: NDArrayA,  # [np.bool_],
    res: NDArrayA,  # [np.float64],
    {args}
) -> None:

    {init}
    {loop}
    {finalize}

    for i in prange(len(interactions)):
        rec, lig = interactions[i]
        for j in prange(len(interaction_clusters)):
            c1, c2 = interaction_clusters[j]
            m1, m2 = mean[rec, c1], mean[lig, c2]

            if np.isnan(res[i, j]):
                continue

            if m1 > 0 and m2 > 0:
                {set_means}
                if mask[rec, c1] and mask[lig, c2]:
                    # both rec, lig are sufficiently expressed in c1, c2
                    res[i, j] += (groups[c1, rec] + groups[c2, lig]) > (m1 + m2)
                else:
                    res[i, j] = np.nan
            else:
                # res_means is initialized with 0s
                res[i, j] = np.nan
"""


def _create_template(n_cls: int, return_means: bool = False, parallel: bool = True) -> str:
    if n_cls <= 0:
        raise ValueError(f"Expected number of clusters to be positive, found `{n_cls}`.")

    rng = range(n_cls)
    init = "".join(
        f"""
    g{i} = np.zeros((data.shape[1],), dtype=np.float64); s{i} = 0"""
        for i in rng
    )

    loop_body = """
        if cl == 0:
            g0 += data[row]
            s0 += 1"""
    loop_body = loop_body + "".join(
        f"""
        elif cl == {i}:
            g{i} += data[row]
            s{i} += 1"""
        for i in range(1, n_cls)
    )
    loop = f"""
    for row in prange(data.shape[0]):
        cl = clustering[row]
        {loop_body}
        else:
            assert False, "Unhandled case."
    """
    finalize = ", ".join(f"g{i} / s{i}" for i in rng)
    finalize = f"groups = np.stack(({finalize}))"

    if return_means:
        args = "res_means: NDArrayA,  # [np.float64]"
        set_means = "res_means[i, j] = (m1 + m2) / 2.0"
    else:
        args = set_means = ""

    return _template.format(
        n_cls=n_cls,
        parallel=bool(parallel),
        ret_means=int(return_means),
        args=args,
        init=init,
        loop=loop,
        finalize=finalize,
        set_means=set_means,
    )


def _fdr_correct(
    pvals: pd.DataFrame,
    corr_method: str,
    corr_axis: Literal["interactions", "clusters"] | CorrAxis,
    alpha: float = 0.05,
) -> pd.DataFrame:
    """Correct p-values for FDR along specific axis in ``pvals``."""
    from pandas.core.arrays.sparse import SparseArray
    from statsmodels.stats.multitest import multipletests

    def fdr(pvals: pd.Series) -> SparseArray:
        _, qvals, _, _ = multipletests(
            np.nan_to_num(pvals.values, copy=True, nan=1.0),
            method=corr_method,
            alpha=alpha,
            is_sorted=False,
            returnsorted=False,
        )
        qvals[np.isnan(pvals.values)] = np.nan

        return SparseArray(qvals, dtype=qvals.dtype, fill_value=np.nan)

    corr_axis = CorrAxis(corr_axis)

    if corr_axis == CorrAxis.CLUSTERS:
        # clusters are in columns
        pvals = pvals.apply(fdr)
    elif corr_axis == CorrAxis.INTERACTIONS:
        pvals = pvals.T.apply(fdr).T
    else:
        raise NotImplementedError(f"FDR correction for `{corr_axis}` is not implemented.")

    return pvals


@d.get_full_description(base="PT")
@d.get_sections(base="PT", sections=["Parameters"])
@d.dedent
class PermutationTestABC(ABC):
    """
    Class for receptor-ligand interaction testing.

    The expected workflow is::

        pt = PermutationTest(adata).prepare()
        res = pt.test("clusters")

    Parameters
    ----------
    %(adata)s
    use_raw
        Whether to access :attr:`anndata.AnnData.raw`.
    """

    def __init__(self, adata: AnnData, use_raw: bool = True):
        if not isinstance(adata, AnnData):
            raise TypeError(f"Expected `adata` to be of type `anndata.AnnData`, found `{type(adata).__name__}`.")
        if not adata.n_obs:
            raise ValueError("No cells are in `adata.obs_names`.")
        if not adata.n_vars:
            raise ValueError("No genes are in `adata.var_names`.")

        self._adata = adata
        if use_raw:
            if adata.raw is None:
                raise AttributeError("No `.raw` attribute found. Try specifying `use_raw=False`.")
            if adata.raw.n_obs != adata.n_obs:
                raise ValueError(f"Expected `{adata.n_obs}` cells in `.raw` object, found `{adata.raw.n_obs}`.")
            adata = adata.raw

        self._data = pd.DataFrame.sparse.from_spmatrix(
            csc_matrix(adata.X), index=adata.obs_names, columns=adata.var_names
        )

        self._interactions: pd.DataFrame | None = None
        self._filtered_data: pd.DataFrame | None = None

    @d.get_full_description(base="PT_prepare")
    @d.get_sections(base="PT_prepare", sections=["Parameters", "Returns"])
    @inject_docs(src=SOURCE, tgt=TARGET, cp=ComplexPolicy)
    def prepare(
        self, interactions: Interaction_t, complex_policy: Literal["min", "all"] | ComplexPolicy = ComplexPolicy.MIN.v
    ) -> PermutationTestABC:
        """
        Prepare self for running the permutation test.

        Parameters
        ----------
        interactions
            Interaction to test. The type can be one of:

                - :class:`pandas.DataFrame` - must contain at least 2 columns named `{src!r}` and `{tgt!r}`.
                - :class:`dict` - dictionary with at least 2 keys named `{src!r}` and `{tgt!r}`.
                - :class:`typing.Sequence` - Either a sequence of :class:`str`, in which case all combinations are
                  produced, or a sequence of :class:`tuple` of 2 :class:`str` or a :class:`tuple` of 2 sequences.

            If `None`, the interactions are extracted from :mod:`omnipath`. Protein complexes can be specified by
            delimiting the components with `'_'`, such as `'alpha_beta_gamma'`.
        complex_policy
            Policy on how to handle complexes. Valid options are:

                - `{cp.MIN.s!r}` - select gene with the minimum average expression. This is the same as in
                  :cite:`cellphonedb`.
                - `{cp.ALL.s!r}` - select all possible combinations between `{src!r}` and `{tgt!r}` complexes.

        Returns
        -------
        Sets the following attributes and returns :attr:`self`:

            - :attr:`interactions` - filtered interactions whose `{src!r}` and `{tgt!r}` are both in the data.
        """
        complex_policy = ComplexPolicy(complex_policy)

        if isinstance(interactions, Mapping):
            interactions = pd.DataFrame(interactions)

        if isinstance(interactions, pd.DataFrame):
            if SOURCE not in interactions.columns:
                raise KeyError(f"Column `{SOURCE!r}` is not in `interactions`.")
            if TARGET not in interactions.columns:
                raise KeyError(f"Column `{TARGET!r}` is not in `interactions`.")

            self._interactions = interactions.copy()
        elif isinstance(interactions, Iterable):
            interactions = tuple(interactions)
            if not len(interactions):
                raise ValueError("No interactions were specified.")

            if isinstance(interactions[0], str):
                interactions = list(product(interactions, repeat=2))
            elif len(interactions) == 2:
                interactions = tuple(zip(*interactions))

            if not all(len(i) == 2 for i in interactions):
                raise ValueError("Not all interactions are of length `2`.")

            self._interactions = pd.DataFrame(interactions, columns=[SOURCE, TARGET])
        else:
            raise TypeError(
                f"Expected either a `pandas.DataFrame`, `dict` or `iterable`, found `{type(interactions).__name__}`"
            )
        if TYPE_CHECKING:
            assert isinstance(self.interactions, pd.DataFrame)

        if self.interactions.empty:
            raise ValueError("The interactions are empty")

        # first uppercase, then drop duplicates
        self._data.columns = self._data.columns.str.upper()
        self.interactions[SOURCE] = self.interactions[SOURCE].str.upper()
        self.interactions[TARGET] = self.interactions[TARGET].str.upper()

        logg.debug("DEBUG: Removing duplicate interactions")
        self.interactions.dropna(subset=(SOURCE, TARGET), inplace=True, how="any")
        self.interactions.drop_duplicates(subset=(SOURCE, TARGET), inplace=True, keep="first")

        logg.debug("DEBUG: Removing duplicate genes in the data")
        n_genes_prior = self._data.shape[1]
        self._data = self._data.loc[:, ~self._data.columns.duplicated()]
        if self._data.shape[1] != n_genes_prior:
            logg.warning(f"Removed `{n_genes_prior - self._data.shape[1]}` duplicate gene(s)")

        self._filter_interactions_complexes(complex_policy)
        self._filter_interactions_by_genes()
        self._trim_data()

        # this is necessary because of complexes
        self.interactions.drop_duplicates(subset=(SOURCE, TARGET), inplace=True, keep="first")

        return self

    @d.get_full_description(base="PT_test")
    @d.get_sections(base="PT_test", sections=["Parameters"])
    @d.dedent
    @inject_docs(src=SOURCE, tgt=TARGET, fa=CorrAxis)
    def test(
        self,
        cluster_key: str,
        clusters: Cluster_t | None = None,
        n_perms: int = 1000,
        threshold: float = 0.01,
        seed: int | None = None,
        corr_method: str | None = None,
        corr_axis: Literal["interactions", "clusters"] | CorrAxis = CorrAxis.INTERACTIONS.v,
        alpha: float = 0.05,
        copy: bool = False,
        key_added: str | None = None,
        numba_parallel: bool | None = None,
        **kwargs: Any,
    ) -> Mapping[str, pd.DataFrame] | None:
        """
        Perform the permutation test as described in :cite:`cellphonedb`.

        Parameters
        ----------
        %(cluster_key)s
        clusters
            Clusters from :attr:`anndata.AnnData.obs` ``['{{cluster_key}}']``. Can be specified either as a sequence
            of :class:`tuple` or just a sequence of cluster names, in which case all combinations considered.
        %(n_perms)s
        threshold
            Do not perform permutation test if any of the interacting components is being expressed
            in less than ``threshold`` percent of cells within a given cluster.
        %(seed)s
        %(corr_method)s
        corr_axis
            Axis over which to perform the FDR correction. Only used when ``corr_method != None``. Valid options are:

                - `{fa.INTERACTIONS.s!r}` - correct interactions by performing FDR correction across the clusters.
                - `{fa.CLUSTERS.s!r}` - correct clusters by performing FDR correction across the interactions.
        alpha
            Significance level for FDR correction. Only used when ``corr_method != None``.
        %(copy)s
        key_added
            Key in :attr:`anndata.AnnData.uns` where the result is stored if ``copy = False``.
            If `None`, ``'{{cluster_key}}_ligrec'`` will be used.
        %(numba_parallel)s
        %(parallelize)s

        Returns
        -------
        %(ligrec_test_returns)s
        """
        _assert_positive(n_perms, name="n_perms")
        _assert_categorical_obs(self._adata, key=cluster_key)

        if corr_method is not None:
            corr_axis = CorrAxis(corr_axis)
        if TYPE_CHECKING:
            assert isinstance(corr_axis, CorrAxis)

        if len(self._adata.obs[cluster_key].cat.categories) <= 1:
            raise ValueError(
                f"Expected at least `2` clusters, found `{len(self._adata.obs[cluster_key].cat.categories)}`."
            )
        if TYPE_CHECKING:
            assert isinstance(self.interactions, pd.DataFrame)
            assert isinstance(self._filtered_data, pd.DataFrame)

        interactions = self.interactions[[SOURCE, TARGET]]
        self._filtered_data["clusters"] = self._adata.obs.copy()[cluster_key].astype("string").astype("category").values

        if clusters is None:
            clusters = list(map(str, self._adata.obs[cluster_key].cat.categories))
        if all(isinstance(c, str) for c in clusters):
            clusters = list(product(clusters, repeat=2))  # type: ignore[assignment]
        clusters = sorted(
            _check_tuple_needles(
                clusters,  # type: ignore[arg-type]
                self._filtered_data["clusters"].cat.categories,
                msg="Invalid cluster `{0!r}`.",
                reraise=True,
            )
        )
        clusters_flat = list({c for cs in clusters for c in cs})

        data = self._filtered_data.loc[np.isin(self._filtered_data["clusters"], clusters_flat), :]
        data["clusters"] = data["clusters"].cat.remove_unused_categories()
        cat = data["clusters"].cat

        cluster_mapper = dict(zip(cat.categories, range(len(cat.categories))))
        gene_mapper = dict(zip(data.columns[:-1], range(len(data.columns) - 1)))  # -1 for 'clusters'

        data.columns = [gene_mapper[c] if c != "clusters" else c for c in data.columns]
        clusters_ = np.array([[cluster_mapper[c1], cluster_mapper[c2]] for c1, c2 in clusters], dtype=np.uint32)

        data["clusters"] = cat.rename_categories(cluster_mapper)
        # much faster than applymap (tested on 1M interactions)
        interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values)

        n_jobs = _get_n_cores(kwargs.pop("n_jobs", None))
        start = logg.info(
            f"Running `{n_perms}` permutations on `{len(interactions)}` interactions "
            f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)"
        )
        res = _analysis(
            data,
            interactions_,
            clusters_,
            threshold=threshold,
            n_perms=n_perms,
            seed=seed,
            n_jobs=n_jobs,
            numba_parallel=numba_parallel,
            **kwargs,
        )
        res = {
            "means": _create_sparse_df(
                res.means,
                index=pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET]),
                columns=pd.MultiIndex.from_tuples(clusters, names=["cluster_1", "cluster_2"]),
                fill_value=0,
            ),
            "pvalues": _create_sparse_df(
                res.pvalues,
                index=pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET]),
                columns=pd.MultiIndex.from_tuples(clusters, names=["cluster_1", "cluster_2"]),
                fill_value=np.nan,
            ),
            "metadata": self.interactions[self.interactions.columns.difference([SOURCE, TARGET])],
        }
        res["metadata"].index = res["means"].index.copy()

        if TYPE_CHECKING:
            assert isinstance(res, dict)

        if corr_method is not None:
            logg.info(
                f"Performing FDR correction across the `{corr_axis.v}` "
                f"using method `{corr_method}` at level `{alpha}`"
            )
            res["pvalues"] = _fdr_correct(res["pvalues"], corr_method, corr_axis, alpha=alpha)

        if copy:
            logg.info("Finish", time=start)
            return res

        _save_data(self._adata, attr="uns", key=Key.uns.ligrec(cluster_key, key_added), data=res, time=start)

    def _trim_data(self) -> None:
        """Subset genes :attr:`_data` to those present in interactions."""
        if TYPE_CHECKING:
            assert isinstance(self._data, pd.DataFrame)
            assert isinstance(self.interactions, pd.DataFrame)

        logg.debug("DEBUG: Removing genes not in any interaction")
        self._filtered_data = self._data.loc[:, list(set(self.interactions[SOURCE]) | set(self.interactions[TARGET]))]

    def _filter_interactions_by_genes(self) -> None:
        """Subset :attr:`interactions` to only those for which we have the data."""
        if TYPE_CHECKING:
            assert isinstance(self.interactions, pd.DataFrame)

        logg.debug("DEBUG: Removing interactions with no genes in the data")
        self._interactions = self.interactions[
            self.interactions[SOURCE].isin(self._data.columns) & self.interactions[TARGET].isin(self._data.columns)
        ]

        if self.interactions.empty:
            raise ValueError("After filtering by genes, no interactions remain.")

    @inject_docs(src=SOURCE, tgt=TARGET, cp=ComplexPolicy)
    def _filter_interactions_complexes(self, complex_policy: ComplexPolicy) -> None:
        """
        Filter the :attr:`interactions` by extracting genes from complexes.

        Parameters
        ----------
        complex_policy
            Policy on how to handle complexes. Valid options are:

                - `{cp.MIN.s!r}` - select gene with the minimum average expression. This is the same as in
                  :cite:`cellphonedb`.
                - `{cp.ALL.s!r}` - select all possible combinations between `{src!r}` and `{tgt!r}` complexes.

        Returns
        -------
        Nothing, just updates the following fields:

            - :attr:`interactions` - filtered interactions whose `{src!r}` and `{tgt!r}` are both in the data.

        Note that for ``complex_policy={cp.ALL.s!r}``, all pairwise comparisons within a complex are created,
        but no filtering happens at this stage - genes not present in the data are filtered at a later stage.
        """

        def find_min_gene_in_complex(_complex: str | None) -> str | None:
            # TODO(michalk8): how can this happen?
            if _complex is None:
                return None
            if "_" not in _complex:
                return _complex
            complexes = [c for c in _complex.split("_") if c in self._data.columns]
            if not len(complexes):
                return None
            if len(complexes) == 1:
                return complexes[0]

            df = self._data[complexes].mean()
            try:
                return str(df.index[df.argmin()])
            except ValueError as e:
                if "attempt to get argmin of an empty sequence" in str(e):
                    return str(df.index[0])
                else:
                    logg.error(e)

        if TYPE_CHECKING:
            assert isinstance(self._interactions, pd.DataFrame)
            assert isinstance(self.interactions, pd.DataFrame)

        if complex_policy == ComplexPolicy.MIN:
            logg.debug("DEBUG: Selecting genes from complexes based on minimum average expression")
            self.interactions[SOURCE] = self.interactions[SOURCE].apply(find_min_gene_in_complex)
            self.interactions[TARGET] = self.interactions[TARGET].apply(find_min_gene_in_complex)
        elif complex_policy == ComplexPolicy.ALL:
            logg.debug("DEBUG: Creating all gene combinations within complexes")
            src = self.interactions.pop(SOURCE).apply(lambda s: str(s).split("_")).explode()
            src.name = SOURCE
            tgt = self.interactions.pop(TARGET).apply(lambda s: str(s).split("_")).explode()
            tgt.name = TARGET

            self._interactions = pd.merge(self.interactions, src, how="left", left_index=True, right_index=True)
            self._interactions = pd.merge(self.interactions, tgt, how="left", left_index=True, right_index=True)
        else:
            raise NotImplementedError(f"Complex policy {complex_policy!r} is not implemented.")

    @property
    def interactions(self) -> pd.DataFrame | None:
        """The interactions."""  # noqa: D401
        return self._interactions

    def __repr__(self) -> str:
        return (
            f"<{self.__class__.__name__}"
            f"[n_interaction={len(self.interactions) if self.interactions is not None else None}]>"
        )

    def __str__(self) -> str:
        return repr(self)


@d.dedent
class PermutationTest(PermutationTestABC):
    """
    %(PT.full_desc)s

    Parameters
    ----------
    %(PT.parameters)s
    """  # noqa: D400

    @d.get_sections(base="PT_prepare_full", sections=["Parameters"])
    @d.dedent
    def prepare(
        self,
        interactions: Interaction_t | None = None,
        complex_policy: Literal["min", "all"] = ComplexPolicy.MIN.v,
        interactions_params: Mapping[str, Any] = MappingProxyType({}),
        transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}),
        receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}),
        **_: Any,
    ) -> PermutationTest:
        """
        %(PT_prepare.full_desc)s

        Parameters
        ----------
        %(PT_prepare.parameters)s
        interactions_params
            Keyword arguments for :func:`omnipath.interactions.import_intercell_network` defining the interactions.
            These datasets from :cite:`omnipath` are used by default: `omnipath`, `pathwayextra`, `kinaseextra` and
            `ligrecextra`.
        transmitter_params
            Keyword arguments for :func:`omnipath.interactions.import_intercell_network` defining the transmitter
            side of intercellular connections.
        receiver_params
            Keyword arguments for :func:`omnipath.interactions.import_intercell_network` defining the receiver
            side of intercellular connections.

        Returns
        -------
        %(PT_prepare.returns)s
        """  # noqa: D400
        if interactions is None:
            from omnipath.interactions import import_intercell_network

            start = logg.info("Fetching interactions from `omnipath`")
            interactions = import_intercell_network(
                interactions_params=interactions_params,
                transmitter_params=transmitter_params,
                receiver_params=receiver_params,
            )
            if TYPE_CHECKING:
                assert isinstance(interactions, pd.DataFrame)

            logg.info(f"Fetched `{len(interactions)}` interactions\n    Finish", time=start)

            # we don't really care about these
            if SOURCE in interactions.columns:
                interactions.pop(SOURCE)
            if TARGET in interactions.columns:
                interactions.pop(TARGET)
            interactions.rename(
                columns={"genesymbol_intercell_source": SOURCE, "genesymbol_intercell_target": TARGET}, inplace=True
            )

            interactions[SOURCE] = interactions[SOURCE].str.replace("^COMPLEX:", "", regex=True)
            interactions[TARGET] = interactions[TARGET].str.replace("^COMPLEX:", "", regex=True)

        _ = super().prepare(interactions, complex_policy=complex_policy)
        return self


[docs] @d.dedent def ligrec( adata: AnnData | SpatialData, cluster_key: str, interactions: Interaction_t | None = None, complex_policy: Literal["min", "all"] = ComplexPolicy.MIN.v, threshold: float = 0.01, corr_method: str | None = None, corr_axis: Literal["interactions", "clusters"] = CorrAxis.CLUSTERS.v, use_raw: bool = True, copy: bool = False, key_added: str | None = None, gene_symbols: str | None = None, **kwargs: Any, ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s Parameters ---------- %(PT.parameters)s %(PT_prepare_full.parameters)s %(PT_test.parameters)s gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. Returns ------- %(ligrec_test_returns)s """ # noqa: D400 if isinstance(adata, SpatialData): adata = adata.table with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) .prepare(interactions, complex_policy=complex_policy, **kwargs) .test( cluster_key=cluster_key, threshold=threshold, corr_method=corr_method, corr_axis=corr_axis, copy=copy, key_added=key_added, **kwargs, ) )
@d.dedent def _analysis( data: pd.DataFrame, interactions: NDArrayA, interaction_clusters: NDArrayA, threshold: float = 0.1, n_perms: int = 1000, seed: int | None = None, n_jobs: int = 1, numba_parallel: bool | None = None, **kwargs: Any, ) -> TempResult: """ Run the analysis as described in :cite:`cellphonedb`. This function runs the mean, percent and shuffled analysis. Parameters ---------- data Array of shape `(n_cells, n_genes)`. interactions Array of shape `(n_interactions, 2)`. interaction_clusters Array of shape `(n_interaction_clusters, 2)`. threshold Percentage threshold for removing lowly expressed genes in clusters. %(n_perms)s %(seed)s n_jobs Number of parallel jobs to launch. numba_parallel Whether to use :class:`numba.prange` or not. If `None`, it's determined automatically. kwargs Keyword arguments for :func:`squidpy._utils.parallelize`, such as ``n_jobs`` or ``backend``. Returns ------- Tuple of the following format: - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the means. - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing the p-values. """ def extractor(res: Sequence[TempResult]) -> TempResult: assert len(res) == n_jobs, f"Expected to find `{n_jobs}` results, found `{len(res)}`." meanss: list[NDArrayA] = [r.means for r in res if r.means is not None] assert len(meanss) == 1, f"Only `1` job should've calculated the means, but found `{len(meanss)}`." means = meanss[0] if TYPE_CHECKING: assert isinstance(means, np.ndarray) pvalues = np.sum([r.pvalues for r in res if r.pvalues is not None], axis=0) / float(n_perms) assert means.shape == pvalues.shape, f"Means and p-values differ in shape: `{means.shape}`, `{pvalues.shape}`." return TempResult(means=means, pvalues=pvalues) groups = data.groupby("clusters") clustering = np.array(data["clusters"].values, dtype=np.int32) mean = groups.mean().values.T # (n_genes, n_clusters) mask = groups.apply(lambda c: ((c > 0).sum() / len(c)) >= threshold).values.T # (n_genes, n_clusters) # (n_cells, n_genes) data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") # all 3 should be C contiguous return parallelize( # type: ignore[no-any-return] _analysis_helper, np.arange(n_perms, dtype=np.int32).tolist(), n_jobs=n_jobs, unit="permutation", extractor=extractor, **kwargs, )( data, mean, mask, interactions, interaction_clusters=interaction_clusters, clustering=clustering, seed=seed, numba_parallel=numba_parallel, ) def _analysis_helper( perms: NDArrayA, data: NDArrayA, mean: NDArrayA, mask: NDArrayA, interactions: NDArrayA, interaction_clusters: NDArrayA, clustering: NDArrayA, seed: int | None = None, numba_parallel: bool | None = None, queue: SigQueue | None = None, ) -> TempResult: """ Run the results of mean, percent and shuffled analysis. Parameters ---------- perms Permutation indices. Only used to set the ``seed``. data Array of shape `(n_cells, n_genes)`. mean Array of shape `(n_genes, n_clusters)` representing mean expression per cluster. mask Array of shape `(n_genes, n_clusters)` containing `True` if the a gene within a cluster is expressed at least in ``threshold`` percentage of cells. interactions Array of shape `(n_interactions, 2)`. interaction_clusters Array of shape `(n_interaction_clusters, 2)`. clustering Array of shape `(n_cells,)` containing the original clustering. seed Random seed for :class:`numpy.random.RandomState`. numba_parallel Whether to use :class:`numba.prange` or not. If `None`, it's determined automatically. queue Signalling queue to update progress bar. Returns ------- Tuple of the following format: - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the true test statistic. It is `None` if ``min(perms)!=0`` so that only 1 worker calculates it. - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing `np.sum(T0 > T)` where `T0` is the test statistic under null hypothesis and `T` is the true test statistic. """ rs = np.random.RandomState(None if seed is None else perms[0] + seed) clustering = clustering.copy() n_cls = mean.shape[1] return_means = np.min(perms) == 0 # ideally, these would be both sparse array, but there is no numba impl. (sparse.COO is read-only and very limited) # keep it f64, because we're setting NaN res = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) numba_parallel = ( (np.prod(res.shape) >= 2**20 or clustering.shape[0] >= 2**15) if numba_parallel is None else numba_parallel # type: ignore[assignment] ) fn_key = f"_test_{n_cls}_{int(return_means)}_{bool(numba_parallel)}" if fn_key not in globals(): exec( compile(_create_template(n_cls, return_means=return_means, parallel=numba_parallel), "", "exec"), globals() # type: ignore[arg-type] ) _test = globals()[fn_key] if return_means: res_means: NDArrayA | None = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) test = partial(_test, res_means=res_means) else: res_means = None test = _test for _ in perms: rs.shuffle(clustering) test(interactions, interaction_clusters, data, clustering, mean, mask, res=res) if queue is not None: queue.put(Signal.UPDATE) if queue is not None: queue.put(Signal.FINISH) return TempResult(means=res_means, pvalues=res)