Creation of a dataset#

In deep learning, everything starts with a well-prepared dataset that provides inputs and outputs to the network that is supposed to be trained. Based on the data exploration of the previous notebook, we are creating a dataset class that can serve individual samples to us.

from data import get_dsb2018_train_files, get_dsb2018_validation_files, get_dsb2018_test_files, fill_label_holes, quantile_normalization
from tifffile import imread
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch

from torchvision import transforms
class DSBData():
    def __init__(self, image_files, label_files, target_shape=(256, 256)):
        """
        Parameters
        ----------
        image_files: list of pathlib.Path objects pointing to the *.tif images
        label_files: list of pathlib.Path objects pointing to the *.tif segmentation masks
        target_shape: tuple of length 2 specifying the sample resolutions of files that
                      will be kept. All other files will NOT be used.
        """
        assert len(image_files) == len(label_files)
        assert all(x.name==y.name for x,y in zip(image_files, label_files))

        self.images = []
        self.labels = []

        tensor_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
        # use tqdm to have eye pleasing error bars
        for idx in tqdm(range(len(image_files))):
            # we use the same data reading approach as in the previous notebook
            image = imread(image_files[idx])
            label = imread(label_files[idx])

            if image.shape != target_shape:
                continue
            
            # do the normalizations
            image = quantile_normalization(
                image,
                quantile_low=0.01,
                quantile_high=0.998,
                clip=True)[0].astype(np.float32)

            # NOTE: we convert the label to dtype float32 and not uint8 because
            # the tensor transformation does a normalization if the input is of
            # dtype uint8, destroying the 0/1 labelling which we want to avoid.
            label = fill_label_holes(label)
            label_binary = np.zeros_like(label).astype(np.float32)
            label_binary[label != 0] = 1.
            
            # convert to torch tensor: adds an artificial color channel in the front
            # and scales inputs to have same size as samples tend to differ in image
            # resolutions
            image = tensor_transform(image)
            label = tensor_transform(label_binary)

            self.images.append(image)
            self.labels.append(label)
            
        self.images = torch.stack(self.images)
        self.labels = torch.stack(self.labels)
    
    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

    def __len__(self):
        return len(self.images)
train_img_files, train_lbl_files = get_dsb2018_train_files()

n_samples = len(train_img_files)

train_data = DSBData(
    image_files=train_img_files[:n_samples],
    label_files=train_lbl_files[:n_samples],
    target_shape=(256, 256)
)

# NOTE: the length of the dataset might not be the same as n_samples
#       because files not having the target shape will be discarded
print(len(train_data))
print(train_data.images.shape, train_data.labels.shape)
print(train_data.images.min(), train_data.images.max())
print(train_data.labels.unique())
val_img_files, val_lbl_files = get_dsb2018_validation_files()

n_samples = len(val_img_files)

val_data = DSBData(
    image_files=val_img_files[:n_samples],
    label_files=val_lbl_files[:n_samples],
    target_shape=(256, 256)
)

# NOTE: the length of the dataset might not be the same as n_samples
#       because files not having the target shape will be discarded
print(len(val_data))
image, label = train_data[0]
print(image.shape, label.shape)

plt.subplot(121)
plt.imshow(image[0].numpy(), cmap="gray")

plt.subplot(122)
plt.imshow(label[0].numpy(), cmap="gray")

Exercise: What if I’d chosen a different shape?#

Return to the last notebook. Check what different shapes are available in the data set. Compose a data set object only with them. Take 1-2 samples and display them in the notebook.