Practical 2: Dask with images#
In the previous practical, we’ve seen that dask can help us parallelise computations on arrays. This can be useful for many operations typically performed on arrays like filtering.
# Let's load an example image
import numpy as np
from skimage import data
from scipy import ndimage
import tifffile
%matplotlib notebook
img = data.cells3d()
img = img.max(0)[1] # take only one channel and max project
img = ndimage.zoom(img, 10, order=1) # zoom in
tifffile.imshow(img)
(<Figure size 988.8x604.8 with 2 Axes>,
<Axes: >,
<matplotlib.image.AxesImage at 0x11daa1970>)
How long does a gaussian filter take when applied to the entire image?
%%timeit -r 3
ndimage.gaussian_filter(img, sigma=5, mode='constant')
103 ms ± 535 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
What if we subdivide the array into chunks and apply this filter to each chunk?
import dask.array as da
img_da = da.from_array(img,
chunks=(500, 500),
)
img_da
|
map_blocks
#
We can use dask.array.map_blocks
to apply a function to each chunk (or block) of the dask array.
filtered = da.map_blocks(
ndimage.gaussian_filter, # the function to apply to each chunk
img_da, # the array to apply the function to
sigma=5, # arguments to the function
mode='constant',
)
filtered
|
Does this improve the timing?
%%timeit -r 3
filtered.compute(scheduler='threads')
24.3 ms ± 373 µs per loop (mean ± std. dev. of 3 runs, 1 loop each)
%%timeit -r 3
filtered.compute(scheduler='processes')
712 ms ± 3.56 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
Performance comparison: Applying the gaussian filter on each funk is faster when using multi-threading than when using multi-processing.
Why is this? While threads share memory, different processes need to send data back and forth, which can create considerable overhead.
Let’s have a look at the output image.
print('entire image')
filtered_ndimage = ndimage.gaussian_filter(img, sigma=5, mode='constant')
tifffile.imshow(filtered_ndimage)
print('dask.array.map_blocks')
tifffile.imshow(filtered)
entire image
dask.array.map_blocks
(<Figure size 988.8x604.8 with 2 Axes>,
<Axes: >,
<matplotlib.image.AxesImage at 0x1560a0ac0>)
We can prevent these border artefacts by using map_overlap
instead of map_blocks
.
This:
adds neighboring chunk values to the borders of each chunk)
applies map_blocks as before
trims the previously added overlap from each chunk
filtered_overlap = \
da.map_overlap(
ndimage.gaussian_filter, # the function to apply to each chunk
img_da, # the array to apply the function to
sigma=5, # arguments to the function
mode='constant',
depth=11,
)
filtered_overlap
|
tifffile.imshow(filtered_overlap.compute())
(<Figure size 988.8x604.8 with 2 Axes>,
<Axes: >,
<matplotlib.image.AxesImage at 0x157b457c0>)
%%timeit -r 3
filtered_overlap.compute(scheduler='threads')
40.2 ms ± 216 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
dask-image#
There’s a python package which automatically deals with these border effects and other problems that can occur when applying the functions available from scipy.ndimage to tiled dask arrays.
https://image.dask.org/en/latest/
The available ndimage
functions:
https://image.dask.org/en/latest/coverage.html
Among others:
affine_transform
label
…
from dask_image import ndfilters
filtered_di = ndfilters.gaussian_filter(img_da, sigma=5, mode='constant')
filtered_di
tifffile.imshow(filtered_di.compute())
%%timeit -r 3
filtered_di.compute()
More dask-image features#
Connected components#
from dask_image import ndfilters
img_da = da.from_array(img, chunks=500)
seg = (ndfilters.gaussian_filter(img_da, sigma=10, mode='constant') > 10000)
tifffile.imshow(seg)
/Users/malbert/miniconda3/envs/devbio-napari-env/lib/python3.9/site-packages/tifffile/tifffile.py:22741: UserWarning: Attempting to set identical low and high xlims makes transformation singular; automatically expanding.
pyplot.Slider(
(<Figure size 988.8x604.8 with 3 Axes>,
<Axes: >,
<matplotlib.image.AxesImage at 0x157fc93a0>)
# Let's calculate connected components on each chunk of the segmentation image
def connected_components(im):
return ndimage.label(im)[0]
labels = seg.map_blocks(connected_components)
tifffile.imshow(labels)
(<Figure size 988.8x604.8 with 2 Axes>,
<Axes: >,
<matplotlib.image.AxesImage at 0x2840e4a00>)
# Using overlap does not help in this case
def connected_components(im):
return ndimage.label(im)[0]
labels = seg.map_overlap(
connected_components,
depth=100,
)
tifffile.imshow(np.array(labels))
(<Figure size 988.8x604.8 with 2 Axes>,
<Axes: >,
<matplotlib.image.AxesImage at 0x285af74f0>)
# dask-image implements connected components
from dask_image import ndmeasure
labels = ndmeasure.label(seg)[0]
tifffile.imshow(labels)
(<Figure size 988.8x604.8 with 2 Axes>,
<Axes: >,
<matplotlib.image.AxesImage at 0x28fc857f0>)
Affine transformations#
# Define a transformation
from scipy.spatial.transform import Rotation as R
# rotation
matrix = R.from_rotvec(np.pi/4. * np.array([0, 0, 1])).as_matrix()[:2, :2]
offset = np.array([1200., -600])
print('Matrix:', matrix)
print('Offset:', offset)
# Transform the image using plain scipy
img_t = ndimage.affine_transform(
img,
matrix=matrix,
offset=offset,
order=1, # linear interpolation
)
tifffile.imshow(img_t)
# Transform the image using dask_image.ndinterp.affine_transformation
from dask_image import ndinterp
img_t = ndinterp.affine_transform(
img_da,
matrix=matrix,
offset=offset,
order=1, # linear interpolation
output_chunks=500,
).compute()
tifffile.imshow(img_t)
Performance comparison
%%timeit -r 1
img_t = ndimage.affine_transform(
img,
matrix=matrix,
offset=offset,
order=1, # linear interpolation
)
%%timeit -r 1
img_t = ndinterp.affine_transform(
img_da,
matrix=matrix,
offset=offset,
order=1, # linear interpolation
output_chunks=500,
).compute()
Excercise: Apply a median filter#
filtered = ...