Stardist in Python: Training#

This notebook will demonstrate how to train your own stardist model based on the data you should have downloaded in the [previous notebook].(./01_Stardist2d_in_python_getting_data.ipynb)

from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize

from skimage import transform

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D

np.random.seed(42)
lbl_cmap = random_label_cmap()
1108461538.py (5): Support for setting an rcParam that expects a str value to a non-str value is deprecated since 3.5 and support will be removed two minor releases later.

Data#

Training data (for input `X` with associated label masks `Y`) can be provided via lists of numpy arrays, where each image can have a different size. Alternatively, a single numpy array can also be used if all images have the same size. Input images can either be two-dimensional (single-channel) or three-dimensional (multi-channel) arrays, where the channel axis comes last. Label images need to be integer-valued.
X = sorted(glob('data/dsb2018/train/images/*.tif'))
Y = sorted(glob('data/dsb2018/train/masks/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))
X = list(map(imread,X))
Y = list(map(imread,Y))
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]

Normalize images and fill small label holes.

axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]
100%|██████████████████████████████████████████████████████████████████████| 447/447 [00:00<00:00, 452.66it/s]
100%|██████████████████████████████████████████████████████████████████████| 447/447 [00:01<00:00, 238.41it/s]

Split into train and validation datasets.

assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))
number of images: 447
- training:       380
- validation:      67

Training data consists of pairs of input image and label instances.

def plot_img_label(img, lbl, img_title="image", lbl_title="label", **kwargs):
    fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))
    im = ai.imshow(img, cmap='gray', clim=(0,1))
    ai.set_title(img_title)    
    fig.colorbar(im, ax=ai)
    al.imshow(lbl, cmap=lbl_cmap)
    al.set_title(lbl_title)
    plt.tight_layout()
i = min(9, len(X)-1)
img, lbl = X[i], Y[i]
assert img.ndim in (2,3)
img = img if (img.ndim==2 or img.shape[-1]==3) else img[...,0]
plot_img_label(img,lbl)
None;
../_images/02_Stardist2d_in_python_training_11_0.png

Configuration#

A StarDist2D model is specified via a Config2D object.

print(Config2D.__doc__)
Configuration for a :class:`StarDist2D` model.

    Parameters
    ----------
    axes : str or None
        Axes of the input images.
    n_rays : int
        Number of radial directions for the star-convex polygon.
        Recommended to use a power of 2 (default: 32).
    n_channel_in : int
        Number of channels of given input image (default: 1).
    grid : (int,int)
        Subsampling factors (must be powers of 2) for each of the axes.
        Model will predict on a subsampled grid for increased efficiency and larger field of view.
    n_classes : None or int
        Number of object classes to use for multi-class predection (use None to disable)
    backbone : str
        Name of the neural network architecture to be used as backbone.
    kwargs : dict
        Overwrite (or add) configuration attributes (see below).


    Attributes
    ----------
    unet_n_depth : int
        Number of U-Net resolution levels (down/up-sampling layers).
    unet_kernel_size : (int,int)
        Convolution kernel size for all (U-Net) convolution layers.
    unet_n_filter_base : int
        Number of convolution kernels (feature channels) for first U-Net layer.
        Doubled after each down-sampling layer.
    unet_pool : (int,int)
        Maxpooling size for all (U-Net) convolution layers.
    net_conv_after_unet : int
        Number of filters of the extra convolution layer after U-Net (0 to disable).
    unet_* : *
        Additional parameters for U-net backbone.
    train_shape_completion : bool
        Train model to predict complete shapes for partially visible objects at image boundary.
    train_completion_crop : int
        If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches.
        Should be chosen based on (largest) object sizes.
    train_patch_size : (int,int)
        Size of patches to be cropped from provided training images.
    train_background_reg : float
        Regularizer to encourage distance predictions on background regions to be 0.
    train_foreground_only : float
        Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels.
    train_sample_cache : bool
        Activate caching of valid patch regions for all training images (disable to save memory for large datasets)
    train_dist_loss : str
        Training loss for star-convex polygon distances ('mse' or 'mae').
    train_loss_weights : tuple of float
        Weights for losses relating to (probability, distance)
    train_epochs : int
        Number of training epochs.
    train_steps_per_epoch : int
        Number of parameter update steps per epoch.
    train_learning_rate : float
        Learning rate for training.
    train_batch_size : int
        Batch size for training.
    train_n_val_patches : int
        Number of patches to be extracted from validation images (``None`` = one patch per image).
    train_tensorboard : bool
        Enable TensorBoard for monitoring training progress.
    train_reduce_lr : dict
        Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable.
    use_gpu : bool
        Indicate that the data generator should use OpenCL to do computations on the GPU.

        .. _ReduceLROnPlateau: https://keras.io/api/callbacks/reduce_lr_on_plateau/
    

