Noise2Void (3D)#
# We import all our dependencies.
from n2v.models import N2VConfig, N2V
import numpy as np
from csbdeep.utils import plot_history
from n2v.utils.n2v_utils import manipulate_val_data
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
from matplotlib import pyplot as plt
import os
from skimage import io
from tifffile import imread
import stackview
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
Dataset#
Some example data has been stored at /projects/p038/p_scads_trainings/BIAS/torch_segmentation_denoising_example_data/denoising_example_data/
. The data can otherwise also be downloaded here (CC-BY-4.0) (Haase, Vorkel, Myers). Similar to the pytorch dataloader, Noise2Void provides its own DataGenerator
class, which serves a similar purpose of iterating through the data easily,loading data and extracting patches for training and validation.
root = r'/projects/p038/p_scads_trainings/BIAS/torch_segmentation_denoising_example_data/denoising_example_data/n2v_3d'
filename = 'lund_i000000_oi_000000.tif'
data_generator = N2V_DataGenerator()
We will load all the ‘.tif’ files from the data
directory. In our case it is only one. The function will return a list of images (numpy arrays). In the dims
parameter we specify the order of dimensions in the image files we are reading. The load_imgs_from_directory
function automatically added two extra dimension to the images: One at the front is used to hold a potential stack of images such as a movie. One at the end could hold color channels such as RGB.
imgs = data_generator.load_imgs_from_directory(directory = root, dims='ZYX')
imgs[0].shape
(1, 71, 1024, 512, 1)
Let’s look at a maximum projection of the volume. We have to remove the added extra dimensions to display it.
fig, ax = plt.subplots()
ax.imshow(imgs[0].squeeze().max(axis=0), cmap='magma', vmin=np.percentile(imgs[0], 0.1), vmax=np.percentile(imgs[0], 99.9))
<matplotlib.image.AxesImage at 0x2b8124489cd0>
Config#
We have to provide a few settings for the training, such as the shape of patches to be extracted from the image. This is necessary, as pushing the whole image data to the device (even on an HPC) would probably use up all available memory at hand. Notice the changed shape along the first axis (the batch dimension).
# Here we extract patches for training and validation.
patch_shape = (32, 64, 64)
patches = data_generator.generate_patches_from_list(imgs[:1], shape=patch_shape)
Generated patches: (2048, 32, 64, 64, 1)
Patches are created so they do not overlap. This is not the case if you specify a number of patches. See the docstring for details! Non-overlapping patches enable us to split them into a training and validation set.
X = patches[:600]
X_val = patches[600:]
fig, axes = plt.subplots(ncols=2)
axes[0].imshow(X[0,16].squeeze())
axes[0].set_title('Training Patch')
axes[1].imshow(X_val[0,16].squeeze())
axes[1].set_title('Validation Patch')
Text(0.5, 1.0, 'Validation Patch')
You can increase train_steps_per_epoch
to get even better results at the price of longer computation.
config = N2VConfig(X, unet_kern_size=3,
train_steps_per_epoch=int(X.shape[0]/128),train_epochs=20, train_loss='mse', batch_norm=True,
train_batch_size=4, n2v_perc_pix=0.198, n2v_patch_shape=(32, 64, 64),
n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5)
Let’s look at the parameters stored in the config-object.
vars(config)
{'means': ['211.3554'],
'stds': ['43.559948'],
'n_dim': 3,
'axes': 'ZYXC',
'n_channel_in': 1,
'n_channel_out': 1,
'unet_residual': False,
'unet_n_depth': 2,
'unet_kern_size': 3,
'unet_n_first': 32,
'unet_last_activation': 'linear',
'unet_input_shape': (None, None, None, 1),
'train_loss': 'mse',
'train_epochs': 20,
'train_steps_per_epoch': 4,
'train_learning_rate': 0.0004,
'train_batch_size': 4,
'train_tensorboard': True,
'train_checkpoint': 'weights_best.h5',
'train_reduce_lr': {'factor': 0.5, 'patience': 10},
'batch_norm': True,
'n2v_perc_pix': 0.198,
'n2v_patch_shape': (32, 64, 64),
'n2v_manipulator': 'uniform_withCP',
'n2v_neighborhood_radius': 5,
'single_net_per_channel': True,
'blurpool': False,
'skip_skipone': False,
'structN2Vmask': None,
'probabilistic': False}
Model creation#
# a name used to identify the model
model_name = 'n2v_3D'
# the base directory in which our model will live
basedir = 'models'
# We are now creating our network model.
model = N2V(config=config, name=model_name, basedir=basedir)
/app/env/lib/python3.9/site-packages/n2v/models/n2v_standard.py:429: UserWarning: output path for model already exists, files may be overwritten: /home/h1/johamuel/Projects/repos/PoL-BioImage-Analysis-TS-GPU-Accelerated-Image-Analysis/docs/70_AI_Segmentation_Denoising/models/n2v_3D
warnings.warn(
2023-08-25 15:10:05.841469: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-25 15:10:07.925917: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 35012 MB memory: -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:0b:00.0, compute capability: 8.0
history = model.train(X, X_val)
259 blind-spots will be generated per training patch of size (32, 64, 64).
Preparing validation data: 100%|██████████| 1448/1448 [00:00<00:00, 1717.83it/s]
Epoch 1/20
2023-08-25 15:10:14.661980: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8100
2023-08-25 15:10:17.572895: I tensorflow/core/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-08-25 15:10:17.574144: I tensorflow/core/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-08-25 15:10:17.574182: W tensorflow/stream_executor/gpu/asm_compiler.cc:80] Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
2023-08-25 15:10:17.576223: I tensorflow/core/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-08-25 15:10:17.576324: W tensorflow/stream_executor/gpu/redzone_allocator.cc:314] INTERNAL: Failed to launch ptxas
Relying on driver to perform ptx compilation.
Modify $PATH to customize ptxas location.
This message will be only logged once.
2023-08-25 15:10:20.276396: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
3/3 [==============================] - 0s 6ms/steposs: 3.5962 - n2v_mse: 3.5962 - n2v_abs: 1.550
4/4 [==============================] - 16s 2s/step - loss: 3.2999 - n2v_mse: 3.2999 - n2v_abs: 1.5008 - val_loss: 1.0786 - val_n2v_mse: 1.0738 - val_n2v_abs: 0.9134 - lr: 4.0000e-04
Epoch 2/20
3/3 [==============================] - 0s 5ms/steposs: 1.5706 - n2v_mse: 1.5706 - n2v_abs: 1.07
4/4 [==============================] - 3s 961ms/step - loss: 1.7476 - n2v_mse: 1.7476 - n2v_abs: 1.1515 - val_loss: 0.9133 - val_n2v_mse: 0.9089 - val_n2v_abs: 0.8382 - lr: 4.0000e-04
Epoch 3/20
3/3 [==============================] - 0s 4ms/steposs: 2.0997 - n2v_mse: 2.0997 - n2v_abs: 1.27
4/4 [==============================] - 3s 960ms/step - loss: 1.9134 - n2v_mse: 1.9134 - n2v_abs: 1.2011 - val_loss: 0.8190 - val_n2v_mse: 0.8151 - val_n2v_abs: 0.7894 - lr: 4.0000e-04
Epoch 4/20
3/3 [==============================] - 0s 5ms/steposs: 1.4163 - n2v_mse: 1.4163 - n2v_abs: 1.05
4/4 [==============================] - 3s 946ms/step - loss: 1.3084 - n2v_mse: 1.3084 - n2v_abs: 1.0001 - val_loss: 0.8296 - val_n2v_mse: 0.8255 - val_n2v_abs: 0.8026 - lr: 4.0000e-04
Epoch 5/20
3/3 [==============================] - 0s 5ms/steposs: 0.9380 - n2v_mse: 0.9380 - n2v_abs: 0.82
4/4 [==============================] - 3s 960ms/step - loss: 0.8024 - n2v_mse: 0.8024 - n2v_abs: 0.7470 - val_loss: 0.7955 - val_n2v_mse: 0.7914 - val_n2v_abs: 0.7933 - lr: 4.0000e-04
Epoch 6/20
3/3 [==============================] - 0s 5ms/steposs: 1.1156 - n2v_mse: 1.1156 - n2v_abs: 0.82
4/4 [==============================] - 3s 959ms/step - loss: 0.9119 - n2v_mse: 0.9119 - n2v_abs: 0.7164 - val_loss: 0.7908 - val_n2v_mse: 0.7869 - val_n2v_abs: 0.7878 - lr: 4.0000e-04
Epoch 7/20
3/3 [==============================] - 0s 6ms/steposs: 1.3766 - n2v_mse: 1.3766 - n2v_abs: 1.03
4/4 [==============================] - 3s 960ms/step - loss: 1.1497 - n2v_mse: 1.1497 - n2v_abs: 0.9222 - val_loss: 0.6052 - val_n2v_mse: 0.6020 - val_n2v_abs: 0.6819 - lr: 4.0000e-04
Epoch 8/20
3/3 [==============================] - 0s 7ms/steposs: 0.3226 - n2v_mse: 0.3226 - n2v_abs: 0.47
4/4 [==============================] - 3s 972ms/step - loss: 0.3489 - n2v_mse: 0.3489 - n2v_abs: 0.4908 - val_loss: 0.5721 - val_n2v_mse: 0.5689 - val_n2v_abs: 0.6524 - lr: 4.0000e-04
Epoch 9/20
3/3 [==============================] - 0s 5ms/steposs: 1.1096 - n2v_mse: 1.1096 - n2v_abs: 0.94
4/4 [==============================] - 3s 943ms/step - loss: 0.8758 - n2v_mse: 0.8758 - n2v_abs: 0.7888 - val_loss: 0.6807 - val_n2v_mse: 0.6768 - val_n2v_abs: 0.6915 - lr: 4.0000e-04
Epoch 10/20
3/3 [==============================] - 0s 5ms/steposs: 0.3309 - n2v_mse: 0.3309 - n2v_abs: 0.48
4/4 [==============================] - 3s 944ms/step - loss: 0.3037 - n2v_mse: 0.3037 - n2v_abs: 0.4558 - val_loss: 0.7176 - val_n2v_mse: 0.7134 - val_n2v_abs: 0.6956 - lr: 4.0000e-04
Epoch 11/20
3/3 [==============================] - 0s 7ms/steposs: 0.6576 - n2v_mse: 0.6576 - n2v_abs: 0.63
4/4 [==============================] - 3s 963ms/step - loss: 0.5658 - n2v_mse: 0.5658 - n2v_abs: 0.5924 - val_loss: 0.4690 - val_n2v_mse: 0.4667 - val_n2v_abs: 0.5862 - lr: 4.0000e-04
Epoch 12/20
3/3 [==============================] - 0s 11ms/stepss: 0.4597 - n2v_mse: 0.4597 - n2v_abs: 0.52
4/4 [==============================] - 3s 955ms/step - loss: 1.1235 - n2v_mse: 1.1235 - n2v_abs: 0.8216 - val_loss: 0.3430 - val_n2v_mse: 0.3418 - val_n2v_abs: 0.5007 - lr: 4.0000e-04
Epoch 13/20
3/3 [==============================] - 0s 4ms/steposs: 0.9861 - n2v_mse: 0.9861 - n2v_abs: 0.80
4/4 [==============================] - 3s 955ms/step - loss: 0.8893 - n2v_mse: 0.8893 - n2v_abs: 0.7777 - val_loss: 0.2787 - val_n2v_mse: 0.2779 - val_n2v_abs: 0.4446 - lr: 4.0000e-04
Epoch 14/20
3/3 [==============================] - 0s 5ms/steposs: 0.7548 - n2v_mse: 0.7548 - n2v_abs: 0.71
4/4 [==============================] - 3s 942ms/step - loss: 0.6305 - n2v_mse: 0.6305 - n2v_abs: 0.6323 - val_loss: 0.3837 - val_n2v_mse: 0.3827 - val_n2v_abs: 0.5253 - lr: 4.0000e-04
Epoch 15/20
3/3 [==============================] - 0s 7ms/steposs: 0.4684 - n2v_mse: 0.4684 - n2v_abs: 0.55
4/4 [==============================] - 3s 939ms/step - loss: 0.7192 - n2v_mse: 0.7192 - n2v_abs: 0.7060 - val_loss: 0.5192 - val_n2v_mse: 0.5177 - val_n2v_abs: 0.6018 - lr: 4.0000e-04
Epoch 16/20
3/3 [==============================] - 0s 5ms/steposs: 0.3171 - n2v_mse: 0.3171 - n2v_abs: 0.44
4/4 [==============================] - 3s 939ms/step - loss: 0.3091 - n2v_mse: 0.3091 - n2v_abs: 0.4417 - val_loss: 0.4662 - val_n2v_mse: 0.4648 - val_n2v_abs: 0.5686 - lr: 4.0000e-04
Epoch 17/20
3/3 [==============================] - 0s 5ms/steposs: 0.4391 - n2v_mse: 0.4391 - n2v_abs: 0.56
4/4 [==============================] - 3s 938ms/step - loss: 0.3438 - n2v_mse: 0.3438 - n2v_abs: 0.4762 - val_loss: 0.4039 - val_n2v_mse: 0.4030 - val_n2v_abs: 0.5302 - lr: 4.0000e-04
Epoch 18/20
3/3 [==============================] - 0s 4ms/steposs: 0.4644 - n2v_mse: 0.4644 - n2v_abs: 0.58
4/4 [==============================] - 3s 937ms/step - loss: 0.5125 - n2v_mse: 0.5125 - n2v_abs: 0.6299 - val_loss: 0.4013 - val_n2v_mse: 0.4004 - val_n2v_abs: 0.5267 - lr: 4.0000e-04
Epoch 19/20
3/3 [==============================] - 0s 6ms/steposs: 0.2503 - n2v_mse: 0.2503 - n2v_abs: 0.42
4/4 [==============================] - 3s 950ms/step - loss: 0.2406 - n2v_mse: 0.2406 - n2v_abs: 0.4126 - val_loss: 0.4255 - val_n2v_mse: 0.4246 - val_n2v_abs: 0.5384 - lr: 4.0000e-04
Epoch 20/20
3/3 [==============================] - 0s 4ms/steposs: 0.1796 - n2v_mse: 0.1796 - n2v_abs: 0.32
4/4 [==============================] - 3s 941ms/step - loss: 0.1748 - n2v_mse: 0.1748 - n2v_abs: 0.3264 - val_loss: 0.4494 - val_n2v_mse: 0.4486 - val_n2v_abs: 0.5547 - lr: 4.0000e-04
Loading network weights from 'weights_best.h5'.
Inference#
# We load the data we want to process.
img = imread(os.path.join(root, filename))
# Here we process the data.
# The 'n_tiles' parameter can be used if images are too big for the GPU memory.
# If we do not provide the 'n_tiles' parameter the system will automatically try to find an appropriate tiling.
pred = model.predict(img, axes='ZYX', n_tiles=(2,4,4))
The input image is of type uint16 and will be casted to float32 for prediction.
1/1 [==============================] - 1s 548ms/step
3%|▎ | 1/32 [00:00<00:00, 135300.13it/s]
1/1 [==============================] - 0s 31ms/step
6%|▋ | 2/32 [00:00<00:01, 18.90it/s]
1/1 [==============================] - 0s 25ms/step
9%|▉ | 3/32 [00:00<00:01, 18.90it/s]
1/1 [==============================] - 0s 24ms/step
12%|█▎ | 4/32 [00:00<00:02, 13.12it/s]
1/1 [==============================] - 0s 24ms/step
16%|█▌ | 5/32 [00:00<00:02, 13.12it/s]
1/1 [==============================] - 0s 27ms/step
19%|█▉ | 6/32 [00:00<00:02, 11.59it/s]
1/1 [==============================] - 0s 28ms/step
22%|██▏ | 7/32 [00:00<00:02, 11.59it/s]
1/1 [==============================] - 0s 28ms/step
25%|██▌ | 8/32 [00:00<00:02, 11.44it/s]
1/1 [==============================] - 0s 27ms/step
28%|██▊ | 9/32 [00:00<00:02, 11.44it/s]
1/1 [==============================] - 0s 29ms/step
31%|███▏ | 10/32 [00:00<00:02, 10.74it/s]
1/1 [==============================] - 0s 24ms/step
34%|███▍ | 11/32 [00:00<00:01, 10.74it/s]
1/1 [==============================] - 0s 26ms/step
38%|███▊ | 12/32 [00:01<00:01, 10.86it/s]
1/1 [==============================] - 0s 24ms/step
41%|████ | 13/32 [00:01<00:01, 10.86it/s]
1/1 [==============================] - 0s 31ms/step
44%|████▍ | 14/32 [00:01<00:01, 10.51it/s]
1/1 [==============================] - 0s 24ms/step
47%|████▋ | 15/32 [00:01<00:01, 10.51it/s]
1/1 [==============================] - 0s 25ms/step
50%|█████ | 16/32 [00:01<00:01, 10.75it/s]
1/1 [==============================] - 0s 24ms/step
53%|█████▎ | 17/32 [00:01<00:01, 10.75it/s]
1/1 [==============================] - 0s 25ms/step
56%|█████▋ | 18/32 [00:01<00:01, 10.20it/s]
1/1 [==============================] - 0s 25ms/step
59%|█████▉ | 19/32 [00:01<00:01, 10.20it/s]
1/1 [==============================] - 0s 24ms/step
62%|██████▎ | 20/32 [00:01<00:01, 10.60it/s]
1/1 [==============================] - 0s 24ms/step
66%|██████▌ | 21/32 [00:01<00:01, 10.60it/s]
1/1 [==============================] - 0s 25ms/step
69%|██████▉ | 22/32 [00:02<00:00, 10.50it/s]
1/1 [==============================] - 0s 24ms/step
72%|███████▏ | 23/32 [00:02<00:00, 10.50it/s]
1/1 [==============================] - 0s 24ms/step
75%|███████▌ | 24/32 [00:02<00:00, 10.60it/s]
1/1 [==============================] - 0s 24ms/step
78%|███████▊ | 25/32 [00:02<00:00, 10.60it/s]
1/1 [==============================] - 0s 25ms/step
81%|████████▏ | 26/32 [00:02<00:00, 10.50it/s]
1/1 [==============================] - 0s 23ms/step
84%|████████▍ | 27/32 [00:02<00:00, 10.50it/s]
1/1 [==============================] - 0s 25ms/step
88%|████████▊ | 28/32 [00:02<00:00, 10.79it/s]
1/1 [==============================] - 0s 24ms/step
91%|█████████ | 29/32 [00:02<00:00, 10.79it/s]
1/1 [==============================] - 0s 25ms/step
94%|█████████▍| 30/32 [00:02<00:00, 10.62it/s]
1/1 [==============================] - 0s 25ms/step
97%|█████████▋| 31/32 [00:02<00:00, 10.62it/s]
1/1 [==============================] - 0s 25ms/step
100%|██████████| 32/32 [00:03<00:00, 10.66it/s]
fig, axes = plt.subplots(ncols=2)
axes[0].imshow(img[15], cmap='magma', vmin=np.percentile(img,0.1), vmax=np.percentile(img,99.9))
axes[0].set_title('Input');
axes[1].imshow(pred[15], cmap='magma', vmin=np.percentile(pred,0.1), vmax=np.percentile(pred,99.9))
axes[1].set_title('Prediction')
Text(0.5, 1.0, 'Prediction')