Nuclei segmentation using StarDist

In this tutorial, we show how we can use the StarDist segmentation method in squidpy.im.segment for nuclei segmentation.

StarDist Schmidt et al. (2018) and Weigert et al. (2020) , (code) uses star-convex polygons to localize cell for which a convolutional neural network was trained to predict pixel-wise polygons for each cell position.

To run the notebook locally, create a conda environment as conda env create -f stardist_environment.yml using this stardist_environment.yml, which installs Squidpy, TensorFlow, and StarDist.

Note: We frequently recognized a dying notebook kernel when importing other packages before StarDist with the following message “The kernel appears to have died. It will restart automatically.” We therefore recommend to import StarDist first.

# Import the StarDist 2D segmentation models.
# Import the recommended normalization technique for stardist.
from csbdeep.utils import normalize
from stardist.models import StarDist2D

import numpy as np

import matplotlib.pyplot as plt

# Import squidpy and additional packages needed for this tutorial.
import squidpy as sq

StarDist has four pre-trained models for 2D images. We will show an example for the Versatile (fluorescent nuclei) model and the Versatile (H&E nuclei). To use the StarDist model, we define a wrapper that normalizes the image with the recommended method, initializes the model and returns the segmentation mask.

StarDist2D.from_pretrained()
There are 4 registered models for 'StarDist2D':

Name                  Alias(es)
────                  ─────────
'2D_versatile_fluo'   'Versatile (fluorescent nuclei)'
'2D_versatile_he'     'Versatile (H&E nuclei)'
'2D_paper_dsb2018'    'DSB 2018 (from StarDist 2D paper)'
'2D_demo'             None

The method parameter of the sq.im.segment method accepts any callable with the signature: numpy.ndarray (height, width, channels) -> numpy.ndarray (height, width[, channels]). Additional model specific arguments will also be passed on.

Cell segmentation on Visium fluorescence data

# Load the image and visualize its channels.
img = sq.datasets.visium_fluo_image_crop()
crop = img.crop_corner(1000, 1000, size=1000)
crop.show(channelwise=True)
../../_images/5734c2de6c64240712f53f4b51df64933d77502fd7a058266fe6801ee2b4ed62.png

Additionally, we will have a look at the pre-trained StarDist model. The 2D_versatile_fluo model works on one channel, as n_channel_in = 1. We will run the segmentation on the first channel of the image in this example.

