Training with checkpoints#

The currently developed pipeline has the major drawbacks that the states of the parameters are not saved. So after training, when closing the notebook, you no longer have access to the trained model. This has to be fixed to save you hours of re-training models over and over again. Also, it might be a good idea to save some “snapshots” of model parameters obtained during training, not only once training is finished. This can be achieved using checkpointing.

from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import torch
cuda_present = torch.cuda.is_available()
ndevices = torch.cuda.device_count()
use_cuda = cuda_present and ndevices > 0
device = torch.device("cuda" if use_cuda else "cpu")  # "cuda:0" ... default device, "cuda:1" would be GPU index 1, "cuda:2" etc
print("number of devices:", ndevices, "\tchosen device:", device, "\tuse_cuda=", use_cuda)
from torch.utils.data import DataLoader
from data import DSBData, get_dsb2018_train_files
from monai.networks.nets import BasicUNet
train_img_files, train_lbl_files = get_dsb2018_train_files()

train_data = DSBData(
    image_files=train_img_files,
    label_files=train_lbl_files,
    target_shape=(256, 256)
)

print(len(train_data))

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=1, pin_memory=True)
model = BasicUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    features=[16, 16, 32, 64, 128, 16],
    act="relu",
    norm="batch",
    dropout=0.25,
)

# transfer the model to the chosen device
model = model.to(device)

Training of a neural network means updating its parameters (weights) using a strategy that involves the gradients of a loss function with respect to the model parameters in order to adjust model weights to minimize this loss.

optimizer = torch.optim.Adam(model.parameters(), lr=1.e-3)
init_params = list(model.parameters())[0].clone().detach()

Such a training is performed by iterating over the batches of the training dataset multiple times. Each full iteration over the dataset is termed an epoch.

During or after training the tensorboard logs can be visualized as follows: in a terminal, type

tensorboard --logdir "path/to/logs",

then open a browser on localhost:6006 (or whichever port the tensorboard server outputted as running on). Alternatively, tensorboard can be accessed from jupyter as well:

%load_ext tensorboard
%tensorboard --port 6006 --logdir ./logs
max_nepochs = 2
log_interval = 1
writer = SummaryWriter(log_dir="logs", comment="this is the test of SummaryWriter")

model.train(True)

chpfolder = Path("chkpts")
if not chpfolder.is_dir():
    chpfolder.mkdir()

# expects raw unnormalized scores and combines sigmoid + BCELoss for better
# numerical stability.
# expects B x C x W x D
loss_function = torch.nn.BCEWithLogitsLoss(reduction="mean")

for epoch in range(1, max_nepochs + 1):
    for batch_idx, (X, y) in enumerate(train_loader):
        # the inputs and labels have to be on the same device as the model
        X, y = X.to(device), y.to(device)
        
        optimizer.zero_grad()

        prediction_logits = model(X)
        
        batch_loss = loss_function(prediction_logits, y)

        batch_loss.backward()

        optimizer.step()

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch:",
                epoch,
                "Batch:",
                batch_idx,
                "Total samples processed:",
                (batch_idx + 1) * train_loader.batch_size,
                "Loss:",
                batch_loss.item(),
            )
            writer.add_scalar("Loss/train", batch_loss.item(), batch_idx)
    # epoch finished, we save the model
    cpath = chpfolder / f"epoch-{epoch:03.0f}.pth"
    torch.save(
        {
            "final_epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        cpath,
    )

    assert cpath.is_file() and cpath.stat().st_size > 0
writer.close()
final_params = list(model.parameters())[0].clone().detach()
assert not torch.allclose(init_params, final_params)

Restoring the model from a saved checkpoint, e.g. for doing inference, can be done as follows:

payload = torch.load(cpath)
model_from_ckpt = BasicUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    features=[16, 16, 32, 64, 128, 16],
    act="relu",
    norm="batch",
    dropout=0.25,
)
model_from_ckpt.load_state_dict(payload['model_state_dict'])
# continue learning/training after this
loaded_params = list(model_from_ckpt.parameters())[0]
assert torch.allclose(loaded_params, final_params)