Practical 2: Dask with images#

In the previous practical, we’ve seen that dask can help us parallelise computations on arrays. This can be useful for many operations typically performed on arrays like filtering.

# Let's load an example image

import numpy as np
from skimage import data
from scipy import ndimage
import tifffile

%matplotlib notebook

img = data.cells3d()
img = img.max(0)[1] # take only one channel and max project
img = ndimage.zoom(img, 10, order=1) # zoom in

tifffile.imshow(img)
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x11daa1970>)

How long does a gaussian filter take when applied to the entire image?

%%timeit -r 3
ndimage.gaussian_filter(img, sigma=5, mode='constant')
103 ms ± 535 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

What if we subdivide the array into chunks and apply this filter to each chunk?

import dask.array as da

img_da = da.from_array(img,
                       chunks=(500, 500),
                       )
img_da
Array Chunk
Bytes 12.50 MiB 488.28 kiB
Shape (2560, 2560) (500, 500)
Dask graph 36 chunks in 1 graph layer
Data type uint16 numpy.ndarray
2560 2560

map_blocks#

We can use dask.array.map_blocks to apply a function to each chunk (or block) of the dask array.

filtered = da.map_blocks(
            ndimage.gaussian_filter, # the function to apply to each chunk
            img_da, # the array to apply the function to
            sigma=5, # arguments to the function
            mode='constant',
            )
filtered
Array Chunk
Bytes 12.50 MiB 488.28 kiB
Shape (2560, 2560) (500, 500)
Dask graph 36 chunks in 2 graph layers
Data type uint16 numpy.ndarray
2560 2560

Does this improve the timing?

%%timeit -r 3
filtered.compute(scheduler='threads')
24.3 ms ± 373 µs per loop (mean ± std. dev. of 3 runs, 1 loop each)
%%timeit -r 3
filtered.compute(scheduler='processes')
712 ms ± 3.56 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)

Performance comparison: Applying the gaussian filter on each funk is faster when using multi-threading than when using multi-processing.

Why is this? While threads share memory, different processes need to send data back and forth, which can create considerable overhead.

Let’s have a look at the output image.

print('entire image')
filtered_ndimage = ndimage.gaussian_filter(img, sigma=5, mode='constant')
tifffile.imshow(filtered_ndimage)

print('dask.array.map_blocks')
tifffile.imshow(filtered)
entire image
dask.array.map_blocks
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x1560a0ac0>)

We can prevent these border artefacts by using map_overlap instead of map_blocks.

This:

  1. adds neighboring chunk values to the borders of each chunk)

  2. applies map_blocks as before

  3. trims the previously added overlap from each chunk

filtered_overlap = \
    da.map_overlap(
            ndimage.gaussian_filter, # the function to apply to each chunk
            img_da, # the array to apply the function to
            sigma=5, # arguments to the function
            mode='constant',
            depth=11,
            )
filtered_overlap
Array Chunk
Bytes 12.50 MiB 488.28 kiB
Shape (2560, 2560) (500, 500)
Dask graph 36 chunks in 5 graph layers
Data type uint16 numpy.ndarray
2560 2560
tifffile.imshow(filtered_overlap.compute())
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x157b457c0>)
%%timeit -r 3
filtered_overlap.compute(scheduler='threads')
40.2 ms ± 216 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

dask-image#

There’s a python package which automatically deals with these border effects and other problems that can occur when applying the functions available from scipy.ndimage to tiled dask arrays.

https://image.dask.org/en/latest/

The available ndimage functions: https://image.dask.org/en/latest/coverage.html

Among others:

  • affine_transform

  • label

from dask_image import ndfilters

filtered_di = ndfilters.gaussian_filter(img_da, sigma=5, mode='constant')
filtered_di
tifffile.imshow(filtered_di.compute())
%%timeit -r 3
filtered_di.compute()

More dask-image features#

Connected components#

from dask_image import ndfilters

img_da = da.from_array(img, chunks=500)
seg = (ndfilters.gaussian_filter(img_da, sigma=10, mode='constant') > 10000)
tifffile.imshow(seg)
/Users/malbert/miniconda3/envs/devbio-napari-env/lib/python3.9/site-packages/tifffile/tifffile.py:22741: UserWarning: Attempting to set identical low and high xlims makes transformation singular; automatically expanding.
  pyplot.Slider(
(<Figure size 988.8x604.8 with 3 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x157fc93a0>)
# Let's calculate connected components on each chunk of the segmentation image

def connected_components(im):
    return ndimage.label(im)[0]

labels = seg.map_blocks(connected_components)

tifffile.imshow(labels)
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x2840e4a00>)
# Using overlap does not help in this case

def connected_components(im):
    return ndimage.label(im)[0]

labels = seg.map_overlap(
    connected_components,
    depth=100,
)
tifffile.imshow(np.array(labels))
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x285af74f0>)
# dask-image implements connected components

from dask_image import ndmeasure
labels = ndmeasure.label(seg)[0]
tifffile.imshow(labels)
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x28fc857f0>)

Affine transformations#

# Define a transformation

from scipy.spatial.transform import Rotation as R

# rotation
matrix = R.from_rotvec(np.pi/4. * np.array([0, 0, 1])).as_matrix()[:2, :2]
offset = np.array([1200., -600])

print('Matrix:', matrix)
print('Offset:', offset)
# Transform the image using plain scipy

img_t = ndimage.affine_transform(
    img,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    )

tifffile.imshow(img_t)
# Transform the image using dask_image.ndinterp.affine_transformation

from dask_image import ndinterp

img_t = ndinterp.affine_transform(
    img_da,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    output_chunks=500,
    ).compute()
    
tifffile.imshow(img_t)

Performance comparison

%%timeit -r 1

img_t = ndimage.affine_transform(
    img,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    )
%%timeit -r 1

img_t = ndinterp.affine_transform(
    img_da,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    output_chunks=500,
    ).compute()

Excercise: Apply a median filter#

filtered = ...