Practical 2: Dask with images#

In the previous practical, we’ve seen that dask can help us parallelise computations on arrays. This can be useful for many operations typically performed on arrays like filtering.

# Let's load an example image

import numpy as np
from skimage import data
from scipy import ndimage
import tifffile

%matplotlib notebook

img = data.cells3d()
img = img.max(0)[1] # take only one channel and max project
img = ndimage.zoom(img, 10, order=1) # zoom in

tifffile.imshow(img)
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x11daa1970>)

How long does a gaussian filter take when applied to the entire image?

%%timeit -r 3
ndimage.gaussian_filter(img, sigma=5, mode='constant')
103 ms ± 535 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

What if we subdivide the array into chunks and apply this filter to each chunk?

import dask.array as da

img_da = da.from_array(img,
                       chunks=(500, 500),
                       )
img_da
Array Chunk
Bytes 12.50 MiB 488.28 kiB
Shape (2560, 2560) (500, 500)
Dask graph 36 chunks in 1 graph layer
Data type uint16 numpy.ndarray
2560 2560

map_blocks#

We can use dask.array.map_blocks to apply a function to each chunk (or block) of the dask array.

filtered = da.map_blocks(
            ndimage.gaussian_filter, # the function to apply to each chunk
            img_da, # the array to apply the function to
            sigma=5, # arguments to the function
            mode='constant',
            )
filtered
Array Chunk
Bytes 12.50 MiB 488.28 kiB
Shape (2560, 2560) (500, 500)
Dask graph 36 chunks in 2 graph layers
Data type uint16 numpy.ndarray
2560 2560

Does this improve the timing?

%%timeit -r 3
filtered.compute(scheduler='threads')
24.3 ms ± 373 µs per loop (mean ± std. dev. of 3 runs, 1 loop each)
%%timeit -r 3
filtered.compute(scheduler='processes')
712 ms ± 3.56 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)

Performance comparison: Applying the gaussian filter on each funk is faster when using multi-threading than when using multi-processing.

Why is this? While threads share memory, different processes need to send data back and forth, which can create considerable overhead.

Let’s have a look at the output image.

print('entire image')
filtered_ndimage = ndimage.gaussian_filter(img, sigma=5, mode='constant')
tifffile.imshow(filtered_ndimage)

print('dask.array.map_blocks')
tifffile.imshow(filtered)
entire image