Before actually doing the training, we need to set some parameters. Predominantly these parameters are

  • Number of rays: See previous notebook

  • Use the graphics card for training: This is disabled by default and will not work on PCs without an NVidia graphics card. It may be possible to enable this feature on your PC by downloading cudatoolkit - no guarantuee given, though.

  • grid: The size of the image may be too large for your memory to handle during training - thus it is common to split the image into tiles and process these tiles separately.

# 32 is a good default choice (see 1_data.ipynb)
n_rays = 32

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = False and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = (2,2)

conf = Config2D (
    n_rays       = n_rays,
    grid         = grid,
    use_gpu      = use_gpu,
    n_channel_in = n_channel,
)
print(conf)
vars(conf)
Config2D(n_dim=2, axes='YXC', n_channel_in=1, n_channel_out=33, train_checkpoint='weights_best.h5', train_checkpoint_last='weights_last.h5', train_checkpoint_epoch='weights_now.h5', n_rays=32, grid=(2, 2), backbone='unet', n_classes=None, unet_n_depth=3, unet_kernel_size=(3, 3), unet_n_filter_base=32, unet_n_conv_per_depth=2, unet_pool=(2, 2), unet_activation='relu', unet_last_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_prefix='', net_conv_after_unet=128, net_input_shape=(None, None, 1), net_mask_shape=(None, None, 1), train_shape_completion=False, train_completion_crop=32, train_patch_size=(256, 256), train_background_reg=0.0001, train_foreground_only=0.9, train_sample_cache=True, train_dist_loss='mae', train_loss_weights=(1, 0.2), train_class_weights=(1, 1), train_epochs=400, train_steps_per_epoch=100, train_learning_rate=0.0003, train_batch_size=4, train_n_val_patches=None, train_tensorboard=True, train_reduce_lr={'factor': 0.5, 'patience': 40, 'min_delta': 0}, use_gpu=False)
{'n_dim': 2,
 'axes': 'YXC',
 'n_channel_in': 1,
 'n_channel_out': 33,
 'train_checkpoint': 'weights_best.h5',
 'train_checkpoint_last': 'weights_last.h5',
 'train_checkpoint_epoch': 'weights_now.h5',
 'n_rays': 32,
 'grid': (2, 2),
 'backbone': 'unet',
 'n_classes': None,
 'unet_n_depth': 3,
 'unet_kernel_size': (3, 3),
 'unet_n_filter_base': 32,
 'unet_n_conv_per_depth': 2,
 'unet_pool': (2, 2),
 'unet_activation': 'relu',
 'unet_last_activation': 'relu',
 'unet_batch_norm': False,
 'unet_dropout': 0.0,
 'unet_prefix': '',
 'net_conv_after_unet': 128,
 'net_input_shape': (None, None, 1),
 'net_mask_shape': (None, None, 1),
 'train_shape_completion': False,
 'train_completion_crop': 32,
 'train_patch_size': (256, 256),
 'train_background_reg': 0.0001,
 'train_foreground_only': 0.9,
 'train_sample_cache': True,
 'train_dist_loss': 'mae',
 'train_loss_weights': (1, 0.2),
 'train_class_weights': (1, 1),
 'train_epochs': 400,
 'train_steps_per_epoch': 100,
 'train_learning_rate': 0.0003,
 'train_batch_size': 4,
 'train_n_val_patches': None,
 'train_tensorboard': True,
 'train_reduce_lr': {'factor': 0.5, 'patience': 40, 'min_delta': 0},
 'use_gpu': False}
if use_gpu:
    from csbdeep.utils.tf import limit_gpu_memory
    # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
    limit_gpu_memory(0.8)
    # alternatively, try this:
    # limit_gpu_memory(None, allow_growth=True)

Note: The trained StarDist2D model will not predict completed shapes for partially visible objects at the image boundary if train_shape_completion=False (which is the default option).

