from __future__ import annotations

from import Sequence
from typing import Callable, Literal

import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit
from scanpy import logging as logg
from scipy.sparse import csr_matrix, issparse, isspmatrix_csr, spmatrix
from sklearn.metrics import pairwise_distances
from spatialdata import SpatialData

from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
from import (

__all__ = ["sepal"]

[docs] @d.dedent @inject_docs(key=Key.obsp.spatial_conn()) def sepal( adata: AnnData | SpatialData, max_neighs: Literal[4, 6], genes: str | Sequence[str] | None = None, n_iter: int | None = 30000, dt: float = 0.001, thresh: float = 1e-8, connectivity_key: str = Key.obsp.spatial_conn(), spatial_key: str = Key.obsm.spatial, layer: str | None = None, use_raw: bool = False, copy: bool = False, n_jobs: int | None = None, backend: str = "loky", show_progress_bar: bool = True, ) -> pd.DataFrame | None: """ Identify spatially variable genes with *Sepal*. *Sepal* is a method that simulates a diffusion process to quantify spatial structure in tissue. See :cite:`andersson2021` for reference. Parameters ---------- %(adata)s max_neighs Maximum number of neighbors of a node in the graph. Valid options are: - `4` - for a square-grid (ST, Dbit-seq). - `6` - for a hexagonal-grid (Visium). genes List of gene names, as stored in :attr:`anndata.AnnData.var_names`, used to compute sepal score. If `None`, it's computed :attr:`anndata.AnnData.var` ``['highly_variable']``, if present. Otherwise, it's computed for all genes. n_iter Maximum number of iterations for the diffusion simulation. If ``n_iter`` iterations are reached, the simulation will terminate even though convergence has not been achieved. dt Time step in diffusion simulation. thresh Entropy threshold for convergence of diffusion simulation. %(conn_key)s %(spatial_key)s layer Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`. use_raw Whether to access :attr:`anndata.AnnData.raw`. %(copy)s %(parallelize)s Returns ------- If ``copy = True``, returns a :class:`pandas.DataFrame` with the sepal scores. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['sepal_score']`` - the sepal scores. Notes ----- If some genes in :attr:`anndata.AnnData.uns` ``['sepal_score']`` are `NaN`, consider re-running the function with increased ``n_iter``. """ if isinstance(adata, SpatialData): adata = adata.table _assert_connectivity_key(adata, connectivity_key) _assert_spatial_basis(adata, key=spatial_key) if max_neighs not in (4, 6): raise ValueError(f"Expected `max_neighs` to be either `4` or `6`, found `{max_neighs}`.") spatial = adata.obsm[spatial_key].astype(np.float_) if genes is None: genes = adata.var_names.values if "highly_variable" in adata.var.columns: genes = genes[adata.var["highly_variable"].values] genes = _assert_non_empty_sequence(genes, name="genes") n_jobs = _get_n_cores(n_jobs) g = adata.obsp[connectivity_key] if not isspmatrix_csr(g): g = csr_matrix(g) g.eliminate_zeros() max_n = np.diff(g.indptr).max() if max_n != max_neighs: raise ValueError(f"Expected `max_neighs={max_neighs}`, found node with `{max_n}` neighbors.") # get saturated/unsaturated nodes sat, sat_idx, unsat, unsat_idx = _compute_idxs(g, spatial, max_neighs, "l1") # get counts vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer) start ="Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)") score = parallelize( _score_helper, collection=np.arange(len(genes)).tolist(), extractor=np.hstack, use_ixs=False, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )( vals=vals, max_neighs=max_neighs, n_iter=n_iter, sat=sat, sat_idx=sat_idx, unsat=unsat, unsat_idx=unsat_idx, dt=dt, thresh=thresh, ) key_added = "sepal_score" sepal_score = pd.DataFrame(score, index=genes, columns=[key_added]) if sepal_score[key_added].isna().any(): logg.warning("Found `NaN` in sepal scores, consider increasing `n_iter` to a higher value") sepal_score.sort_values(by=key_added, ascending=False, inplace=True) if copy:"Finish", time=start) return sepal_score _save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start)
def _score_helper( ixs: Sequence[int], vals: spmatrix | NDArrayA, max_neighs: int, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, dt: np.float_, thresh: np.float_, queue: SigQueue | None = None, ) -> NDArrayA: if max_neighs == 4: fun = _laplacian_rect elif max_neighs == 6: fun = _laplacian_hex else: raise NotImplementedError(f"Laplacian for `{max_neighs}` neighbors is not yet implemented.") score, sparse = [], issparse(vals) for i in ixs: conc = vals[:, i].A.flatten() if sparse else vals[:, i].copy() # type: ignore[union-attr] conc = vals[:, i].A.flatten() if sparse else vals[:, i].copy() # type: ignore[union-attr] time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh) score.append(dt * time_iter) if queue is not None: queue.put(Signal.UPDATE) if queue is not None: queue.put(Signal.FINISH) return np.array(score) @njit(fastmath=True) def _diffusion( conc: NDArrayA, laplacian: Callable[[NDArrayA, NDArrayA, NDArrayA], np.float_], n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, dt: float = 0.001, D: float = 1.0, thresh: float = 1e-8, ) -> float: """Simulate diffusion process on a regular graph.""" sat_shape, conc_shape = sat.shape[0], conc.shape[0] entropy_arr = np.zeros(n_iter) prev_ent = 1.0 nhood = np.zeros(sat_shape) weights = np.ones(sat_shape) for i in range(n_iter): for j in range(sat_shape): nhood[j] = np.sum(conc[sat_idx[j]]) d2 = laplacian(conc[sat], nhood, weights) dcdt = np.zeros(conc_shape) dcdt[sat] = D * d2 conc[sat] += dcdt[sat] * dt conc[unsat] += dcdt[unsat_idx] * dt # set values below zero to 0 conc[conc < 0] = 0 # compute entropy ent = _entropy(conc[sat]) / sat_shape entropy_arr[i] = np.abs(ent - prev_ent) # estimate entropy difference prev_ent = ent if entropy_arr[i] <= thresh: break tmp = np.nonzero(entropy_arr <= thresh)[0] return float(tmp[0] if len(tmp) else np.nan) # taken from @njit(parallel=False, fastmath=True) def _laplacian_rect( centers: NDArrayA, nbrs: NDArrayA, h: float, ) -> NDArrayA: """ Five point stencil approximation on rectilinear grid. See `Wikipedia <>`_ for more information. """ d2f: NDArrayA = nbrs - 4 * centers d2f = d2f / h**2 return d2f # taken from @njit(fastmath=True) def _laplacian_hex( centers: NDArrayA, nbrs: NDArrayA, h: float, ) -> NDArrayA: """ Seven point stencil approximation on hexagonal grid. References ---------- Approximate Methods of Higher Analysis, Curtis D. Benster, L.V. Kantorovich, V.I. Krylov, ISBN-13: 978-0486821603. """ d2f: NDArrayA = nbrs - 6 * centers d2f = d2f / h**2 d2f = (d2f * 2) / 3 return d2f # taken from @njit(fastmath=True) def _entropy( xx: NDArrayA, ) -> float: """Get entropy of an array.""" xnz = xx[xx > 0] xs = np.sum(xnz) xn = xnz / xs xl = np.log(xn) return float((-xl * xn).sum()) def _compute_idxs( g: spmatrix, spatial: NDArrayA, sat_thresh: int, metric: str = "l1" ) -> tuple[NDArrayA, NDArrayA, NDArrayA, NDArrayA]: """Get saturated and unsaturated nodes and neighborhood indices.""" sat, unsat = _get_sat_unsat_idx(g.indptr, g.shape[0], sat_thresh) sat_idx, nearest_sat, un_unsat = _get_nhood_idx(sat, unsat, g.indptr, g.indices, sat_thresh) # compute dist btwn remaining unsat and all sat dist = pairwise_distances(spatial[un_unsat], spatial[sat], metric=metric) # assign closest sat to remaining nearest_sat nearest_sat[np.isnan(nearest_sat)] = sat[np.argmin(dist, axis=1)] return sat, sat_idx, unsat, nearest_sat.astype(np.int32) @njit def _get_sat_unsat_idx(g_indptr: NDArrayA, g_shape: int, sat_thresh: int) -> tuple[NDArrayA, NDArrayA]: """Get saturated and unsaturated nodes based on thresh.""" n_indices = np.diff(g_indptr) unsat = np.arange(g_shape)[n_indices < sat_thresh] sat = np.arange(g_shape)[n_indices == sat_thresh] return sat, unsat @njit def _get_nhood_idx( sat: NDArrayA, unsat: NDArrayA, g_indptr: NDArrayA, g_indices: NDArrayA, sat_thresh: int ) -> tuple[NDArrayA, NDArrayA, NDArrayA]: """Get saturated and unsaturated neighborhood indices.""" # get saturated nhood indices sat_idx = np.zeros((sat.shape[0], sat_thresh)) for idx in range(sat.shape[0]): i = sat[idx] sat_idx[idx] = g_indices[g_indptr[i] : g_indptr[i + 1]] # get closest saturated of unsaturated nearest_sat = np.full_like(unsat, fill_value=np.nan, dtype=np.float64) for idx in range(unsat.shape[0]): i = unsat[idx] unsat_neigh = g_indices[g_indptr[i] : g_indptr[i + 1]] for u in unsat_neigh: if u in sat: # take the first saturated nhood nearest_sat[idx] = u break # some unsat still don't have a sat nhood # return them and compute distances in outer func un_unsat = unsat[np.isnan(nearest_sat)] return sat_idx.astype(np.int32), nearest_sat, un_unsat