Predict cluster labels spots using Tensorflow

In this tutorial, we show how you can use the squidpy.im.ImageContainer object to train a ResNet model to predict cluster labels of spots.

This is a general approach that can be easily extended to a variety of supervised, self-supervised or unsupervised tasks. We aim to highlight how the flexibility provided by the image container, and it’s seamless integration with AnnData, makes it easy to interface your data with modern deep learning frameworks such as Tensorflow.

Furthermore, we show how you can leverage such a ResNet model to generate a new set of features that can provide useful insights on spots similarity based on image morphology.

First, we’ll load some libraries. Note that Tensorflow is not a dependency of Squidpy and you’d therefore have to install it separately in your conda environment. Have a look at the Tensorflow installation instructions. This of course applies to any deep learning framework of your choice.

[1]:
import scanpy as sc
import squidpy as sq
from squidpy.im import ImageContainer
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from anndata import AnnData
from sklearn.model_selection import (
    train_test_split,
)  # we'll use this function to split our dataset in train and test set
import tensorflow as tf
from tensorflow.keras.layers.experimental import (
    preprocessing,
)  # let's use the new pre-processing layers for resizing and data augmentation tasks

sc.logging.print_header() # TODO: update Scanpy and Squidpy versions
print(f"squidpy=={sq.__version__}")
print(f"tensorflow=={tf.__version__}")
scanpy==1.8.0.dev90+gbcfa925f anndata==0.7.5 umap==0.5.1 numpy==1.19.5 scipy==1.6.1 pandas==1.2.2 scikit-learn==0.24.1 statsmodels==0.12.2 python-igraph==0.9.0 leidenalg==0.8.3 pynndescent==0.5.2
squidpy==1.0.0
tensorflow==2.4.1

We will load the public data available in Squidpy.

[2]:
adata = sq.datasets.visium_hne_adata()
img = sq.datasets.visium_hne_image()

Create train-test split

We create a vector of our labels with which to train the classifier. In this case, we will train a classifier to predict cluster labels obtained from gene expression. We’ll create a one-hot encoded array with the convenient function tf.one_hot. Furthermore, we’ll split the vector indices to get a train and test set. Note that we specify the cluster labels as the stratify argument, to make sure that the cluster labels are balanced in each split.

[3]:
# get train,test split stratified by cluster labels
train_idx, test_idx = train_test_split(
    adata.obs_names.values,
    test_size=0.2,
    stratify=adata.obs["cluster"],
    shuffle=True,
    random_state=42,
)
[4]:
print(
    f"Train set : \n {adata[train_idx, :].obs.cluster.value_counts()} \n \n Test set: \n {adata[test_idx, :].obs.cluster.value_counts()}"
)
Train set :
 Cortex_1                         227
Thalamus_1                       209
Cortex_2                         206
Cortex_3                         195
Fiber_tract                      181
Hippocampus                      178
Hypothalamus_1                   166
Thalamus_2                       154
Cortex_4                         131
Striatum                         122
Hypothalamus_2                   106
Cortex_5                         103
Lateral_ventricle                 84
Pyramidal_layer_dentate_gyrus     54
Pyramidal_layer                   34
Name: cluster, dtype: int64

 Test set:
 Cortex_1                         57
Thalamus_1                       52
Cortex_2                         51
Cortex_3                         49
Fiber_tract                      45
Hippocampus                      44
Hypothalamus_1                   42
Thalamus_2                       38
Cortex_4                         33
Striatum                         31
Hypothalamus_2                   27
Cortex_5                         26
Lateral_ventricle                21
Pyramidal_layer_dentate_gyrus    14
Pyramidal_layer                   8
Name: cluster, dtype: int64
/home/icb/giovanni.palla/miniconda3/envs/spatial/lib/python3.8/site-packages/pandas/core/arrays/categorical.py:2487: FutureWarning: The `inplace` parameter in pandas.Categorical.remove_unused_categories is deprecated and will be removed in a future version.
  res = method(*args, **kwargs)

Create datasets and train the model

Next, we’ll create a Tensorflow dataset which will be used as data loader for model training. A key aspect of this step is how the Image Container makes it easy to relate spots information to the underlying image. In particular, we will make use of img.generate_spot_crops, a method that creates a generator to crop the tissue image corresponding to each spot. In just one line of code you can create this generator as well as specifying the size of the crops . You might want to increase the size to include some neighborhood morphology information.