model = StarDist2D(conf, name='stardist', basedir='models')
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.24019, nms_thresh=0.3.

Check if the neural network has a large enough field of view to see up to the boundary of most objects.

median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")
1/1 [==============================] - 0s 182ms/step
1/1 [==============================] - 0s 61ms/step
median object size:      [17.5 18. ]
network field of view :  [94 94]

Data Augmentation#

You can define a function/callable that applies augmentation to each batch of the data generator.
We here use an augmenter that applies random rotations, flips, and intensity changes, which are typically sensible for (2D) microscopy images (but you can disable augmentation by setting augmenter = None).

def random_fliprot(img, mask): 
    assert img.ndim >= mask.ndim
    axes = tuple(range(mask.ndim))
    perm = tuple(np.random.permutation(axes))
    img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(perm) 
    for ax in axes: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    # add some gaussian noise
    sig = 0.02*np.random.uniform(0,1)
    x = x + sig*np.random.normal(0,1,x.shape)
    return x, y
# plot some augmented examples
img, lbl = X[0],Y[0]
plot_img_label(img, lbl)
for _ in range(3):
    img_aug, lbl_aug = augmenter(img,lbl)
    plot_img_label(img_aug, lbl_aug, img_title="image augmented", lbl_title="label augmented")
../_images/02_Stardist2d_in_python_training_24_0.png ../_images/02_Stardist2d_in_python_training_24_1.png ../_images/02_Stardist2d_in_python_training_24_2.png ../_images/02_Stardist2d_in_python_training_24_3.png

Training#

We recommend to monitor the progress during training with TensorBoard. You can start it in the shell from the current working directory like this:

$ tensorboard --logdir=.

Then connect to http://localhost:6006/ with your browser.

quick_demo = True

if quick_demo:
    print (
        "NOTE: This is only for a quick demonstration!\n"
        "      Please set the variable 'quick_demo = False' for proper (long) training.",
        file=sys.stderr, flush=True
    )
    model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,
                epochs=2, steps_per_epoch=10)

    print("====> Stopping training and loading previously trained demo model from disk.", file=sys.stderr, flush=True)
    model = StarDist2D.from_pretrained('2D_demo')
else:
    model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)
None;
NOTE: This is only for a quick demonstration!
      Please set the variable 'quick_demo = False' for proper (long) training.
Epoch 1/2
3/3 [==============================] - 0s 45ms/steploss: 1.7712 - prob_loss: 0.3109 - dist_loss: 7.3018 - prob_kld: 0.2339 - dist_relevant_mae: 7.3011 - dist_relevant_mse: 111.6556 - dist_dist_iou_metric: 0.
10/10 [==============================] - 10s 1s/step - loss: 1.7712 - prob_loss: 0.3109 - dist_loss: 7.3018 - prob_kld: 0.2339 - dist_relevant_mae: 7.3011 - dist_relevant_mse: 111.6556 - dist_dist_iou_metric: 0.3070 - val_loss: 1.8127 - val_prob_loss: 0.2679 - val_dist_loss: 7.7241 - val_prob_kld: 0.1960 - val_dist_relevant_mae: 7.7235 - val_dist_relevant_mse: 135.9931 - val_dist_dist_iou_metric: 0.2631 - lr: 3.0000e-04
Epoch 2/2
3/3 [==============================] - 0s 40ms/steploss: 1.9356 - prob_loss: 0.3372 - dist_loss: 7.9920 - prob_kld: 0.2581 - dist_relevant_mae: 7.9912 - dist_relevant_mse: 129.1611 - dist_dist_iou_metric: 0.
10/10 [==============================] - 10s 1s/step - loss: 1.9356 - prob_loss: 0.3372 - dist_loss: 7.9920 - prob_kld: 0.2581 - dist_relevant_mae: 7.9912 - dist_relevant_mse: 129.1611 - dist_dist_iou_metric: 0.3146 - val_loss: 1.8624 - val_prob_loss: 0.3237 - val_dist_loss: 7.6934 - val_prob_kld: 0.2518 - val_dist_relevant_mae: 7.6928 - val_dist_relevant_mse: 135.1480 - val_dist_dist_iou_metric: 0.2563 - lr: 3.0000e-04

Loading network weights from 'weights_best.h5'.
====> Stopping training and loading previously trained demo model from disk.
Found model '2D_demo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.486166, nms_thresh=0.5.

