Practical 2: Dask with images#
In the previous practical, we’ve seen that dask can help us parallelise computations on arrays. This can be useful for many operations typically performed on arrays like filtering.
# Let's load an example image
import numpy as np
from skimage import data
from scipy import ndimage
import tifffile
%matplotlib notebook
img = data.cells3d()
img = img.max(0)[1] # take only one channel and max project
img = ndimage.zoom(img, 10, order=1) # zoom in
tifffile.imshow(img)
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x11daa1970>)
How long does a gaussian filter take when applied to the entire image?
%%timeit -r 3
ndimage.gaussian_filter(img, sigma=5, mode='constant')
103 ms ± 535 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
What if we subdivide the array into chunks and apply this filter to each chunk?
import dask.array as da
img_da = da.from_array(img,
                       chunks=(500, 500),
                       )
img_da
| 
 | ||||||||||||||||
map_blocks#
We can use dask.array.map_blocks to apply a function to each chunk (or block) of the dask array.
filtered = da.map_blocks(
            ndimage.gaussian_filter, # the function to apply to each chunk
            img_da, # the array to apply the function to
            sigma=5, # arguments to the function
            mode='constant',
            )
filtered
| 
 | ||||||||||||||||
Does this improve the timing?
%%timeit -r 3
filtered.compute(scheduler='threads')
24.3 ms ± 373 µs per loop (mean ± std. dev. of 3 runs, 1 loop each)
%%timeit -r 3
filtered.compute(scheduler='processes')
712 ms ± 3.56 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
Performance comparison: Applying the gaussian filter on each funk is faster when using multi-threading than when using multi-processing.
Why is this? While threads share memory, different processes need to send data back and forth, which can create considerable overhead.
Let’s have a look at the output image.
print('entire image')
filtered_ndimage = ndimage.gaussian_filter(img, sigma=5, mode='constant')
tifffile.imshow(filtered_ndimage)
print('dask.array.map_blocks')
tifffile.imshow(filtered)
entire image
dask.array.map_blocks
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x1560a0ac0>)
We can prevent these border artefacts by using map_overlap instead of map_blocks.
This:
- adds neighboring chunk values to the borders of each chunk) 
- applies map_blocks as before 
- trims the previously added overlap from each chunk 
filtered_overlap = \
    da.map_overlap(
            ndimage.gaussian_filter, # the function to apply to each chunk
            img_da, # the array to apply the function to
            sigma=5, # arguments to the function
            mode='constant',
            depth=11,
            )
filtered_overlap
| 
 | ||||||||||||||||
tifffile.imshow(filtered_overlap.compute())
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x157b457c0>)
%%timeit -r 3
filtered_overlap.compute(scheduler='threads')
40.2 ms ± 216 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
dask-image#
There’s a python package which automatically deals with these border effects and other problems that can occur when applying the functions available from scipy.ndimage to tiled dask arrays.
https://image.dask.org/en/latest/
The available ndimage functions:
https://image.dask.org/en/latest/coverage.html
Among others:
- affine_transform 
- label 
- … 
from dask_image import ndfilters
filtered_di = ndfilters.gaussian_filter(img_da, sigma=5, mode='constant')
filtered_di
tifffile.imshow(filtered_di.compute())
%%timeit -r 3
filtered_di.compute()
More dask-image features#
Connected components#
from dask_image import ndfilters
img_da = da.from_array(img, chunks=500)
seg = (ndfilters.gaussian_filter(img_da, sigma=10, mode='constant') > 10000)
tifffile.imshow(seg)
/Users/malbert/miniconda3/envs/devbio-napari-env/lib/python3.9/site-packages/tifffile/tifffile.py:22741: UserWarning: Attempting to set identical low and high xlims makes transformation singular; automatically expanding.
  pyplot.Slider(
(<Figure size 988.8x604.8 with 3 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x157fc93a0>)
# Let's calculate connected components on each chunk of the segmentation image
def connected_components(im):
    return ndimage.label(im)[0]
labels = seg.map_blocks(connected_components)
tifffile.imshow(labels)
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x2840e4a00>)
# Using overlap does not help in this case
def connected_components(im):
    return ndimage.label(im)[0]
labels = seg.map_overlap(
    connected_components,
    depth=100,
)
tifffile.imshow(np.array(labels))
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x285af74f0>)
# dask-image implements connected components
from dask_image import ndmeasure
labels = ndmeasure.label(seg)[0]
tifffile.imshow(labels)
(<Figure size 988.8x604.8 with 2 Axes>,
 <Axes: >,
 <matplotlib.image.AxesImage at 0x28fc857f0>)
Affine transformations#
# Define a transformation
from scipy.spatial.transform import Rotation as R
# rotation
matrix = R.from_rotvec(np.pi/4. * np.array([0, 0, 1])).as_matrix()[:2, :2]
offset = np.array([1200., -600])
print('Matrix:', matrix)
print('Offset:', offset)
# Transform the image using plain scipy
img_t = ndimage.affine_transform(
    img,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    )
tifffile.imshow(img_t)
# Transform the image using dask_image.ndinterp.affine_transformation
from dask_image import ndinterp
img_t = ndinterp.affine_transform(
    img_da,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    output_chunks=500,
    ).compute()
    
tifffile.imshow(img_t)
Performance comparison
%%timeit -r 1
img_t = ndimage.affine_transform(
    img,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    )
%%timeit -r 1
img_t = ndinterp.affine_transform(
    img_da,
    matrix=matrix,
    offset=offset,
    order=1, # linear interpolation
    output_chunks=500,
    ).compute()
Excercise: Apply a median filter#
filtered = ...