We won’t get too much in details of the additional arguments and steps related to the Tensorflow Dataset objects, you can familiarize yourself with Tensorflow datasets here.

[5]:
def get_ohe(adata: AnnData, cluster_key: str, obs_names: np.ndarray):
    cluster_labels = adata[obs_names, :].obs["cluster"]
    classes = cluster_labels.unique().shape[0]
    cluster_map = {v: i for i, v in enumerate(cluster_labels.cat.categories.values)}
    labels = np.array([cluster_map[c] for c in cluster_labels], dtype=np.uint8)
    labels_ohe = tf.one_hot(labels, depth=classes, dtype=tf.float32)
    return labels_ohe


def create_dataset(
    adata: AnnData,
    img: ImageContainer,
    obs_names: np.ndarray,
    cluster_key: str,
    augment: bool,
    shuffle: bool,
):
    # image dataset
    spot_generator = img.generate_spot_crops(
        adata,
        obs_names=obs_names,  # this arguent specified the observations names
        scale=1.5,  # this argument specifies that we will consider some additional context under each spot. Scale=1 would crop the spot with exact coordinates
        as_array="image",  # this line specifies that we will crop from the "image" layer. You can specify multiple layers to obtain crops from multiple pre-processing steps.
        return_obs=False,
    )
    image_dataset = tf.data.Dataset.from_tensor_slices([x for x in spot_generator])

    # label dataset
    lab = get_ohe(adata, cluster_key, obs_names)
    lab_dataset = tf.data.Dataset.from_tensor_slices(lab)

    ds = tf.data.Dataset.zip((image_dataset, lab_dataset))

    if shuffle:  # if you want to shuffle the dataset during training
        ds = ds.shuffle(1000, reshuffle_each_iteration=True)
    ds = ds.batch(64)  # batch
    processing_layers = [
        preprocessing.Resizing(128, 128),
        preprocessing.Rescaling(1.0 / 255),
    ]
    augment_layers = [
        preprocessing.RandomFlip(),
        preprocessing.RandomContrast(0.8),
    ]
    if augment:  # if you want to augment the image crops during training
        processing_layers.extend(augment_layers)

    data_processing = tf.keras.Sequential(processing_layers)

    ds = ds.map(lambda x, y: (data_processing(x), y))  # add processing to dataset
    return ds
[6]:
train_ds = create_dataset(adata, img, train_idx, "cluster", augment=True, shuffle=True)
test_ds = create_dataset(adata, img, test_idx, "cluster", augment=True, shuffle=True)
/home/icb/giovanni.palla/miniconda3/envs/spatial/lib/python3.8/site-packages/pandas/core/arrays/categorical.py:2487: FutureWarning: The `inplace` parameter in pandas.Categorical.remove_unused_categories is deprecated and will be removed in a future version.
  res = method(*args, **kwargs)
/home/icb/giovanni.palla/miniconda3/envs/spatial/lib/python3.8/site-packages/pandas/core/arrays/categorical.py:2487: FutureWarning: The `inplace` parameter in pandas.Categorical.remove_unused_categories is deprecated and will be removed in a future version.
  res = method(*args, **kwargs)

Here, we are actually instantiating the model. We’ll use a pre-trained ResNet on ImageNet, and a dense layer for output.

[7]:
input_shape = (128, 128, 3)  # input shape
inputs = tf.keras.layers.Input(shape=input_shape)