Threshold optimization#

While the default values for the probability and non-maximum suppression thresholds already yield good results in many cases, we still recommend to adapt the thresholds to your data. The optimized threshold values are saved to disk and will be automatically loaded with the model.

if quick_demo:
    model.optimize_thresholds(X_val[:2], Y_val[:2])
else:
    model.optimize_thresholds(X_val, Y_val)
1/1 [==============================] - 0s 195ms/step
1/1 [==============================] - 0s 83ms/step
NMS threshold = 0.3:  75%|███████████████████████████         | 15/20 [00:01<00:00, 10.23it/s, 0.511 -> 0.700]
NMS threshold = 0.4:  75%|███████████████████████████         | 15/20 [00:01<00:00,  9.82it/s, 0.511 -> 0.688]
NMS threshold = 0.5:  75%|███████████████████████████         | 15/20 [00:01<00:00, 10.23it/s, 0.511 -> 0.688]
Using optimized values: prob_thresh=0.508553, nms_thresh=0.3.

Evaluation and Detection Performance#

Besides the losses and metrics during training, we can also quantitatively evaluate the actual detection/segmentation performance on the validation data by considering objects in the ground truth to be correctly matched if there are predicted objects with overlap (here intersection over union (IoU)) beyond a chosen IoU threshold \(\tau\).

The corresponding matching statistics (average overlap, accuracy, recall, precision, etc.) are typically of greater practical relevance than the losses/metrics computed during training (but harder to formulate as a loss function). The value of \(\tau\) can be between 0 (even slightly overlapping objects count as correctly predicted) and 1 (only pixel-perfectly overlapping objects count) and which \(\tau\) to use depends on the needed segmentation precision/application.

Please see help(matching) for definitions of the abbreviations used in the evaluation below and see the Wikipedia page on Sensitivity and specificity for further details.

# help(matching)

First predict the labels for all validation images:

Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

Plot a GT/prediction example

plot_img_label(X_val[0],Y_val[0], lbl_title="label GT")
plot_img_label(X_val[0],Y_val_pred[0], lbl_title="label Pred")
../_images/02_Stardist2d_in_python_training_37_0.png ../_images/02_Stardist2d_in_python_training_37_1.png

Choose several IoU thresholds \(\tau\) that might be of interest and for each compute matching statistics for the validation data.

taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]
100%|███████████████████████████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.82it/s]

Example: Print all available matching statistics for \(\tau=0.5\)

stats[taus.index(0.5)]
DatasetMatching(criterion='iou', thresh=0.5, fp=108, tp=2239, fn=304, precision=0.953983809118023, recall=0.8804561541486433, accuracy=0.8445869483213881, f1=0.9157464212678936, n_true=2543, n_pred=2347, mean_true_score=0.7459801213740365, mean_matched_score=0.8472654973890911, panoptic_quality=0.775880347097822, by_image=False)

Plot the matching statistics and the number of true/false positives/negatives as a function of the IoU threshold \(\tau\).

fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
    ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()

for m in ('fp', 'tp', 'fn'):
    ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();
../_images/02_Stardist2d_in_python_training_43_0.png

Receptive field#

The receptive field of a convolutional network determines, how large of an area in the input image the deepest layers in the network (highest number of feature maps, lowest number of image size) can see. We can demonstrate this by simply upscaling an input image and check whether stardist is still able to see it:

sample = X_val[0]
sample_upscaled = transform.rescale(sample, 2)[:sample.shape[0], :sample.shape[1]]
sample_upscaled_GT = transform.rescale(Y_val[0], 2)[:sample.shape[0], :sample.shape[1]]

Let’s have a look at the upscaled data:

fig, axes = plt.subplots(ncols=2, nrows=2)
axes[0,0].imshow(sample, cmap='gray')
axes[1,0].imshow(sample_upscaled, cmap='gray')

axes[0,1].imshow(Y_val[0])
axes[1,1].imshow(sample_upscaled_GT)

axes[0,0].set_title('Raw')
axes[1,0].set_title('Upscaled x2')
axes[0,1].set_title('Ground truth')
axes[1,1].set_title('Upscaled x2')
fig.tight_layout()
../_images/02_Stardist2d_in_python_training_47_0.png

Exercise 1#

Run the prediction on both images (sample and sample_upscaled) and compare the predictions - what do you observe? Hint: Run the prediction on some data with model.predict_instances()