Processing batches of data#

As most deep learning workflows benefit greatly from running on machines with GPUs that can process data in parallel, during model training the data is passed in batches of samples to the network instead of processing each sample sequentially. Torch offers great support for this which builds on top of a provided dataset. For convenience, the dataset class introduced in the previous notebook is part of the data module and we can now easily import it.

from data import DSBData, get_dsb2018_train_files
train_img_files, train_lbl_files = get_dsb2018_train_files()

train_data = DSBData(
    image_files=train_img_files,
    label_files=train_lbl_files,
    target_shape=(256, 256)
)

print(len(train_data))

Before starting to work with the data and actual models, we have to wrap out dataset object in a DataLoader.

from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

The DataLoader comes with out of the box support for iterators that make looping code a bit more concise.

for batch_idx, (batch_images, batch_labels) in enumerate(train_loader):
    print("Batch", batch_idx, batch_images.shape, batch_labels.shape)
    #break

Neural network architecture#

For semantic segmentation problems, a specific convolutional neural network architecture, i.e. a defined sequence of operations (also called layers) involving convolutional filters, data aggregation via pooling and nonlinear activation functions, has been demonstrated to work well across a wide range of image domains. This architecture is called UNet and its basic structure is shown below. (Image taken from here.)

Drawing

As this is rather cumbersome to implement directly, we will use the MONAI library, which provides a convenient torch implementation of this architecture by the name of BasicUNet.

If you are interested, the MONAI library offers many more architectures in their network architectures documentation section.

import matplotlib.pyplot as plt
import torch

from monai.networks.nets import BasicUNet
BasicUNet?
model = BasicUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    features=[16, 16, 32, 64, 128, 16],
    act="relu",
    norm="batch",
    dropout=0.25,
)
print(model)

We can now feed a batch of images directly through the model to obtain predictions. Note however, that those will likely not be usable for segmentation as the model has not been trained yet and model parameters are initialized randomly.

Very importantly, the model outputs are of the same shape as the model inputs. Because the UNet consists entirely of convolutional operations, it is (to a degree) shape invariant and can process arbitrary input sizes. It is however recommended to work with resolutions that are divisible by 16, as the input resolution is halved in each of the four downsampling blocks.

batch_preds = model(batch_images)
print(batch_preds.shape)
plt.subplot(131)
plt.imshow(batch_images[0, 0].numpy(), cmap="gray")
plt.title("Input")

plt.subplot(132)
plt.imshow(batch_labels[0, 0].numpy(), cmap="gray")
plt.title("Ground truth")

plt.subplot(133)
plt.imshow(batch_preds.detach()[0, 0].numpy(), cmap="gray")
plt.title("Predictions")
# different sized dummy input should be processable as well
dummy_batch = torch.zeros(8, 1, 512, 512)
dummy_preds = model(dummy_batch)
print(dummy_preds.shape)
# different sized dummy input that is not divisible by 16, still produces output of same shape
dummy_batch = torch.zeros(8, 1, 114, 87)
dummy_preds = model(dummy_batch)
print(dummy_preds.shape)

The model output range is not limited to [0,1.) because in the output layer, no nonlinear activation was used which could have transformed the output pixel values as such.

To fix this and make the output usable for segmentation purposes, we apply a sigmoid activation function per pixel.

print(batch_preds.min(), batch_preds.max())
batch_preds_seg = torch.nn.functional.sigmoid(batch_preds)
print(batch_preds_seg.min(), batch_preds_seg.max())
plt.imshow(batch_preds_seg.detach()[0, 0], cmap="gray")
plt.colorbar(orientation="horizontal")

In order to obtain binary (0/1) predictions, a straightforward approach would be to use thresholding at 0.5

batch_preds_seg_binary = (batch_preds_seg > 0.5).to(torch.uint8)
plt.imshow(batch_preds_seg_binary.detach()[0, 0], cmap="gray")
plt.colorbar(orientation="horizontal")

Out model is not trained yet. So don’t be bothered too much to just see garbage in the plot above.

Exercise: My first MONAI BasicUnet#

Play with the model a bit. Take the constructor and change some parameters, e.g. the features, the activation, normalisation. Then, have the model predict on the same image as above. Display the prediction and compare to what we saw earlier. Do you spot a difference?