# load Resnet with pre-trained imagenet weights
x = tf.keras.applications.ResNet50(
    weights="imagenet",
    include_top=False,
    input_shape=input_shape,
    classes=15,
    pooling="avg",
)(inputs)
outputs = tf.keras.layers.Dense(
    units=15,  # add output layer
)(x)
model = tf.keras.Model(inputs, outputs)  # create model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),  # add optimizer
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),  # add loss
)
[8]:
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 128, 128, 3)]     0
_________________________________________________________________
resnet50 (Functional)        (None, 2048)              23587712
_________________________________________________________________
dense (Dense)                (None, 15)                30735
=================================================================
Total params: 23,618,447
Trainable params: 23,565,327
Non-trainable params: 53,120
_________________________________________________________________
[9]:
history = model.fit(
    train_ds,
    validation_data=test_ds,
    epochs=50,
    verbose=2,
)
Epoch 1/50
34/34 - 43s - loss: 2.1304 - val_loss: 3.6685
Epoch 2/50
34/34 - 7s - loss: 1.1984 - val_loss: 5.2714
Epoch 3/50
34/34 - 7s - loss: 0.9193 - val_loss: 6.3917
Epoch 4/50
34/34 - 7s - loss: 0.7265 - val_loss: 10.2531
Epoch 5/50
34/34 - 7s - loss: 0.5830 - val_loss: 8.8467
Epoch 6/50
34/34 - 7s - loss: 0.4502 - val_loss: 12.2926
Epoch 7/50
34/34 - 7s - loss: 0.3471 - val_loss: 9.1756
Epoch 8/50
34/34 - 7s - loss: 0.3040 - val_loss: 10.4599
Epoch 9/50
34/34 - 7s - loss: 0.2202 - val_loss: 9.2584
Epoch 10/50
34/34 - 7s - loss: 0.2127 - val_loss: 13.0235
Epoch 11/50
34/34 - 7s - loss: 0.1827 - val_loss: 12.9605
Epoch 12/50
34/34 - 7s - loss: 0.1609 - val_loss: 8.7757
Epoch 13/50
34/34 - 7s - loss: 0.1482 - val_loss: 7.8238
Epoch 14/50
34/34 - 7s - loss: 0.1055 - val_loss: 8.8139
Epoch 15/50
34/34 - 7s - loss: 0.1072 - val_loss: 8.7499
Epoch 16/50
34/34 - 7s - loss: 0.0782 - val_loss: 5.8379
Epoch 17/50
34/34 - 7s - loss: 0.0828 - val_loss: 5.7679
Epoch 18/50
34/34 - 7s - loss: 0.0779 - val_loss: 6.7843
Epoch 19/50
34/34 - 7s - loss: 0.0891 - val_loss: 9.8836
Epoch 20/50
34/34 - 7s - loss: 0.0720 - val_loss: 8.3936
Epoch 21/50
34/34 - 7s - loss: 0.0602 - val_loss: 14.4134
Epoch 22/50
34/34 - 7s - loss: 0.0523 - val_loss: 10.7274
Epoch 23/50
34/34 - 7s - loss: 0.0585 - val_loss: 5.9731
Epoch 24/50
34/34 - 7s - loss: 0.0488 - val_loss: 6.1676
Epoch 25/50
34/34 - 7s - loss: 0.0547 - val_loss: 6.5565
Epoch 26/50
34/34 - 7s - loss: 0.0403 - val_loss: 5.5574
Epoch 27/50
34/34 - 7s - loss: 0.0438 - val_loss: 4.3810
Epoch 28/50
34/34 - 7s - loss: 0.0404 - val_loss: 4.1031
Epoch 29/50
34/34 - 7s - loss: 0.0350 - val_loss: 4.1813
Epoch 30/50
34/34 - 7s - loss: 0.0325 - val_loss: 4.1099
Epoch 31/50
34/34 - 7s - loss: 0.0263 - val_loss: 3.8373
Epoch 32/50
34/34 - 7s - loss: 0.0251 - val_loss: 4.1451
Epoch 33/50
34/34 - 7s - loss: 0.0262 - val_loss: 3.6394
Epoch 34/50
34/34 - 7s - loss: 0.0425 - val_loss: 3.7185
Epoch 35/50
34/34 - 7s - loss: 0.0573 - val_loss: 3.3745
Epoch 36/50
34/34 - 7s - loss: 0.0563 - val_loss: 3.8144
Epoch 37/50
34/34 - 7s - loss: 0.0718 - val_loss: 3.3854
Epoch 38/50
34/34 - 7s - loss: 0.0877 - val_loss: 3.9512
Epoch 39/50
34/34 - 7s - loss: 0.0574 - val_loss: 3.8434
Epoch 40/50
34/34 - 7s - loss: 0.0629 - val_loss: 4.3814
Epoch 41/50
34/34 - 7s - loss: 0.0492 - val_loss: 4.3040
Epoch 42/50
34/34 - 7s - loss: 0.0531 - val_loss: 3.6752
Epoch 43/50
34/34 - 7s - loss: 0.0908 - val_loss: 3.9180
Epoch 44/50
34/34 - 7s - loss: 0.0621 - val_loss: 2.8373
Epoch 45/50
34/34 - 7s - loss: 0.1062 - val_loss: 4.7432
Epoch 46/50
34/34 - 7s - loss: 0.0752 - val_loss: 4.4942
Epoch 47/50
34/34 - 7s - loss: 0.0472 - val_loss: 4.6422
Epoch 48/50
34/34 - 7s - loss: 0.0502 - val_loss: 3.5747
Epoch 49/50
34/34 - 7s - loss: 0.0607 - val_loss: 4.7950
Epoch 50/50
34/34 - 7s - loss: 0.0432 - val_loss: 3.8837

