Dask deconvolution#

With modern (many Gigabyte) image sizes it is often not possible to load the entire image into GPU memory. And even if the entire image does fit into GPU memory keep in mind multiple buffers are needed to perform FFT Convolution and Deconvolution (remember how many buffers were needed for example 2).

In the case where we don’t have enough GPU memory to perform the entire calculation, we can use Dask to load the image in chunks and perform the deconvolution on each chunk. This is a bit more complicated than the previous examples, but it is still possible to do it in just a few lines of code.

import dask.array as da
from clij2fft.richardson_lucy import richardson_lucy_nc 
import numpy as np
import stackview
from skimage.io import imread


from skimage.io import imread
from decon_helper import image_path

image_name='Bars-G10-P30-stack.tif'
psf_name='PSF-Bars-stack.tif'
truth_name='Bars-stack.tif'

im=imread(image_path / image_name)
psf=imread(image_path / psf_name)
truth=imread(image_path / truth_name)
im=im.astype('float32')
psf=psf.astype('float32')
psf=psf/psf.sum()
print(im.shape, psf.shape, truth.shape)

# define the PSF XY half size and the XY overlap, we want the PSF half size to be smaller than the overlap
psfHalfSize = 16

# crop PSF using PSFHalfSize
psf=psf[:,int(psf.shape[1]/2)-psfHalfSize:int(psf.shape[1]/2)+psfHalfSize-1,int(psf.shape[2]/2)-psfHalfSize:int(psf.shape[2]/2)+psfHalfSize-1]
(128, 256, 256) (128, 256, 256) (128, 256, 256)

Define number of chunks#

Define the number of chunks to divide the image into.

(In this example the image is relatively small so likely the image and arrays needed for FFT based calculations would fit into the GPU without chunking, in a real life example we would pre-compute the largest chunk size we could process given memory constraints and base the chunk size on that).

num_x_chunks = 2
num_y_chunks = 2
num_z_chunks = 1

z_chunk_size = im.shape[0]
y_chunk_size = int(im.shape[1]/num_x_chunks)
x_chunk_size = int(im.shape[2]/num_y_chunks)
print('chunks', z_chunk_size, y_chunk_size, x_chunk_size)
# create dask image
chunks 128 128 128

Define the deconvolver#

try:
    from clij2fft.richardson_lucy import richardson_lucy_nc 
    def deconv_chunk(img):
            print(img.shape,psf.shape)
            result = richardson_lucy_nc(img, psf, iterations, reg)
            print('finished decon chunk')
            return result
            #return stack
except ImportError:
    print('clij2fft non-circulant rl not imported')
    try:
        import RedLionfishDeconv as rl
        print('redlionfish rl imported')
        def deconv_chunk(img, psf, iterations):
            print(img.shape,psf.shape)
            result = rl.doRLDeconvolutionFromNpArrays(img, psf, niter=iterations, method='gpu', resAsUint8=False )
            print('finished decon chunk')
            return result
    except ImportError:
        print('redlionfish rl not imported')

Deconvolve in chunks with overlap between chunks#

Here we call the dask deconvolution using an overlap factor to prevent edge artifacts between chunks.

iterations = 100
reg = 0.0001

overlap = 24
dimg = da.from_array(im,chunks=(z_chunk_size, y_chunk_size, x_chunk_size))

out = dimg.map_overlap(deconv_chunk, depth={0:0, 1:overlap, 2:overlap}, boundary='reflect', dtype=np.float32)

decon_overlap_24 = out.compute(num_workers=1)
stackview.orthogonal(decon_overlap_24)
    
(0, 0, 0) (128, 31, 31)
(128, 176, 176) (128, 31, 31)
get lib
1 warning generated.
2 warnings generated.
2 warnings generated.
2 warnings generated.
2 warnings generated.
2 warnings generated.
1 warning generated.
Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk
(128, 176, 176) (128, 31, 31)
get lib

Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk
(128, 176, 176) (128, 31, 31)
get lib

Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk
(128, 176, 176) (128, 31, 31)
get lib

Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk

Deconvolve in chunks with no overlap#

Here we deconvolve in chunks without overlap (this will be a bit faster and use a little less memory) but as you can see we end up with artifacts on the seems of the chunks.

overlap = 0 
dimg = da.from_array(im,chunks=(z_chunk_size, y_chunk_size, x_chunk_size))

out = dimg.map_overlap(deconv_chunk, depth={0:0, 1:overlap, 2:overlap}, boundary='reflect', dtype=np.float32)

decon_overlap_24 = out.compute(num_workers=1)
stackview.orthogonal(decon_overlap_24)
  
(0, 0, 0) (128, 31, 31)
(128, 128, 128) (128, 31, 31)
get lib
2 warnings generated.
Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk
(128, 128, 128) (128, 31, 31)
get lib

Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk
(128, 128, 128) (128, 31, 31)
get lib

Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk
(128, 128, 128) (128, 31, 31)
get lib

Richardson Lucy Started
0 10 20 30 40 50 60 70 80 90 
Richardson Lucy Finishedfinished decon chunk