Source code for squidpy.gr.neighbors

"""Graph construction strategies for spatial neighbor graphs.

See the :doc:`/extensibility` guide for how to implement a custom builder.
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any, Generic, TypeVar, cast

import numpy as np
from fast_array_utils import stats as fau_stats
from numba import njit, prange
from scipy.sparse import (
    SparseEfficiencyWarning,
    block_diag,
    csr_array,
    csr_matrix,
    isspmatrix_csr,
    spmatrix,
)
from scipy.spatial import Delaunay
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors

from squidpy._constants._constants import CoordType, Transform
from squidpy._utils import NDArrayA
from squidpy._validators import assert_positive

__all__ = [
    "GraphMatrixT",
    "GraphBuilder",
    "GraphBuilderCSR",
    "GraphPostprocessor",
    "DistanceIntervalPostprocessor",
    "PercentilePostprocessor",
    "TransformPostprocessor",
    "KNNBuilder",
    "RadiusBuilder",
    "DelaunayBuilder",
    "GridBuilder",
]


CoordT = TypeVar("CoordT")
GraphMatrixT = TypeVar("GraphMatrixT")
GraphPostprocessor = Callable[[GraphMatrixT, GraphMatrixT], tuple[GraphMatrixT, GraphMatrixT]]


[docs] class GraphBuilder(ABC, Generic[CoordT, GraphMatrixT]): """Base class for spatial graph construction strategies. Custom builders must implement :meth:`build_graph`. Overriding :meth:`postprocessors` and :meth:`combine` is optional. Postprocessors can be provided directly via ``__init__`` or by overriding :meth:`postprocessors`. """ def __init__( self, transform: str | Transform | None = None, set_diag: bool = False, percentile: float | None = None, postprocessors: Sequence[GraphPostprocessor[GraphMatrixT]] = (), ) -> None: self.transform = Transform.NONE if transform is None else Transform(transform) self.set_diag = set_diag self.percentile = percentile self._postprocessors: list[GraphPostprocessor[GraphMatrixT]] = list(postprocessors)
[docs] def build(self, coords: CoordT) -> tuple[GraphMatrixT, GraphMatrixT]: adj, dst = self.build_graph(coords) for postprocessor in self.postprocessors(): adj, dst = postprocessor(adj, dst) return adj, dst
[docs] @abstractmethod def build_graph(self, coords: CoordT) -> tuple[GraphMatrixT, GraphMatrixT]: """Construct raw adjacency and distance matrices."""
[docs] def postprocessors(self) -> Sequence[GraphPostprocessor[GraphMatrixT]]: """Return post-build processing steps for ``(adj, dst)``.""" return self._postprocessors
[docs] @abstractmethod def uns_params(self) -> dict[str, Any]: """Parameters stored in :attr:`anndata.AnnData.uns` after graph construction."""
[docs] def combine( self, mats: Sequence[tuple[GraphMatrixT, GraphMatrixT]], ixs: Sequence[int], ) -> tuple[GraphMatrixT, GraphMatrixT]: """Combine per-library results into a single graph. Override this only if the builder should support multi-library graph construction via ``library_key``. """ raise NotImplementedError("Using `library_key` with this graph builder is not implemented yet.")
[docs] class GraphBuilderCSR(GraphBuilder[NDArrayA, csr_matrix], ABC): """CSR-based graph construction strategy. Specializes :class:`GraphBuilder` for sparse CSR matrix output. Adds SparseEfficiencyWarning suppression and multi-library ``library_key`` combination. Built-in concrete builders (:class:`KNNBuilder`, :class:`RadiusBuilder`, :class:`DelaunayBuilder`, :class:`GridBuilder`) inherit from this class and declare their postprocessors explicitly in ``__init__`` using the reusable public postprocessor classes. Subclass this (not the generic :class:`GraphBuilder`) when implementing a builder that returns CSR matrices. See Also -------- GraphBuilder : Generic builder interface for custom coordinate/matrix types. KNNBuilder : Example of a concrete CSR-based builder. """
[docs] def build(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: with warnings.catch_warnings(): warnings.simplefilter("ignore", SparseEfficiencyWarning) return super().build(coords)
[docs] @abstractmethod def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: """Construct raw adjacency and distance matrices."""
[docs] def combine( self, mats: Sequence[tuple[csr_matrix, csr_matrix]], ixs: Sequence[int], ) -> tuple[csr_matrix, csr_matrix]: adj = block_diag([m[0] for m in mats], format="csr") dst = block_diag([m[1] for m in mats], format="csr") # ``block_diag`` stacks the per-library blocks in library order. Only when # libraries are interleaved in the original observation order do we need to # permute rows/columns back. Skipping this reordering when ``ixs`` is already # sorted (the common case of contiguous libraries) avoids two full fancy-index # copies of a potentially very large sparse matrix. ixs_arr = np.asarray(ixs) if ixs_arr.size and np.any(np.diff(ixs_arr) < 0): order = np.argsort(ixs_arr) adj = adj[order, :][:, order] dst = dst[order, :][:, order] return cast(csr_matrix, adj), cast(csr_matrix, dst)
[docs] class KNNBuilder(GraphBuilderCSR): """Build a generic k-nearest-neighbor spatial graph. Each observation is connected to its k nearest neighbors. See :func:`~squidpy.gr.spatial_neighbors_knn` for the user-facing API or :func:`~squidpy.gr.spatial_neighbors_from_builder` for direct builder usage. """ def __init__( self, n_neighs: int = 6, transform: str | Transform | None = None, set_diag: bool = False, percentile: float | None = None, ) -> None: assert_positive(n_neighs, name="n_neighs") postprocessors: list[GraphPostprocessor[csr_matrix]] = [] if percentile is not None: postprocessors.append(PercentilePostprocessor(percentile)) postprocessors.append(TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))) super().__init__( transform=transform, set_diag=set_diag, percentile=percentile, postprocessors=postprocessors, ) self.n_neighs = n_neighs
[docs] def uns_params(self) -> dict[str, Any]: return { "coord_type": CoordType.GENERIC.v, "n_neighbors": self.n_neighs, "transform": self.transform.v, }
[docs] def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: N = coords.shape[0] tree = NearestNeighbors(n_neighbors=self.n_neighs, radius=1, metric="euclidean") tree.fit(coords) dists, col_indices = tree.kneighbors() dists, col_indices = dists.reshape(-1), col_indices.reshape(-1) row_indices = np.repeat(np.arange(N), self.n_neighs) adj = csr_matrix( (np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)), shape=(N, N), ) dst = csr_matrix((dists, (row_indices, col_indices)), shape=(N, N)) adj.setdiag(1.0 if self.set_diag else adj.diagonal()) dst.setdiag(0.0) return adj, dst
[docs] class RadiusBuilder(GraphBuilderCSR): """Build a generic radius-based spatial graph. Two observations are connected when their Euclidean distance falls within the specified radius. See :func:`~squidpy.gr.spatial_neighbors_radius` for the user-facing API or :func:`~squidpy.gr.spatial_neighbors_from_builder` for direct builder usage. """ def __init__( self, radius: float | tuple[float, float], transform: str | Transform | None = None, set_diag: bool = False, percentile: float | None = None, ) -> None: postprocessors: list[GraphPostprocessor[csr_matrix]] = [] if isinstance(radius, tuple): postprocessors.append(DistanceIntervalPostprocessor(tuple(sorted(radius)))) if percentile is not None: postprocessors.append(PercentilePostprocessor(percentile)) postprocessors.append(TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))) super().__init__( transform=transform, set_diag=set_diag, percentile=percentile, postprocessors=postprocessors, ) self.radius = radius
[docs] def uns_params(self) -> dict[str, Any]: return { "coord_type": CoordType.GENERIC.v, "radius": self.radius, "transform": self.transform.v, }
[docs] def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: N = coords.shape[0] r = self.radius if isinstance(self.radius, int | float) else max(self.radius) tree = NearestNeighbors(radius=r, metric="euclidean") tree.fit(coords) dists, col_indices = tree.radius_neighbors() row_indices = np.repeat(np.arange(N), [len(x) for x in col_indices]) dists = np.concatenate(dists) col_indices = np.concatenate(col_indices) adj = csr_matrix( (np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)), shape=(N, N), ) dst = csr_matrix((dists, (row_indices, col_indices)), shape=(N, N)) adj.setdiag(1.0 if self.set_diag else adj.diagonal()) dst.setdiag(0.0) return adj, dst
[docs] class DelaunayBuilder(GraphBuilderCSR): """Build a generic point-cloud graph from a Delaunay triangulation. Delaunay triangulation connects observations into triangles such that no other observation lies inside the circumcircle of each triangle. Unlike ``GridBuilder(delaunay=True)``, this builder uses geometry-based connectivity and stores real Euclidean edge lengths. ``radius`` only controls post-construction edge pruning; the triangulation itself is unchanged. A tuple ``(min, max)`` keeps edges with Euclidean length in that interval. A scalar ``r`` is shorthand for ``(0.0, r)``, i.e. keep edges with length at most ``r``. ``None`` keeps every edge. See :func:`~squidpy.gr.spatial_neighbors_delaunay` for the user-facing API or :func:`~squidpy.gr.spatial_neighbors_from_builder` for direct builder usage. """ def __init__( self, radius: float | tuple[float, float] | None = None, transform: str | Transform | None = None, set_diag: bool = False, percentile: float | None = None, ) -> None: if isinstance(radius, int | float): radius = (0.0, float(radius)) postprocessors: list[GraphPostprocessor[csr_matrix]] = [] if radius is not None: postprocessors.append(DistanceIntervalPostprocessor(tuple(sorted(radius)))) if percentile is not None: postprocessors.append(PercentilePostprocessor(percentile)) postprocessors.append(TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))) super().__init__( transform=transform, set_diag=set_diag, percentile=percentile, postprocessors=postprocessors, ) self.radius = radius
[docs] def uns_params(self) -> dict[str, Any]: return { "coord_type": CoordType.GENERIC.v, "radius": self.radius, "transform": self.transform.v, }
[docs] def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: N = coords.shape[0] tri = Delaunay(coords) indptr, indices = tri.vertex_neighbor_vertices adj = csr_matrix((np.ones_like(indices, dtype=np.float32), indices, indptr), shape=(N, N)) rows = np.repeat(np.arange(N), np.diff(indptr)) dists = np.linalg.norm(coords[rows] - coords[indices], axis=1) dst = csr_matrix((dists, indices, indptr), shape=(N, N)) adj.setdiag(1.0 if self.set_diag else adj.diagonal()) dst.setdiag(0.0) return adj, dst
[docs] class GridBuilder(GraphBuilderCSR): """Build a grid-based spatial graph. Assumes observations lie on an approximately regular lattice (e.g., Visium). When ``delaunay=True``, Delaunay triangulation is used only to derive the base connectivity; the distance matrix still encodes grid/ring distances, not Euclidean lengths. See :func:`~squidpy.gr.spatial_neighbors_grid` for the user-facing API or :func:`~squidpy.gr.spatial_neighbors_from_builder` for direct builder usage. """ def __init__( self, n_neighs: int = 6, n_rings: int = 1, delaunay: bool = False, transform: str | Transform | None = None, set_diag: bool = False, ) -> None: assert_positive(n_neighs, name="n_neighs") assert_positive(n_rings, name="n_rings") postprocessors = [TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))] super().__init__(transform=transform, set_diag=set_diag, percentile=None, postprocessors=postprocessors) self.n_neighs = n_neighs self.n_rings = n_rings self.delaunay = delaunay
[docs] def uns_params(self) -> dict[str, Any]: return { "coord_type": CoordType.GRID.v, "n_neighbors": self.n_neighs, "n_rings": self.n_rings, "delaunay": self.delaunay, "transform": self.transform.v, }
[docs] def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: if self.n_rings > 1: adj = self._base_adjacency(coords, set_diag=True) res, walk = adj, adj for i in range(self.n_rings - 1): walk = walk @ adj walk[res.nonzero()] = 0.0 walk.eliminate_zeros() walk.data[:] = i + 2.0 res = res + walk adj = res adj.setdiag(float(self.set_diag)) adj.eliminate_zeros() dst = adj.copy() adj.data[:] = 1.0 else: adj = self._base_adjacency(coords, set_diag=self.set_diag) dst = adj.copy() dst.setdiag(0.0) return adj, dst
def _base_adjacency(self, coords: NDArrayA, *, set_diag: bool) -> csr_matrix: """KNN adjacency with median-distance correction for grid coordinates.""" N = coords.shape[0] if self.delaunay: tri = Delaunay(coords) indptr, indices = tri.vertex_neighbor_vertices adj = csr_matrix((np.ones_like(indices, dtype=np.float32), indices, indptr), shape=(N, N)) else: tree = NearestNeighbors(n_neighbors=self.n_neighs, radius=1, metric="euclidean") tree.fit(coords) dists, col_indices = tree.kneighbors() dists, col_indices = dists.reshape(-1), col_indices.reshape(-1) row_indices = np.repeat(np.arange(N), self.n_neighs) dist_cutoff = np.median(dists) * 1.3 mask = dists < dist_cutoff row_indices, col_indices = row_indices[mask], col_indices[mask] adj = csr_matrix( (np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)), shape=(N, N), ) adj.setdiag(1.0 if set_diag else adj.diagonal()) return adj
# --------------------------------------------------------------------------- # Private helpers used by the builder classes # --------------------------------------------------------------------------- def _filter_by_radius_interval( adj: csr_matrix, dst: csr_matrix, radius: tuple[float, float], ) -> None: minn, maxx = radius mask = (dst.data < minn) | (dst.data > maxx) a_diag = adj.diagonal() dst.data[mask] = 0.0 adj.data[mask] = 0.0 adj.setdiag(a_diag)
[docs] @dataclass(frozen=True) class DistanceIntervalPostprocessor: interval: tuple[float, float]
[docs] def __call__(self, adj: csr_matrix, dst: csr_matrix) -> tuple[csr_matrix, csr_matrix]: _filter_by_radius_interval(adj, dst, self.interval) return adj, dst
[docs] @dataclass(frozen=True) class PercentilePostprocessor: percentile: float
[docs] def __call__(self, adj: csr_matrix, dst: csr_matrix) -> tuple[csr_matrix, csr_matrix]: threshold = np.percentile(dst.data, self.percentile) adj[dst > threshold] = 0.0 dst[dst > threshold] = 0.0 return adj, dst
[docs] @dataclass(frozen=True) class TransformPostprocessor: transform: Transform
[docs] def __call__(self, adj: csr_matrix, dst: csr_matrix) -> tuple[csr_matrix, csr_matrix]: adj.eliminate_zeros() dst.eliminate_zeros() if self.transform == Transform.SPECTRAL: return cast(csr_matrix, _transform_a_spectral(adj)), dst if self.transform == Transform.COSINE: return cast(csr_matrix, _transform_a_cosine(adj)), dst if self.transform == Transform.NONE: return adj, dst raise NotImplementedError(f"Transform `{self.transform}` is not yet implemented.")
@njit def _csr_bilateral_diag_scale_helper( mat: csr_array | csr_matrix, degrees: NDArrayA, ) -> NDArrayA: """ Return an array F aligned with CSR non-zeros such that F[k] = d[i] * data[k] * d[j] for the k-th non-zero (i, j) in CSR order. Parameters ---------- data : array of float CSR `data` (non-zero values). indices : array of int CSR `indices` (column indices). indptr : array of int CSR `indptr` (row pointer). degrees : array of float, shape (n,) Diagonal scaling vector. Returns ------- array of float Length equals len(data). Entry-wise factors d_i * d_j * data[k] """ res = np.empty_like(mat.data, dtype=np.float32) for i in prange(len(mat.indptr) - 1): ixs = mat.indices[mat.indptr[i] : mat.indptr[i + 1]] res[mat.indptr[i] : mat.indptr[i + 1]] = degrees[i] * degrees[ixs] * mat.data[mat.indptr[i] : mat.indptr[i + 1]] return res def symmetric_normalize_csr(adj: spmatrix) -> csr_matrix: """ Return D^{-1/2} * A * D^{-1/2}, where D = diag(degrees(A)) and A = adj. Parameters ---------- adj : scipy.sparse.csr_matrix Returns ------- scipy.sparse.csr_matrix """ degrees = np.squeeze(np.array(np.sqrt(1.0 / fau_stats.sum(adj, axis=0)))) if adj.shape[0] != len(degrees): raise ValueError("len(degrees) must equal number of rows of adj") res_data = _csr_bilateral_diag_scale_helper(adj, degrees) return csr_matrix((res_data, adj.indices, adj.indptr), shape=adj.shape) def _transform_a_spectral(a: spmatrix) -> spmatrix: if not isspmatrix_csr(a): a = a.tocsr() if not a.nnz: return a return symmetric_normalize_csr(a) def _transform_a_cosine(a: spmatrix) -> spmatrix: return cosine_similarity(a, dense_output=False)