We can plot training and test loss during training. Clearly it would benefit from some more fine-tuning :).

[10]:
sns.lineplot(x=np.arange(50), y="loss", data=history.history)
sns.lineplot(x=np.arange(50), y="val_loss", data=history.history)
[10]:
<AxesSubplot:ylabel='loss'>
../_images/external_tutorials_tutorial_tf_15_1.png

Calculate embedding and visualize results

What we are actually interested in is the ResNet embedding values of the data after training. We expect that such an embedding contains relevant features of the image that can be used for downstream analysis such as clustering or integration with gene expression.

For generating this embedding, we first create a new dataset, that contains the full list of spots, in the correct order and without augmentation.

[11]:
full_ds = create_dataset(
    adata, img, adata.obs_names.values, "cluster", augment=False, shuffle=False
)
/home/icb/giovanni.palla/miniconda3/envs/spatial/lib/python3.8/site-packages/pandas/core/arrays/categorical.py:2487: FutureWarning: The `inplace` parameter in pandas.Categorical.remove_unused_categories is deprecated and will be removed in a future version.
  res = method(*args, **kwargs)
/home/icb/giovanni.palla/miniconda3/envs/spatial/lib/python3.8/site-packages/pandas/core/arrays/categorical.py:2487: FutureWarning: The `inplace` parameter in pandas.Categorical.remove_unused_categories is deprecated and will be removed in a future version.
  res = method(*args, **kwargs)

Then, we instantiate another model without the output layer, in order to get the final embedding layer.

[12]:
model_embed = tf.keras.Model(inputs, x)
embedding = model_embed.predict(full_ds)

We can then save the embedding in a new AnnData, and copy over all the relevant metadata from the AnnData with gene expression counts…

[13]:
adata_resnet = AnnData(embedding, obs=adata.obs.copy())
adata_resnet.obsm["spatial"] = adata.obsm["spatial"].copy()
adata_resnet.uns = adata.uns.copy()
adata_resnet
[13]:
AnnData object with n_obs × n_vars = 2688 × 2048
    obs: 'in_tissue', 'array_row', 'array_col', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_counts', 'leiden', 'cluster'
    uns: 'cluster_colors', 'hvg', 'leiden', 'leiden_colors', 'neighbors', 'pca', 'rank_genes_groups', 'spatial', 'umap'
    obsm: 'spatial'

… perform the standard clustering analysis.

[14]:
sc.pp.scale(adata_resnet)
sc.pp.pca(adata_resnet)
sc.pp.neighbors(adata_resnet)
sc.tl.leiden(adata_resnet, key_added="resnet_embedding_cluster")
sc.tl.umap(adata_resnet)

Interestingly, it seems that despite the poor performance on the test set, the model has encoded some information relevant to separate spots from each other. The clustering annotation also resembles the original annotation based on gene expression similarity.

[15]:
sc.set_figure_params(facecolor="white", figsize=(8, 8))
sc.pl.umap(
    adata_resnet, color=["cluster", "resnet_embedding_cluster"], size=100, wspace=0.7
)
../_images/external_tutorials_tutorial_tf_25_0.png

We can visualize the same information in spatial coordinates. Again some clusters seems to closely recapitulate the Hippocampus and Pyramidal layers clusters. It seems to have worked surprisingly well!

[16]:
sc.pl.spatial(
    adata_resnet,
    color=["cluster", "resnet_embedding_cluster"],
    frameon=False,
    wspace=0.5,
)
../_images/external_tutorials_tutorial_tf_27_0.png

An additional analysis could be to integrate information of both gene expression and the features learned by the ResNet classifier, in order to get a joint representation of both gene expression and image information. Such integration could be done for instance by concatenating the resulting PCA from the gene expression adata and the ResNet embedding adata_resnet. After concatenating the principal components, you could follow the usual steps of building a KNN graph and clustering with the leiden algorithm.

With this tutorial we have shown how to interface the Squidpy workflow with modern deep learning frameworks, and have inspired you with additional analysis that leverage several data modalities and powerful DL-based representations.