Source code for squidpy.gr._niche

from __future__ import annotations

import contextlib
from typing import Any, Literal

import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sps
from anndata import AnnData
from numpy.typing import NDArray
from scipy.sparse import coo_matrix, hstack, issparse, spdiags
from scipy.spatial import distance
from sklearn.metrics import f1_score
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import normalize
from spatialdata import SpatialData
from spatialdata._logging import logger as logg

from squidpy._constants._constants import NicheDefinitions
from squidpy._docs import d, inject_docs
from squidpy._validators import assert_isinstance, assert_key_in_adata, assert_one_of
from squidpy.gr._utils import extract_adata_if_sdata

__all__ = ["calculate_niche"]


[docs] @d.dedent @inject_docs(fla=NicheDefinitions) def calculate_niche( data: AnnData | SpatialData, flavor: Literal["neighborhood", "utag", "cellcharter", "spatialleiden"], library_key: str | None = None, mask: pd.core.series.Series = None, groups: str | None = None, n_neighbors: int | None = None, resolutions: float | tuple[float, float] | list[float | tuple[float, float]] | None = None, min_niche_size: int | None = None, scale: bool = True, abs_nhood: bool = False, distance: int | None = None, n_hop_weights: list[float] | None = None, aggregation: str | None = None, n_components: int | None = None, random_state: int = 42, spatial_connectivities_key: str = "spatial_connectivities", latent_connectivities_key: str = "connectivities", layer_ratio: float = 1.0, n_iterations: int = -1, use_weights: bool | tuple[bool, bool] = True, use_rep: str | None = None, inplace: bool = True, *, table_key: str | None = None, ) -> AnnData | None: """ Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. The resulting niche labels with be stored in 'adata.obs'. Parameters ---------- %(adata)s flavor Method to use for niche calculation. Available options are: - `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication). - `{fla.SPATIALLEIDEN.s!r}` - cluster spatially resolved omics data using Multiplex Leiden. - `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA for dimensionality reduction. An arbitrary embedding can be used instead of PCA by setting the `use_rep` parameter which will try to find the embedding in `adata.obsm`. %(library_key)s If provided, niches will be calculated separately for each unique value in this column. Each niche will be prefixed with the library identifier. %(table_key)s mask Boolean array to filter cells which won't get assigned to a niche. Note that if you want to exclude these cells during neighborhood calculation already, you should subset your AnnData table before running 'sq.gr.spatial_neigbors'. groups Groups based on which to calculate neighborhood profile (E.g. columns of cell type annotations in adata.obs). Required if flavor == `{fla.NEIGHBORHOOD.s!r}`. n_neighbors Number of neighbors to use for 'scanpy.pp.neighbors' before clustering using leiden algorithm. Required if flavor == `{fla.NEIGHBORHOOD.s!r}` or flavor == `{fla.UTAG.s!r}`. resolutions List of resolutions to use for leiden clustering. In the case of spatialleiden you can pass a tuple. Resolution for the latent space and spatial layer, respectively. A single float applies to both layers. Required if flavor == `{fla.NEIGHBORHOOD.s!r}` or flavor == `{fla.UTAG.s!r}`. Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. min_niche_size Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. scale If 'True', compute z-scores of neighborhood profiles. Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. abs_nhood If 'True', calculate niches based on absolute neighborhood profile. Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. distance n-hop neighbor adjacency matrices to use e.g. [1,2,3] for 1-hop,2-hop,3-hop neighbors respectively or "5" for 1-hop,...,5-hop neighbors. 0 (self) is always included. Required if flavor == `{fla.CELLCHARTER.s!r}`. Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. n_hop_weights How to weight subsequent n-hop adjacency matrices. E.g. [1, 0.5, 0.25] for weights of 1-hop, 2-hop, 3-hop adjacency matrices respectively. Optional if flavor == `{fla.NEIGHBORHOOD.s!r}` and `distance` > 1. aggregation How to aggregate count matrices. Either 'mean' or 'variance'. Required if flavor == `{fla.CELLCHARTER.s!r}`. n_components Number of components to use for GMM. Required if flavor == `{fla.CELLCHARTER.s!r}`. random_state Random state to use for GMM or SpatialLeiden. Optional if flavor == `{fla.CELLCHARTER.s!r}` or flavor == `{fla.SPATIALLEIDEN.s!r}`. spatial_connectivities_key Key in `adata.obsp` where spatial connectivities are stored. Required if flavor == `{fla.SPATIALLEIDEN.s!r}`. latent_connectivities_key Key in `adata.obsp` where gene expression connectivities are stored. Required if flavor == `{fla.SPATIALLEIDEN.s!r}`. layer_ratio The ratio of the weighting of the layers; latent space vs spatial. A higher ratio will increase relevance of the spatial neighbors and lead to more spatially homogeneous clusters. Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. n_iterations Number of iterations to run the Leiden algorithm. If the number is negative it runs until convergence. Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. use_weights Whether to use weights for the edges for latent space and spatial neighbors, respectively. A single bool applies to both layers. Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. use_rep Key in `adata.obsm` where the embedding is stored. If provided, this embedding will be used instead of PCA for dimensionality reduction. Optional if flavor == `{fla.CELLCHARTER.s!r}`. inplace If 'True', perform the operation in place. If 'False', return a new AnnData object with the niche labels. """ if flavor == "cellcharter" and aggregation is None: aggregation = "mean" if distance is None: distance = 3 if flavor == "cellcharter" else 1 if flavor == "cellcharter" and n_components is None: n_components = 10 _validate_niche_args( data, flavor, library_key, table_key, groups, n_neighbors, resolutions, min_niche_size, scale, abs_nhood, distance, n_hop_weights, aggregation, n_components, random_state, spatial_connectivities_key, latent_connectivities_key, layer_ratio, n_iterations, use_weights, use_rep, inplace, ) if resolutions is None: resolutions = [0.5] orig_adata = extract_adata_if_sdata(data, table_key=table_key) adata = orig_adata.copy() assert_key_in_adata( adata, spatial_connectivities_key, attr="obsp", extra_msg="If you haven't computed a spatial neighborhood graph yet, use `sq.gr.spatial_neighbors`.", ) if flavor == "spatialleiden": assert_key_in_adata( adata, latent_connectivities_key, attr="obsp", extra_msg="If you haven't computed a latent neighborhood graph yet, use `sc.pp.neighbors`.", ) result_columns = _get_result_columns( flavor=flavor, resolutions=resolutions, library_key=None, libraries=None, ) if library_key is not None: assert_key_in_adata(adata, library_key, attr="obs") logg.info(f"Stratifying by library_key '{library_key}'") for col in result_columns: adata.obs[col] = "not_a_niche" for lib_id in adata.obs[library_key].unique(): logg.info(f"Processing library '{lib_id}'") lib_indices = adata.obs[adata.obs[library_key] == lib_id].index if len(lib_indices) == 0: logg.warning(f"Library '{lib_id}' contains no cells, skipping") continue lib_adata = adata[lib_indices].copy() lib_mask = None if mask is not None: lib_mask = mask[mask.index.isin(lib_indices)] lib_result = calculate_niche( lib_adata, flavor=flavor, library_key=None, mask=lib_mask, groups=groups, n_neighbors=n_neighbors, resolutions=None if flavor == "cellcharter" else resolutions, min_niche_size=min_niche_size, scale=scale, abs_nhood=abs_nhood, distance=None if flavor == "utag" else distance, n_hop_weights=n_hop_weights, aggregation=aggregation, n_components=n_components, random_state=random_state, spatial_connectivities_key=spatial_connectivities_key, latent_connectivities_key=latent_connectivities_key, layer_ratio=layer_ratio, n_iterations=n_iterations, use_weights=use_weights, inplace=False, ) for col in result_columns: if col in lib_result.obs.columns: prefixed_values = lib_result.obs[col].apply( lambda x, lib=lib_id: f"lib={lib}_{x}" if x != "not_a_niche" else x ) adata.obs.loc[lib_indices, col] = prefixed_values.values else: _calculate_niches( adata, mask, flavor, groups, n_neighbors, resolutions, min_niche_size, scale, abs_nhood, distance, n_hop_weights, aggregation, n_components, random_state, spatial_connectivities_key, latent_connectivities_key, layer_ratio, n_iterations, use_weights, use_rep, ) if not inplace: return adata # For SpatialData, update the table directly if isinstance(data, SpatialData): data.tables[table_key] = adata else: # For AnnData, copy results back to original object for col in result_columns: if col in orig_adata.obs.columns: logg.info(f"Overwriting existing column '{col}'") with contextlib.suppress(KeyError): del orig_adata.obs[col] if f"{col}_colors" in orig_adata.uns.keys(): with contextlib.suppress(KeyError): del orig_adata.uns[f"{col}_colors"] orig_adata.obs[col] = adata.obs[col] return None
def _get_result_columns( flavor: str, resolutions: float | tuple[float, float] | list[float | tuple[float, float]], library_key: str | None, libraries: list[str] | None, ) -> list[str]: """Get the column names that will be populated based on flavor and resolutions.""" library_str = f"_{library_key}" if library_key is not None else "" if flavor == "cellcharter": base_column = "cellcharter_niche" if library_key is None: return [base_column] elif libraries is not None and len(libraries) > 0: return [f"{base_column}_{lib}" for lib in libraries] # For neighborhood, utag and spatialleiden, we need to handle resolutions if not isinstance(resolutions, list): resolutions = [resolutions] if flavor == "neighborhood": prefix = f"nhood_niche{library_str}" elif flavor == "utag": prefix = f"utag_niche{library_str}" elif flavor == "spatialleiden": prefix = f"spatialleiden{library_str}" if library_key is None: return [f"{prefix}_res={res}" for res in resolutions] else: assert isinstance(libraries, list) # for mypy return [f"{prefix}_{lib}_res={res}" for lib in libraries for res in resolutions] def _calculate_niches( adata: AnnData, mask: pd.core.series.Series | None, flavor: str, groups: str | None, n_neighbors: int | None, resolutions: float | tuple[float, float] | list[float | tuple[float, float]], min_niche_size: int | None, scale: bool, abs_nhood: bool, distance: int, n_hop_weights: list[float] | None, aggregation: str | None, n_components: int | None, random_state: int, spatial_connectivities_key: str, latent_connectivities_key: str, layer_ratio: float, n_iterations: int, use_weights: bool | tuple[bool, bool], use_rep: str | None, ) -> None: """Calculate niches using the specified flavor and parameters.""" if flavor == "neighborhood": assert isinstance(resolutions, float | list) _get_nhood_profile_niches( adata, mask, groups, n_neighbors, resolutions, min_niche_size, scale, abs_nhood, distance, n_hop_weights, spatial_connectivities_key, ) elif flavor == "utag": assert isinstance(resolutions, float | list) _get_utag_niches(adata, n_neighbors, resolutions, spatial_connectivities_key) elif flavor == "cellcharter": assert isinstance(aggregation, str) # for mypy assert isinstance(n_components, int) # for mypy _get_cellcharter_niches( adata, distance, aggregation, n_components, random_state, spatial_connectivities_key, use_rep, ) elif flavor == "spatialleiden": _get_spatialleiden_domains( adata, spatial_connectivities_key, latent_connectivities_key, resolutions, layer_ratio, use_weights, n_iterations, random_state, ) def _get_nhood_profile_niches( adata: AnnData, mask: pd.core.series.Series | None, groups: str | None, n_neighbors: int | None, resolutions: float | tuple[float, float] | list[float | tuple[float, float]], min_niche_size: int | None, scale: bool, abs_nhood: bool, distance: int, n_hop_weights: list[float] | None, spatial_connectivities_key: str, ) -> None: """ adapted from https://github.com/immunitastx/monkeybread/blob/main/src/monkeybread/calc/_neighborhood_profile.py """ adata_masked = adata # get obs x neighbor matrix from sparse matrix matrix = adata_masked.obsp[spatial_connectivities_key].tocoo() # get obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood nhood_profile = _calculate_neighborhood_profile(adata_masked, groups, matrix, abs_nhood) # Additionally use n-hop neighbors if distance > 1. This sums up the (weighted) neighborhood profiles of all n-hop neighbors. if distance > 1: n_hop_adjacency_matrix = adata_masked.obsp[spatial_connectivities_key].copy() # if no weights are provided, use 1 for all n_hop neighbors if n_hop_weights is None: n_hop_weights = [1] * distance # if weights are provided, start with applying weight to the original neighborhood profile elif len(n_hop_weights) < distance: # Extend weights if too few provided n_hop_weights = n_hop_weights + [n_hop_weights[-1]] * (distance - len(n_hop_weights)) logg.debug(f"Extended weights to match distance: {n_hop_weights}") # Apply first weight to base profile weighted_profile = n_hop_weights[0] * nhood_profile # Calculate higher-order hop profiles n_hop_adjacency_matrix = adata_masked.obsp[spatial_connectivities_key].copy() # get n_hop neighbor adjacency matrices by multiplying the original adjacency matrix with itself n times and get corresponding neighborhood profiles. for n_hop in range(1, distance): logg.debug(f"Calculating {n_hop + 1}-hop neighbors") # Multiply adjacency matrix by itself to get n+1 hop adjacency n_hop_adjacency_matrix = n_hop_adjacency_matrix @ adata_masked.obsp[spatial_connectivities_key] matrix = n_hop_adjacency_matrix.tocoo() # Calculate and add weighted profile hop_profile = _calculate_neighborhood_profile(adata_masked, groups, matrix, abs_nhood) weighted_profile += n_hop_weights[n_hop] * hop_profile if not abs_nhood: weighted_profile = weighted_profile / sum(n_hop_weights) nhood_profile = weighted_profile # create AnnData object from neighborhood profile to perform scanpy functions # Use .to_numpy(copy=True) to ensure the array is writeable (required for pandas CoW compatibility) # Preserve the DataFrame index for later matching with adata_masked adata_neighborhood = ad.AnnData(X=nhood_profile.to_numpy(copy=True), obs=pd.DataFrame(index=nhood_profile.index)) # reason for scaling see https://monkeybread.readthedocs.io/en/latest/notebooks/tutorial.html#niche-analysis if scale: sc.pp.scale(adata_neighborhood, zero_center=True) # mask obs to exclude cells for which no niche shall be assigned if mask is not None: mask = mask[mask.index.isin(adata_neighborhood.obs.index)] adata_neighborhood = adata_neighborhood[mask] # required for leiden clustering (note: no dim reduction performed in original implementation) sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") resolutions = resolutions if isinstance(resolutions, list) else [resolutions] # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label for res in resolutions: niche_key = f"nhood_niche_res={res}" if niche_key in adata_masked.obs.columns: del adata_masked.obs[niche_key] if f"{niche_key}_colors" in adata_masked.uns.keys(): del adata_masked.uns[f"{niche_key}_colors"] # print(adata_masked.obs[niche_key]) sc.tl.leiden( adata_neighborhood, resolution=res, key_added=niche_key, ) adata_masked.obs[niche_key] = "not_a_niche" neighborhood_clusters = dict(zip(adata_neighborhood.obs.index, adata_neighborhood.obs[niche_key], strict=False)) mask_indices = adata_masked.obs.index adata_masked.obs.loc[mask_indices, niche_key] = [ neighborhood_clusters.get(idx, "not_a_niche") for idx in mask_indices ] # filter niches with n_cells < min_niche_size if min_niche_size is not None: counts_by_niche = adata_masked.obs[niche_key].value_counts() to_filter = counts_by_niche[counts_by_niche < min_niche_size].index adata_masked.obs[niche_key] = adata_masked.obs[niche_key].apply( lambda x, to_filter=to_filter: "not_a_niche" if x in to_filter else x ) adata_masked.obs[niche_key] = adata_masked.obs.index.map(adata_masked.obs[niche_key]).fillna("not_a_niche") return def _get_utag_niches( adata: AnnData, n_neighbors: int | None, resolutions: float | tuple[float, float] | list[float | tuple[float, float]], spatial_connectivities_key: str, ) -> None: """ Adapted from https://github.com/ElementoLab/utag/blob/main/utag/segmentation.py """ new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) adata_utag = ad.AnnData(X=new_feature_matrix) sc.tl.pca(adata_utag) # note: unlike with flavor 'neighborhood' dim reduction is performed here sc.pp.neighbors(adata_utag, n_neighbors=n_neighbors, use_rep="X_pca") if not isinstance(resolutions, list): resolutions = [resolutions] # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label for res in resolutions: sc.tl.leiden(adata_utag, resolution=res, key_added=f"utag_niche_res={res}") adata.obs[f"utag_niche_res={res}"] = adata_utag.obs[f"utag_niche_res={res}"].values return def _get_cellcharter_niches( adata: AnnData, distance: int, aggregation: str, n_components: int, random_state: int, spatial_connectivities_key: str, use_rep: str | None = None, ) -> None: """adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py""" adjacency_matrix = adata.obsp[spatial_connectivities_key] layers = list(range(distance + 1)) aggregated_matrices = [] adj_hop = _setdiag(adjacency_matrix, 0) # Remove self-loops, set diagonal to 0 adj_visited = _setdiag(adjacency_matrix.copy(), 1) # Track visited neighbors for k in layers: if k == 0: # get original count matrix (not aggregated) aggregated_matrices.append(adata.X) else: # get count and adjacency matrix for k-hop (neighbor of neighbor of neighbor ...) and aggregate them if k > 1: adj_hop, adj_visited = _hop(adj_hop, adjacency_matrix, adj_visited) adj_hop_norm = _normalize(adj_hop) aggregated_matrix = _aggregate(adata, adj_hop_norm, aggregation) aggregated_matrices.append(aggregated_matrix) concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally arr = concatenated_matrix.toarray() # Densify if use_rep is not None: # Use provided embedding from adata.obsm assert_key_in_adata(adata, use_rep, attr="obsm") embedding = adata.obsm[use_rep] # Ensure embedding has the right number of components if embedding.shape[1] < n_components: raise ValueError( f"Embedding has {embedding.shape[1]} components, but n_components={n_components}. Please provide an embedding with at least {n_components} components." ) # Use only the first n_components embedding = embedding[:, :n_components] else: logg.warning( "CellCharter recommends to use a dimensionality reduced embedding of the data, e.g. a scVI embedding. Since 'use_rep' is not provided, PCA will be used as proxy - performance may be suboptimal." ) arr_ad = ad.AnnData(X=arr) sc.tl.pca(arr_ad) embedding = arr_ad.obsm["X_pca"] # cluster concatenated matrix with GMM, each cluster label equals to a niche label niches = _get_GMM_clusters(embedding, n_components, random_state) adata.obs["cellcharter_niche"] = pd.Categorical(niches) return def _calculate_neighborhood_profile( adata: AnnData, groups: str | None, matrix: coo_matrix, abs_nhood: bool, ) -> pd.DataFrame: """ Returns an obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood """ nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0]))) neighbor_matrix = pd.DataFrame(nonzero_indices) # get unique categories unique_categories = np.unique(adata.obs[groups].values) # get obs x k matrix where each column is the category of the k-th neighbor indices_with_nan = neighbor_matrix.to_numpy() valid_indices = neighbor_matrix.fillna(-1).astype(int).to_numpy() cat_by_id = adata.obs[groups].values[valid_indices] cat_by_id[indices_with_nan == -1] = np.nan # cat_by_id = np.take(category_arr, neighbor_matrix) # in obs x k matrix convert categorical values to numerical values cat_indices = {category: index for index, category in enumerate(unique_categories)} cat_values = np.vectorize(cat_indices.get)(cat_by_id) # get obx x category matrix where each column is the absolute amount of a category in the neighborhood m, k = cat_by_id.shape abs_freq = np.zeros((m, len(unique_categories)), dtype=int) np.add.at(abs_freq, (np.arange(m)[:, None], cat_values), 1) # normalize by n_neighbors to get relative frequency of each category rel_freq = abs_freq / k if abs_nhood: return pd.DataFrame(abs_freq, index=adata.obs.index) else: return pd.DataFrame(rel_freq, index=adata.obs.index) def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: """ Performs inner product of adjacency matrix and feature matrix, such that each observation inherits features from its immediate neighbors as described in UTAG paper. """ adjacency_matrix = adata.obsp[spatial_connectivity_key] if normalize_adj: return normalize(adjacency_matrix, norm="l1", axis=1) @ adata.X else: return adjacency_matrix @ adata.X def _setdiag(adjacency_matrix: sps.spmatrix, value: int) -> sps.spmatrix: """remove self-loops""" if issparse(adjacency_matrix): adjacency_matrix = adjacency_matrix.tolil() adjacency_matrix.setdiag(value) adjacency_matrix = adjacency_matrix.tocsr() if value == 0: adjacency_matrix.eliminate_zeros() return adjacency_matrix def _hop( adj_hop: sps.spmatrix, adj: sps.spmatrix, adj_visited: sps.spmatrix = None ) -> tuple[sps.spmatrix, sps.spmatrix]: """get nearest neighbor of neighbors""" adj_hop = adj_hop @ adj if adj_visited is not None: adj_hop = adj_hop > adj_visited adj_visited = adj_visited + adj_hop return adj_hop, adj_visited def _normalize(adj: sps.spmatrix) -> sps.spmatrix: """normalize adjacency matrix such that nodes with high degree don't disproportionately affect aggregation""" deg = np.array(np.sum(adj, axis=1)).squeeze() with np.errstate(divide="ignore"): deg_inv = 1 / deg deg_inv[deg_inv == float("inf")] = 0 return spdiags(deg_inv, 0, len(deg_inv), len(deg_inv)) * adj def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggregation: str = "mean") -> Any: """aggregate count and adjacency matrix either by mean or variance""" # TODO: add support for other aggregation methods if aggregation == "mean": aggregated_matrix = normalized_adjacency_matrix @ adata.X elif aggregation == "variance": mean_matrix = (normalized_adjacency_matrix @ adata.X).toarray() X_to_arr = adata.X.toarray() mean_squared_matrix = normalized_adjacency_matrix @ (X_to_arr * X_to_arr) aggregated_matrix = mean_squared_matrix - mean_matrix * mean_matrix else: raise ValueError(f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'.") return aggregated_matrix def _get_GMM_clusters(A: NDArray[np.float64], n_components: int, random_state: int) -> Any: """Returns niche labels generated by GMM clustering. Compared to cellcharter this approach is simplified by using sklearn's GaussianMixture model without stability analysis. """ gmm = GaussianMixture( n_components=n_components, random_state=random_state, init_params="random_from_data", ) gmm.fit(A) labels = gmm.predict(A) return labels def _get_spatialleiden_domains( adata: AnnData, spatial_connectivities_key: str, latent_connectivities_key: str, resolutions: float | tuple[float, float] | list[float | tuple[float, float]], layer_ratio: float, use_weights: bool | tuple[bool, bool], n_iterations: int, random_state: int, ) -> None: """ Perform SpatialLeiden clustering. This is a wrapper around :py:func:`spatialleiden.multiplex_leiden` that uses :py:class:`anndata.AnnData` as input and works with two layers; one latent space and one spatial layer. Adapted from https://github.com/HiDiHlabs/SpatialLeiden/. """ try: import spatialleiden as sl except ImportError as e: msg = "Please install the spatialleiden algorithm: `pip install squidpy[leiden]` or `conda install bioconda::spatialleiden` or `pip install spatialleiden`." raise ImportError(msg) from e if not isinstance(resolutions, list): resolutions = [resolutions] for res in resolutions: sl.spatialleiden( adata, resolution=res, use_weights=use_weights, n_iterations=n_iterations, layer_ratio=layer_ratio, latent_neighbors_key=latent_connectivities_key, spatial_neighbors_key=spatial_connectivities_key, random_state=random_state, directed=False, key_added=f"spatialleiden_res={res}", ) return def _fide_score(adata: AnnData, niche_key: str, average: bool) -> Any: """ F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. The F1-score is computed for every class, then all F1-scores are averaged. If some classes are not predicted, the `n_classes` argument allows to pad with zeros before averaging the F1-scores. """ i, j = adata.obsp["spatial_connectivities"].nonzero() # get row and column indices of non-zero elements niche_labels, neighbor_niche_labels = ( adata.obs.iloc[i][niche_key], adata.obs.iloc[j][niche_key], ) if not average: fide = f1_score(niche_labels, neighbor_niche_labels, average=None) else: fide = f1_score(niche_labels, neighbor_niche_labels, average="macro") return fide def _jensen_shannon_divergence(adata: AnnData, niche_key: str, library_key: str) -> Any: """ Calculate Jensen-Shannon divergence (JSD) over all slides. This metric measures how well niche label distributions match across different slides. """ niche_labels = sorted(adata.obs[niche_key].unique()) label_distributions = [] for _, slide in adata.obs.groupby(library_key): counts = slide[niche_key].value_counts(normalize=True) relative_freq = [counts.get(label, 0) for label in niche_labels] label_distributions.append(relative_freq) return distance.jensenshannon(np.array(label_distributions)) def _validate_niche_args( data: AnnData | SpatialData, flavor: Literal["neighborhood", "utag", "cellcharter", "spatialleiden"], library_key: str | None, table_key: str | None, groups: str | None, n_neighbors: int | None, resolutions: float | tuple[float, float] | list[float | tuple[float, float]] | None, min_niche_size: int | None, scale: bool, abs_nhood: bool, distance: int | None, n_hop_weights: list[float] | None, aggregation: str | None, n_components: int | None, random_state: int, spatial_connectivities_key: str, latent_connectivities_key: str, layer_ratio: float, n_iterations: int, use_weights: bool | tuple[bool, bool], use_rep: str | None, inplace: bool, ) -> None: """ Validate whether necessary arguments are provided for a given niche flavor. Also warns whether unnecessary optional arguments are supplied. Raises ------ ValueError If required arguments for the specified flavor are missing or have incorrect values. TypeError If arguments are of incorrect type. """ assert_isinstance(data, (AnnData, SpatialData), name="data") assert_one_of(flavor, ["neighborhood", "utag", "cellcharter", "spatialleiden"], name="flavor") if isinstance(data, SpatialData) and table_key is None: raise TypeError("missing required keyword-only argument: 'table_key'") if library_key is not None: assert_isinstance(library_key, str, name="library_key") adata = extract_adata_if_sdata(data, table_key=table_key) if library_key not in adata.obs.columns: raise ValueError(f"'library_key' must be a column in 'adata.obs', got {library_key}") if n_neighbors is not None: assert_isinstance(n_neighbors, int, name="n_neighbors") if resolutions is not None: if not isinstance(resolutions, float | tuple | list): raise TypeError( f"'resolutions' must be a float, a tuple of floats, a list of floats, or a list containing floats and/or tuples of floats, got {type(resolutions).__name__}" ) if isinstance(resolutions, tuple): if not all(isinstance(x, float) for x in resolutions): raise TypeError("All elements in the tuple 'resolutions' must be floats.") elif isinstance(resolutions, list): for item in resolutions: if not ( isinstance(item, float) or (isinstance(item, tuple) and all(isinstance(i, float) for i in item)) ): raise TypeError("Each item in the list 'resolutions' must be a float or a tuple of floats.") if n_hop_weights is not None: assert_isinstance(n_hop_weights, list, name="n_hop_weights") assert_isinstance(scale, bool, name="scale") assert_isinstance(abs_nhood, bool, name="abs_nhood") # Define parameters used by each flavor flavor_param_specs = { "neighborhood": { "required": ["groups", "n_neighbors", "resolutions"], "optional": [ "min_niche_size", "scale", "abs_nhood", "distance", "n_hop_weights", ], "unused": ["aggregation", "n_components", "random_state"], }, "utag": { "required": ["n_neighbors", "resolutions"], "optional": [], "unused": [ "groups", "min_niche_size", "scale", "abs_nhood", "distance", "n_hop_weights", "aggregation", "n_components", "random_state", ], }, "cellcharter": { "required": ["distance", "aggregation", "random_state"], "optional": ["n_components", "use_rep"], "unused": [ "groups", "min_niche_size", "scale", "abs_nhood", "n_neighbors", "resolutions", "n_hop_weights", ], }, "spatialleiden": { "required": ["latent_connectivities_key", "spatial_connectivities_key"], "optional": [ "resolutions", "layer_ratio", "n_iterations", "use_weights", "random_state", ], "unused": [ "groups", "min_niche_size", "scale", "abs_nhood", "n_neighbors", "n_hop_weights", ], }, } for param_name in flavor_param_specs[flavor]["required"]: param_value = locals()[param_name] if param_value is None: raise ValueError(f"'{param_name}' is required for flavor '{flavor}'") _check_unnecessary_args( flavor, { "groups": groups, "n_neighbors": n_neighbors, "resolutions": resolutions, "min_niche_size": min_niche_size, "scale": scale, "abs_nhood": abs_nhood, "distance": distance, "n_hop_weights": n_hop_weights, "aggregation": aggregation, "n_components": n_components, "random_state": random_state, "use_rep": use_rep, }, flavor_param_specs[flavor], ) # Flavor-specific validations if flavor == "neighborhood": assert_isinstance(groups, str, name="groups") if min_niche_size is not None: assert_isinstance(min_niche_size, int, name="min_niche_size") if distance is not None and isinstance(distance, int) and distance < 1: raise ValueError(f"'distance' must be at least 1, got {distance}") elif flavor == "cellcharter": if distance is not None: assert_isinstance(distance, int, name="distance") if distance is not None and distance < 1: raise ValueError(f"'distance' must be at least 1, got {distance}") if aggregation is not None: assert_isinstance(aggregation, str, name="aggregation") assert_one_of(aggregation, ["mean", "variance"], name="aggregation") assert_isinstance(n_components, int, name="n_components") if n_components < 1: raise ValueError(f"'n_components' must be at least 1, got {n_components}") assert_isinstance(random_state, int, name="random_state") if use_rep is not None: assert_isinstance(use_rep, str, name="use_rep") # for mypy if resolutions is None: resolutions = [0.0] elif flavor == "spatialleiden": assert_isinstance(latent_connectivities_key, str, name="latent_connectivities_key") assert_isinstance(spatial_connectivities_key, str, name="spatial_connectivities_key") assert_isinstance(layer_ratio, (float, int), name="layer_ratio") assert_isinstance(n_iterations, int, name="n_iterations") if not ( isinstance(use_weights, bool) or ( isinstance(use_weights, tuple) and len(use_weights) == 2 and all(isinstance(x, bool) for x in use_weights) ) ): raise TypeError(f"'use_weights' must be a bool or a tuple of two bools, got {use_weights!r}") assert_isinstance(random_state, int, name="random_state") if resolutions is None: resolutions = [1.0] assert_isinstance(inplace, bool, name="inplace") def _check_unnecessary_args(flavor: str, param_dict: dict[str, Any], param_specs: dict[str, Any]) -> None: """ Check for unnecessary arguments that were provided but not used by the given flavor. Parameters ---------- flavor The flavor being used ('neighborhood', 'utag', 'cellcharter', or 'spatialleiden') param_dict Dictionary of parameter names to their values param_specs Dictionary with 'required', 'optional', and 'unused' parameter lists for the flavor """ unnecessary_args = [] for param_name in param_specs["unused"]: param_value = param_dict.get(param_name) # Special handling for boolean parameters with default values if param_name == "scale" and param_value is True: continue if param_name == "abs_nhood" and param_value is False: continue if param_name == "random_state" and param_value == 42: continue if param_value is not None: unnecessary_args.append(param_name) if unnecessary_args: logg.warning( f"Parameters {', '.join([f'{arg}' for arg in unnecessary_args])} are not used for flavor '{flavor}'.", )