from __future__ import annotations

from typing import Union  # noqa: F401
from typing import Literal  # < 3.8
from typing import Callable, Sequence

from scanpy import logging as logg
from anndata import AnnData

from numba import njit
from scipy.sparse import issparse, spmatrix, csr_matrix, isspmatrix_csr
from sklearn.metrics import pairwise_distances
import numpy as np
import pandas as pd

from squidpy._docs import d, inject_docs
from squidpy._utils import Signal, NDArrayA, SigQueue, parallelize, _get_n_cores
from import (
from squidpy._constants._pkg_constants import Key

__all__ = ["sepal"]

[docs]@d.dedent @inject_docs(key=Key.obsp.spatial_conn()) def sepal( adata: AnnData, max_neighs: Literal[4, 6], genes: str | Sequence[str] | None = None, n_iter: int | None = 30000, dt: float = 0.001, thresh: float = 1e-8, connectivity_key: str = Key.obsp.spatial_conn(), spatial_key: str = Key.obsm.spatial, layer: str | None = None, use_raw: bool = False, copy: bool = False, n_jobs: int | None = None, backend: str = "loky", show_progress_bar: bool = True, ) -> pd.DataFrame | None: """ Identify spatially variable genes with *Sepal*. *Sepal* is a method that simulates a diffusion process to quantify spatial structure in tissue. See :cite:`andersson2021` for reference. Parameters ---------- %(adata)s max_neighs Maximum number of neighbors of a node in the graph. Valid options are: - `4` - for a square-grid (ST, Dbit-seq). - `6` - for a hexagonal-grid (Visium). genes List of gene names, as stored in :attr:`anndata.AnnData.var_names`, used to compute sepal score. If `None`, it's computed :attr:`anndata.AnnData.var` ``['highly_variable']``, if present. Otherwise, it's computed for all genes. n_iter Maximum number of iterations for the diffusion simulation. If ``n_iter`` iterations are reached, the simulation will terminate even though convergence has not been achieved. dt Time step in diffusion simulation. thresh Entropy threshold for convergence of diffusion simulation. %(conn_key)s %(spatial_key)s layer Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`. use_raw Whether to access :attr:`anndata.AnnData.raw`. %(copy)s %(parallelize)s Returns ------- If ``copy = True``, returns a :class:`pandas.DataFrame` with the sepal scores. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['sepal_score']`` - the sepal scores. Notes ----- If some genes in :attr:`anndata.AnnData.uns` ``['sepal_score']`` are `NaN`, consider re-running the function with increased ``n_iter``. """ _assert_connectivity_key(adata, connectivity_key) _assert_spatial_basis(adata, key=spatial_key) if max_neighs not in (4, 6): raise ValueError(f"Expected `max_neighs` to be either `4` or `6`, found `{max_neighs}`.") spatial = adata.obsm[spatial_key].astype(np.float_) if genes is None: genes = adata.var_names.values if "highly_variable" in adata.var.columns: genes = genes[adata.var["highly_variable"].values] genes = _assert_non_empty_sequence(genes, name="genes") n_jobs = _get_n_cores(n_jobs) g = adata.obsp[connectivity_key] if not isspmatrix_csr(g): g = csr_matrix(g) g.eliminate_zeros() max_n = np.diff(g.indptr).max() if max_n != max_neighs: raise ValueError(f"Expected `max_neighs={max_neighs}`, found node with `{max_n}` neighbors.") # get saturated/unsaturated nodes sat, sat_idx, unsat, unsat_idx = _compute_idxs(g, spatial, max_neighs, "l1") # get counts vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer) start ="Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)") score = parallelize( _score_helper, collection=np.arange(len(genes)), extractor=np.hstack, use_ixs=False, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )( vals=vals, max_neighs=max_neighs, n_iter=n_iter, sat=sat, sat_idx=sat_idx, unsat=unsat, unsat_idx=unsat_idx, dt=dt, thresh=thresh, ) key_added = "sepal_score" sepal_score = pd.DataFrame(score, index=genes, columns=[key_added]) if sepal_score[key_added].isna().any(): logg.warning("Found `NaN` in sepal scores, consider increasing `n_iter` to a higher value") sepal_score.sort_values(by=key_added, ascending=False, inplace=True) if copy:"Finish", time=start) return sepal_score _save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start)
def _score_helper( ixs: Sequence[int], vals: spmatrix | NDArrayA, max_neighs: int, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, dt: np.float_, thresh: np.float_, queue: SigQueue | None = None, ) -> NDArrayA: if max_neighs == 4: fun = _laplacian_rect elif max_neighs == 6: fun = _laplacian_hex else: raise NotImplementedError(f"Laplacian for `{max_neighs}` neighbors is not yet implemented.") score, sparse = [], issparse(vals) for i in ixs: conc = vals[:, i].A.flatten() if sparse else vals[:, i].copy() time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh) score.append(dt * time_iter) if queue is not None: queue.put(Signal.UPDATE) if queue is not None: queue.put(Signal.FINISH) return np.array(score) @njit(fastmath=True) def _diffusion( conc: NDArrayA, laplacian: Callable[[NDArrayA, NDArrayA, NDArrayA], np.float_], n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, dt: float = 0.001, D: float = 1.0, thresh: float = 1e-8, ) -> float: """Simulate diffusion process on a regular graph.""" sat_shape, conc_shape = sat.shape[0], conc.shape[0] entropy_arr = np.zeros(n_iter) prev_ent = 1.0 nhood = np.zeros(sat_shape) weights = np.ones(sat_shape) for i in range(n_iter): for j in range(sat_shape): nhood[j] = np.sum(conc[sat_idx[j]]) d2 = laplacian(conc[sat], nhood, weights) dcdt = np.zeros(conc_shape) dcdt[sat] = D * d2 conc[sat] += dcdt[sat] * dt conc[unsat] += dcdt[unsat_idx] * dt # set values below zero to 0 conc[conc < 0] = 0 # compute entropy ent = _entropy(conc[sat]) / sat_shape entropy_arr[i] = np.abs(ent - prev_ent) # estimate entropy difference prev_ent = ent if entropy_arr[i] <= thresh: break tmp = np.nonzero(entropy_arr <= thresh)[0] return float(tmp[0] if len(tmp) else np.nan) # taken from @njit(parallel=False, fastmath=True) def _laplacian_rect( centers: NDArrayA, nbrs: NDArrayA, h: float, ) -> NDArrayA: """ Five point stencil approximation on rectilinear grid. See `Wikipedia <>`_ for more information. """ d2f: NDArrayA = nbrs - 4 * centers d2f = d2f / h**2 return d2f # taken from @njit(fastmath=True) def _laplacian_hex( centers: NDArrayA, nbrs: NDArrayA, h: float, ) -> NDArrayA: """ Seven point stencil approximation on hexagonal grid. References ---------- Approximate Methods of Higher Analysis, Curtis D. Benster, L.V. Kantorovich, V.I. Krylov, ISBN-13: 978-0486821603. """ d2f: NDArrayA = nbrs - 6 * centers d2f = d2f / h**2 d2f = (d2f * 2) / 3 # type: ignore[assignment] return d2f # taken from @njit(fastmath=True) def _entropy( xx: NDArrayA, ) -> float: """Get entropy of an array.""" xnz = xx[xx > 0] xs = np.sum(xnz) xn = xnz / xs xl = np.log(xn) return float((-xl * xn).sum()) def _compute_idxs( g: spmatrix, spatial: NDArrayA, sat_thresh: int, metric: str = "l1" ) -> tuple[NDArrayA, NDArrayA, NDArrayA, NDArrayA]: """Get saturated and unsaturated nodes and neighborhood indices.""" sat, unsat = _get_sat_unsat_idx(g.indptr, g.shape[0], sat_thresh) sat_idx, nearest_sat, un_unsat = _get_nhood_idx(sat, unsat, g.indptr, g.indices, sat_thresh) # compute dist btwn remaining unsat and all sat dist = pairwise_distances(spatial[un_unsat], spatial[sat], metric=metric) # assign closest sat to remaining nearest_sat nearest_sat[np.isnan(nearest_sat)] = sat[np.argmin(dist, axis=1)] return sat, sat_idx, unsat, nearest_sat.astype(np.int32) @njit def _get_sat_unsat_idx(g_indptr: NDArrayA, g_shape: int, sat_thresh: int) -> tuple[NDArrayA, NDArrayA]: """Get saturated and unsaturated nodes based on thresh.""" n_indices = np.diff(g_indptr) unsat = np.arange(g_shape)[n_indices < sat_thresh] sat = np.arange(g_shape)[n_indices == sat_thresh] return sat, unsat @njit def _get_nhood_idx( sat: NDArrayA, unsat: NDArrayA, g_indptr: NDArrayA, g_indices: NDArrayA, sat_thresh: int ) -> tuple[NDArrayA, NDArrayA, NDArrayA]: """Get saturated and unsaturated neighborhood indices.""" # get saturated nhood indices sat_idx = np.zeros((sat.shape[0], sat_thresh)) for idx in range(sat.shape[0]): i = sat[idx] sat_idx[idx] = g_indices[g_indptr[i] : g_indptr[i + 1]] # get closest saturated of unsaturated nearest_sat = np.full_like(unsat, fill_value=np.nan, dtype=np.float64) for idx in range(unsat.shape[0]): i = unsat[idx] unsat_neigh = g_indices[g_indptr[i] : g_indptr[i + 1]] for u in unsat_neigh: if u in sat: # take the first saturated nhood nearest_sat[idx] = u break # some unsat still don't have a sat nhood # return them and compute distances in outer func un_unsat = unsat[np.isnan(nearest_sat)] return sat_idx.astype(np.int32), nearest_sat, un_unsat