Source code for squidpy.gr._sepal

from __future__ import annotations

from collections.abc import Sequence
from typing import Literal

import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit
from scanpy import logging as logg
from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix
from sklearn.metrics import pairwise_distances
from spatialdata import SpatialData

from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, _get_n_cores, deprecated_params, thread_map
from squidpy._validators import assert_non_empty_sequence
from squidpy.gr._utils import (
    _assert_connectivity_key,
    _assert_spatial_basis,
    _extract_expression,
    _save_data,
    extract_adata_if_sdata,
)

__all__ = ["sepal"]


[docs] @d.dedent @inject_docs(key=Key.obsp.spatial_conn()) @deprecated_params({"backend": "1.10.0"}) def sepal( adata: AnnData | SpatialData, max_neighs: Literal[4, 6], genes: str | Sequence[str] | None = None, n_iter: int | None = 30000, dt: float = 0.001, thresh: float = 1e-8, connectivity_key: str = Key.obsp.spatial_conn(), spatial_key: str = Key.obsm.spatial, layer: str | None = None, use_raw: bool = False, copy: bool = False, n_jobs: int | None = None, show_progress_bar: bool = True, *, table_key: str | None = None, ) -> pd.DataFrame | None: """ Identify spatially variable genes with *Sepal*. *Sepal* is a method that simulates a diffusion process to quantify spatial structure in tissue. See :cite:`andersson2021` for reference. Parameters ---------- %(adata)s %(table_key)s max_neighs Maximum number of neighbors of a node in the graph. Valid options are: - `4` - for a square-grid (ST, Dbit-seq). - `6` - for a hexagonal-grid (Visium). genes List of gene names, as stored in :attr:`anndata.AnnData.var_names`, used to compute sepal score. If `None`, it's computed :attr:`anndata.AnnData.var` ``['highly_variable']``, if present. Otherwise, it's computed for all genes. n_iter Maximum number of iterations for the diffusion simulation. If ``n_iter`` iterations are reached, the simulation will terminate even though convergence has not been achieved. dt Time step in diffusion simulation. thresh Entropy threshold for convergence of diffusion simulation. %(conn_key)s %(spatial_key)s layer Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`. use_raw Whether to access :attr:`anndata.AnnData.raw`. %(copy)s %(n_jobs)s %(show_progress_bar)s Returns ------- If ``copy = True``, returns a :class:`pandas.DataFrame` with the sepal scores. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['sepal_score']`` - the sepal scores. Notes ----- If some genes in :attr:`anndata.AnnData.uns` ``['sepal_score']`` are `NaN`, consider re-running the function with increased ``n_iter``. """ adata = extract_adata_if_sdata(adata, table_key=table_key) _assert_connectivity_key(adata, connectivity_key) _assert_spatial_basis(adata, key=spatial_key) if max_neighs not in (4, 6): raise ValueError(f"Expected `max_neighs` to be either `4` or `6`, found `{max_neighs}`.") spatial = adata.obsm[spatial_key].astype(np.float64) if genes is None: genes = adata.var_names.values if "highly_variable" in adata.var.columns: genes = genes[adata.var["highly_variable"].values] genes = assert_non_empty_sequence(genes, name="genes") n_jobs = _get_n_cores(n_jobs) g = adata.obsp[connectivity_key] if not isspmatrix_csr(g): g = csr_matrix(g) g.eliminate_zeros() max_n = np.diff(g.indptr).max() if max_n != max_neighs: raise ValueError(f"Expected `max_neighs={max_neighs}`, found node with `{max_n}` neighbors.") # get saturated/unsaturated nodes sat, sat_idx, unsat, unsat_idx = _compute_idxs(g, spatial, max_neighs, "l1") # get counts vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer) start = logg.info(f"Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)") use_hex = max_neighs == 6 if issparse(vals): vals = csc_matrix(vals) score = _diffusion_genes( vals, use_hex, n_iter, sat, sat_idx, unsat, unsat_idx, dt, thresh, n_jobs=n_jobs, show_progress_bar=show_progress_bar, ) key_added = "sepal_score" sepal_score = pd.DataFrame(score, index=genes, columns=[key_added]) if sepal_score[key_added].isna().any(): logg.warning("Found `NaN` in sepal scores, consider increasing `n_iter` to a higher value") sepal_score = sepal_score.sort_values(by=key_added, ascending=False) if copy: logg.info("Finish", time=start) return sepal_score _save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start)
def _diffusion_genes( vals: NDArrayA | spmatrix, use_hex: bool, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, dt: float, thresh: float, n_jobs: int, show_progress_bar: bool = True, ) -> NDArrayA: """Run diffusion for each gene column, parallelised across threads.""" sparse = issparse(vals) def _process_gene(i: int) -> float: if sparse: conc = np.ascontiguousarray(vals[:, i].toarray().ravel(), dtype=np.float64) else: conc = np.ascontiguousarray(vals[:, i], dtype=np.float64) time_iter = _diffusion( conc, use_hex, n_iter, sat, sat_idx, unsat, unsat_idx, dt, thresh, ) return dt * time_iter scores = thread_map( _process_gene, range(vals.shape[1]), n_jobs=n_jobs, show_progress_bar=show_progress_bar, unit="gene", ) return np.array(scores) @njit(fastmath=True, nogil=True) def _diffusion( conc: NDArrayA, use_hex: bool, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, dt: float, thresh: float, ) -> float: """Simulate diffusion process on a regular graph.""" sat_shape = sat.shape[0] n_cells = conc.shape[0] entropy_arr = np.zeros(n_iter) nhood = np.zeros(sat_shape) dcdt = np.zeros(n_cells) prev_ent = 1.0 for i in range(n_iter): for j in range(sat_shape): nhood[j] = np.sum(conc[sat_idx[j]]) if use_hex: d2 = _laplacian_hex(conc[sat], nhood) else: d2 = _laplacian_rect(conc[sat], nhood) dcdt[:] = 0.0 dcdt[sat] = d2 conc[sat] += dcdt[sat] * dt conc[unsat] += dcdt[unsat_idx] * dt # set values below zero to 0 conc[conc < 0] = 0 # compute entropy ent = _entropy(conc[sat]) / sat_shape entropy_arr[i] = np.abs(ent - prev_ent) # estimate entropy difference prev_ent = ent if entropy_arr[i] <= thresh: break tmp = np.nonzero(entropy_arr <= thresh)[0] return float(tmp[0] if len(tmp) else np.nan) # taken from https://github.com/almaan/sepal/blob/master/sepal/models.py @njit(parallel=False, fastmath=True) def _laplacian_rect( centers: NDArrayA, nbrs: NDArrayA, ) -> NDArrayA: """ Five point stencil approximation on rectilinear grid. See `Wikipedia <https://en.wikipedia.org/wiki/Five-point_stencil>`_ for more information. """ d2f: NDArrayA = nbrs - 4 * centers return d2f # taken from https://github.com/almaan/sepal/blob/master/sepal/models.py @njit(fastmath=True) def _laplacian_hex( centers: NDArrayA, nbrs: NDArrayA, ) -> NDArrayA: """ Seven point stencil approximation on hexagonal grid. References ---------- Approximate Methods of Higher Analysis, Curtis D. Benster, L.V. Kantorovich, V.I. Krylov, ISBN-13: 978-0486821603. """ d2f: NDArrayA = (2.0 * nbrs - 12.0 * centers) / 3.0 return d2f # taken from https://github.com/almaan/sepal/blob/master/sepal/models.py @njit(fastmath=True) def _entropy( xx: NDArrayA, ) -> float: """Compute Shannon entropy of an array of probability values (in nats).""" xnz = xx[xx > 0] xs: np.float64 = np.sum(xnz) eps = np.finfo(np.float64).eps # ~2.22e-16 if xs < eps: # 0 because # xn represents probabilities # and p(x)=0 is taken as 0 entropy # see https://stats.stackexchange.com/a/433096 return 0.0 xn = xnz / xs xl = np.log(np.maximum(xn, eps)) return float((-xl * xn).sum()) def _compute_idxs( g: spmatrix, spatial: NDArrayA, sat_thresh: int, metric: str = "l1" ) -> tuple[NDArrayA, NDArrayA, NDArrayA, NDArrayA]: """Get saturated and unsaturated nodes and neighborhood indices.""" sat, unsat = _get_sat_unsat_idx(g.indptr, g.shape[0], sat_thresh) sat_idx, nearest_sat, un_unsat = _get_nhood_idx(sat, unsat, g.indptr, g.indices, sat_thresh) # compute dist btwn remaining unsat and all sat dist = pairwise_distances(spatial[un_unsat], spatial[sat], metric=metric) # assign closest sat to remaining nearest_sat nearest_sat[np.isnan(nearest_sat)] = sat[np.argmin(dist, axis=1)] return sat, sat_idx, unsat, nearest_sat.astype(np.int32) @njit def _get_sat_unsat_idx(g_indptr: NDArrayA, g_shape: int, sat_thresh: int) -> tuple[NDArrayA, NDArrayA]: """Get saturated and unsaturated nodes based on thresh.""" n_indices = np.diff(g_indptr) unsat = np.arange(g_shape)[n_indices < sat_thresh] sat = np.arange(g_shape)[n_indices == sat_thresh] return sat, unsat @njit def _get_nhood_idx( sat: NDArrayA, unsat: NDArrayA, g_indptr: NDArrayA, g_indices: NDArrayA, sat_thresh: int, ) -> tuple[NDArrayA, NDArrayA, NDArrayA]: """Get saturated and unsaturated neighborhood indices.""" # get saturated nhood indices sat_idx = np.zeros((sat.shape[0], sat_thresh)) for idx in range(sat.shape[0]): i = sat[idx] sat_idx[idx] = g_indices[g_indptr[i] : g_indptr[i + 1]] # get closest saturated of unsaturated nearest_sat = np.full_like(unsat, fill_value=np.nan, dtype=np.float64) for idx in range(unsat.shape[0]): i = unsat[idx] unsat_neigh = g_indices[g_indptr[i] : g_indptr[i + 1]] for u in unsat_neigh: if u in sat: # take the first saturated nhood nearest_sat[idx] = u break # some unsat still don't have a sat nhood # return them and compute distances in outer func un_unsat = unsat[np.isnan(nearest_sat)] return sat_idx.astype(np.int32), nearest_sat, un_unsat