Training a Unet#

Now we have all pieces in place to train a network to segment images for us. Let’s do it!

from torch.utils.data import DataLoader
from data import DSBData, get_dsb2018_train_files
import torch
from monai.networks.nets import BasicUNet
train_img_files, train_lbl_files = get_dsb2018_train_files()

train_data = DSBData(
    image_files=train_img_files,
    label_files=train_lbl_files,
    target_shape=(256, 256)
)

print(len(train_data))

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
model = BasicUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    features=[16, 16, 32, 64, 128, 16],
    act="relu",
    norm="batch",
    dropout=0.25,
)

Training of a neural network means updating its parameters (weights) in order to descrese what is called the loss function. This is performed using an optimizer (Adam here) which uses the gradient of this loss function ith respect to the model parameters in order to adjust model weights. This should lead to an ever descreasing loss during training.

optimizer = torch.optim.Adam(model.parameters(), lr=1.e-3)
init_params = list(model.parameters())[0].clone().detach() #storing it for later use

Such a training is performed by iterating over the batches of the training dataset multiple times. Each full iteration over the dataset is termed an epoch.

max_nepochs = 1
log_interval = 1
model.train(True) #the model is put in training mode, i.e. gradients are computed

# BCEWithLogitsLoss expects raw unnormalized scores and combines sigmoid + BCELoss for better
# numerical stability.
# expects B x C x W x D
loss_function = torch.nn.BCEWithLogitsLoss(reduction="mean")

for epoch in range(1, max_nepochs + 1):
    for batch_idx, (X, y) in enumerate(train_loader):
        # print("train", batch_idx, X.shape, y.shape)

        optimizer.zero_grad()

        prediction_logits = model(X)
        
        batch_loss = loss_function(prediction_logits, y)

        batch_loss.backward()

        optimizer.step()

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch:",
                epoch,
                "Batch:",
                batch_idx,
                "Total samples processed:",
                (batch_idx + 1) * train_loader.batch_size,
                "Loss:",
                batch_loss.item(),
            )
final_params = list(model.parameters())[0].clone().detach()
assert not torch.allclose(init_params, final_params)

Look at some predictions#

Now that the model has been trained for a little bit, we are looking at the predictions again. Usually model training has to be peformed longer, so don’t expect any wonders. Also keep in mind that the predictions here are based on the data the model was trained on. Those predictions might be far better than those on data not used during training. But this is a story for later.

import matplotlib.pyplot as plt
# convert to 0/1 range on each pixel
prediction = torch.nn.functional.sigmoid(prediction_logits)
prediction_binary = (prediction > 0.5).to(torch.uint8)

sidx = 0
plt.subplot(131)
plt.imshow(X[sidx, 0].numpy(), cmap="gray")
plt.title("Input")

plt.subplot(132)
plt.imshow(y[sidx, 0].numpy(), cmap="gray")
plt.title("Ground truth")

plt.subplot(133)
plt.imshow(prediction_binary.detach()[sidx, 0].numpy(), cmap="gray")
plt.title("Predictions")

Exercise: We can do better!#

Take the training code from above and have the model train for longer. For example, try 10 or 20 epochs. Do you see any improvements?