StarDist2D.from_pretrained("2D_versatile_fluo")
Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.
StarDist2D(2D_versatile_fluo): YXC → YXC
├─ Directory: None
└─ Config2D(axes='YXC', backbone='unet', grid=(2, 2), n_channel_in=1, n_channel_out=33, n_dim=2, n_rays=32, net_conv_after_unet=128, net_input_shape=[None, None, 1], net_mask_shape=[None, None, 1], train_background_reg=0.0001, train_batch_size=8, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5', train_completion_crop=32, train_dist_loss='mae', train_epochs=800, train_foreground_only=0.9, train_learning_rate=0.0003, train_loss_weights=[1, 0.2], train_n_val_patches=None, train_patch_size=[256, 256], train_reduce_lr={'factor': 0.5, 'patience': 80, 'min_delta': 0}, train_shape_completion=False, train_steps_per_epoch=400, train_tensorboard=True, unet_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_kernel_size=[3, 3], unet_last_activation='relu', unet_n_conv_per_depth=2, unet_n_depth=3, unet_n_filter_base=32, unet_pool=[2, 2], unet_prefix='', use_gpu=False)

The input image is normalized beforehand by supplying a normalizer to the prediction function. We pass the recommended StarDist normalization method from csbdeep.utils into our callable.

Calling model.predict_instances will:

  • predict object probabilities and star-convex polygon distances.

  • perform non-maximum suppression (with overlap threshold nms_thresh) for polygons above object probability threshold prob_thresh.

  • render all remaining polygon instances in a label image.

  • return the label instances image and also the details (coordinates, etc.) of all remaining polygons.

For our purpose, we will only return the respective labels. Check the detailed example StarDist notebook for more information.

def stardist_2D_versatile_fluo(img, nms_thresh=None, prob_thresh=None):
    # Make sure to normalize the input image beforehand or supply a normalizer to the prediction function.
    # this is the default normalizer noted in StarDist examples.
    img = normalize(img, 1, 99.8, axis=(0, 1))
    model = StarDist2D.from_pretrained("2D_versatile_fluo")
    labels, _ = model.predict_instances(
        img, nms_thresh=nms_thresh, prob_thresh=prob_thresh
    )
    return labels
sq.im.segment(
    img=crop,
    layer="image",
    channel=0,
    method=stardist_2D_versatile_fluo,
    layer_added="segmented_stardist",
    nms_thresh=None,
    prob_thresh=None,
)
Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.
# Plot the DAPI channel of the image crop and the segmentation result.
print(crop)
print(f"Number of segments in crop: {len(np.unique(crop['segmented_stardist']))}")

fig, axes = plt.subplots(1, 2)
crop.show("image", channel=0, ax=axes[0])
_ = axes[0].set_title("DAPI")
crop.show("segmented_stardist", cmap="jet", interpolation="none", ax=axes[1])
_ = axes[1].set_title("segmentation")
ImageContainer[shape=(1000, 1000), layers=['image', 'segmented_stardist']]
Number of segments in crop: 412
../../_images/af8b0187c5486082f41c2595055a77be501aef922a631502ec1a612544a966f6.png

Cell segmentation on H&E stained tissue data

# load H&E stained tissue image and crop to a smaller segment
img = sq.datasets.visium_hne_image_crop()
crop = img.crop_corner(0, 0, size=1000)
crop.show("image")
../../_images/0d75fafcc36e87c590993759b04e26b22f1b93aac64bcf316532b8200f9307ef.png
def stardist_2D_versatile_he(img, nms_thresh=None, prob_thresh=None):
    # axis_norm = (0,1)   # normalize channels independently
    axis_norm = (0, 1, 2)  # normalize channels jointly
    # Make sure to normalize the input image beforehand or supply a normalizer to the prediction function.
    # this is the default normalizer noted in StarDist examples.
    img = normalize(img, 1, 99.8, axis=axis_norm)
    model = StarDist2D.from_pretrained("2D_versatile_he")
    labels, _ = model.predict_instances(
        img, nms_thresh=nms_thresh, prob_thresh=prob_thresh
    )
    return labels

StarDist H&E segmentation method works on three input channels as n_channel_in = 3. We therefore pass channel = None to the sq.img.segment method which will then run the given segmentation method on all given channels.

StarDist2D.from_pretrained("2D_versatile_he")
Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.
StarDist2D(2D_versatile_he): YXC → YXC
├─ Directory: None
└─ Config2D(axes='YXC', backbone='unet', grid=(2, 2), n_channel_in=3, n_channel_out=33, n_dim=2, n_rays=32, net_conv_after_unet=128, net_input_shape=[None, None, 3], net_mask_shape=[None, None, 1], train_background_reg=0.0001, train_batch_size=8, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5', train_completion_crop=32, train_dist_loss='mae', train_epochs=200, train_foreground_only=0.9, train_learning_rate=0.0003, train_loss_weights=[1, 0.1], train_n_val_patches=3, train_patch_size=[512, 512], train_reduce_lr={'factor': 0.5, 'patience': 50, 'min_delta': 0}, train_shape_completion=False, train_steps_per_epoch=200, train_tensorboard=True, unet_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_kernel_size=[3, 3], unet_last_activation='relu', unet_n_conv_per_depth=2, unet_n_depth=3, unet_n_filter_base=32, unet_pool=[2, 2], unet_prefix='', use_gpu=False)
sq.im.segment(
    img=crop,
    layer="image",
    channel=None,
    method=stardist_2D_versatile_he,
    layer_added="segmented_stardist_default",
    prob_thresh=None,
    nms_thresh=None,
)
Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.
print(crop)
print(
    f"Number of segments in crop: {len(np.unique(crop['segmented_stardist_default']))}"
)

fig, axes = plt.subplots(1, 2)
crop.show("image", ax=axes[0])
_ = axes[0].set_title("H&H")
crop.show("segmented_stardist_default", cmap="jet", interpolation="none", ax=axes[1])
_ = axes[1].set_title("segmentation")
ImageContainer[shape=(1000, 1000), layers=['image', 'segmented_stardist_default']]
Number of segments in crop: 193
../../_images/04fbf33d5337485068b90e7432dd45096f8a2f426adcf15ec71489220c581e3a.png

Adjusting the prob_thresh parameter will enhance the segmentation. We show this additionally for prob_thresh = 0.3. Please be aware that the print statement of the default values will remain unchanged, even if you adjusted the parameters.

sq.im.segment(
    img=crop,
    layer="image",
    channel=None,
    method=stardist_2D_versatile_he,
    layer_added="segmented_stardist",
    prob_thresh=0.3,
    nms_thresh=None,
)
Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.
print(crop)
print(f"Number of segments in crop: {len(np.unique(crop['segmented_stardist']))}")

fig, axes = plt.subplots(1, 2)
crop.show("image", ax=axes[0])
_ = axes[0].set_title("H&H")
crop.show("segmented_stardist", cmap="jet", interpolation="none", ax=axes[1])
_ = axes[1].set_title("segmentation")
ImageContainer[shape=(1000, 1000), layers=['image', 'segmented_stardist', 'segmented_stardist_default']]
Number of segments in crop: 632
../../_images/2e79cb8c3ab44b467642f5c93e1a9dd83bd5136f254a3f7dcc7e2b0e0b630a9d.png