Image recovery after using extract_image_patches

I have an autoencoder that takes an image as input and creates a new image as an output.

The input image (1x1024x1024x3) is broken into patches (1024x32x32x3) before being sent to the network.

As soon as I have an exit, also a patch package of size 1024x32x32x3, I want to be able to restore an image of 1024x1024x3. I thought it was crazy just changing shape, but here's what happened.

First, the image read by Tensorflow: Input image

I fixed the image with the following code

patch_size = [1, 32, 32, 1] patches = tf.extract_image_patches([image], patch_size, patch_size, [1, 1, 1, 1], 'VALID') patches = tf.reshape(patches, [1024, 32, 32, 3]) 

Here are some patches from this image:

Patched input # 168Patched input # 169

But this, when I reformat the patch data back into an image, that things become pear-shaped.

 reconstructed = tf.reshape(patches, [1, 1024, 1024, 3]) converted = tf.image.convert_image_dtype(reconstructed, tf.uint8) encoded = tf.image.encode_png(converted) 

Reconstructed output

In this example, processing and recovery are not performed. I made a version of the code that you can use to test this behavior. To use it, follow these steps:

 echo "/path/to/test-image.png" > inputs.txt mkdir images python3 image_test.py inputs.txt images 

The code will make one input image, one patch image and one output image for each of the 1024 patches on each input image, so comment out the lines that create the input and output images if you only care about saving all the patches.

Someone please explain what happened :(

+9
source share
5 answers

Use Update # 2 . One small example for your task: (TF 1.0)

Given the size image (4.4.1), converted to size patches (4.2.2.1) and restored them back to the image.

 import tensorflow as tf image = tf.constant([[[1], [2], [3], [4]], [[5], [6], [7], [8]], [[9], [10], [11], [12]], [[13], [14], [15], [16]]]) patch_size = [1,2,2,1] patches = tf.extract_image_patches([image], patch_size, patch_size, [1, 1, 1, 1], 'VALID') patches = tf.reshape(patches, [4, 2, 2, 1]) reconstructed = tf.reshape(patches, [1, 4, 4, 1]) rec_new = tf.space_to_depth(reconstructed,2) rec_new = tf.reshape(rec_new,[4,4,1]) sess = tf.Session() I,P,R_n = sess.run([image,patches,rec_new]) print(I) print(I.shape) print(P.shape) print(R_n) print(R_n.shape) 

Output:

 [[[ 1][ 2][ 3][ 4]] [[ 5][ 6][ 7][ 8]] [[ 9][10][11][12]] [[13][14][15][16]]] (4, 4, 1) (4, 2, 2, 1) [[[ 1][ 2][ 3][ 4]] [[ 5][ 6][ 7][ 8]] [[ 9][10][11][12]] [[13][14][15][16]]] (4,4,1) 

Update - for 3 channels (debugging ..)

only works for p = sqrt (h)

 import tensorflow as tf import numpy as np c = 3 h = 1024 p = 32 image = tf.random_normal([h,h,c]) patch_size = [1,p,p,1] patches = tf.extract_image_patches([image], patch_size, patch_size, [1, 1, 1, 1], 'VALID') patches = tf.reshape(patches, [h, p, p, c]) reconstructed = tf.reshape(patches, [1, h, h, c]) rec_new = tf.space_to_depth(reconstructed,p) rec_new = tf.reshape(rec_new,[h,h,c]) sess = tf.Session() I,P,R_n = sess.run([image,patches,rec_new]) print(I.shape) print(P.shape) print(R_n.shape) err = np.sum((R_n-I)**2) print(err) 

Output:

 (1024, 1024, 3) (1024, 32, 32, 3) (1024, 1024, 3) 0.0 

Update 2

Reconstructing from the output of extract_image_patches seems complicated. Other functions are used to extract patches and reverse the process for recovery, which seems simpler.

 import tensorflow as tf import numpy as np c = 3 h = 1024 p = 128 image = tf.random_normal([1,h,h,c]) # Image to Patches Conversion pad = [[0,0],[0,0]] patches = tf.space_to_batch_nd(image,[p,p],pad) patches = tf.split(patches,p*p,0) patches = tf.stack(patches,3) patches = tf.reshape(patches,[(h/p)**2,p,p,c]) # Do processing on patches # Using patches here to reconstruct patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c]) patches_proc = tf.split(patches_proc,p*p,3) patches_proc = tf.stack(patches_proc,axis=0) patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c]) reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad) sess = tf.Session() I,P,R_n = sess.run([image,patches,reconstructed]) print(I.shape) print(P.shape) print(R_n.shape) err = np.sum((R_n-I)**2) print(err) 

Output:

 (1, 1024, 1024, 3) (64, 128, 128, 3) (1, 1024, 1024, 3) 0.0 

Here you can see other interesting tensor conversion functions: https://www.tensorflow.org/api_guides/python/array_ops

+4
source

Since I also struggled with this, I am posting a solution that may be useful to others. The trick is to understand that the inverse of tf.extract_image_patches is its gradient, as suggested here . Since the gradient of this operation is implemented in Tensorflow, it is easy to build a recovery function:

 import tensorflow as tf from keras import backend as K import numpy as np def extract_patches(x): return tf.extract_image_patches( x, (1, 3, 3, 1), (1, 1, 1, 1), (1, 1, 1, 1), padding="VALID" ) def extract_patches_inverse(x, y): _x = tf.zeros_like(x) _y = extract_patches(_x) grad = tf.gradients(_y, _x)[0] # Divide by grad, to "average" together the overlapping patches # otherwise they would simply sum up return tf.gradients(_y, _x, grad_ys=y)[0] / grad # Generate 10 fake images, last dimension can be different than 3 images = np.random.random((10, 28, 28, 3)).astype(np.float32) # Extract patches patches = extract_patches(images) # Reconstruct image # Notice that original images are only passed to infer the right shape images_reconstructed = extract_patches_inverse(images, patches) # Compare with original (evaluating tf.Tensor into a numpy array) # Here using Keras session images_r = images_reconstructed.eval(session=K.get_session()) print (np.sum(np.square(images - images_r))) # 2.3820458e-11 
+6
source

tf.extract_image_patches quiet difficult to use since it does a lot of things in the background.

If you just need not to overlap, then it is much easier to write it to yourself. You can restore the full image by inverting all operations in image_to_patches .

Sample code (graphics of the original image and patches):

 import tensorflow as tf from skimage import io import matplotlib.pyplot as plt def image_to_patches(image, patch_height, patch_width): # resize image so that it dimensions are dividable by patch_height and patch_width image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32) image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32) height = tf.cast(tf.ceil(image_height / patch_height) * patch_height, dtype=tf.int32) width = tf.cast(tf.ceil(image_width / patch_width) * patch_width, dtype=tf.int32) num_rows = height // patch_height num_cols = width // patch_width # make zero-padding image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width)) # get slices along the 0-th axis image = tf.reshape(image, [num_rows, patch_height, width, -1]) # h/patch_h, w, patch_h, c image = tf.transpose(image, [0, 2, 1, 3]) # get slices along the 1-st axis # h/patch_h, w/patch_w, patch_w,patch_h, c image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1]) # num_patches, patch_w, patch_h, c image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1]) # num_patches, patch_h, patch_w, c return tf.transpose(image, [0, 2, 1, 3]) image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg') print('Original image shape:', image.shape) tile_size = 200 image = tf.constant(image) tiles = image_to_patches(image, tile_size, tile_size) sess = tf.Session() I, tiles = sess.run([image, tiles]) print(I.shape) print(tiles.shape) plt.figure(figsize=(1 * (4 + 1), 5)) plt.subplot(5, 1, 1) plt.imshow(I) plt.title('original') plt.axis('off') for i, tile in enumerate(tiles): plt.subplot(5, 5, 5 + 1 + i) plt.imshow(tile) plt.title(str(i)) plt.axis('off') plt.show() 
+4
source

To specifically address the original question, namely: "Restoring an image after using extract_image_patches", I suggest using tf.scatter_nd() and build a multi-layer image. This will work even in a situation where the extracted corrections have overlap or the image does not match the selection. Here is my suggested solution.

 import cv2 import numpy as np import tensorflow as tf # Function to extract patches using 'extract_image_patches' def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100): with tf.variable_scope('im2_patches'): patches = tf.image.extract_image_patches( images=raw_input, ksizes=[1, _patch_size[0], _patch_size[1], 1], strides=[1, _stride, _stride, 1], rates=[1, 1, 1, 1], padding='SAME' ) h = tf.shape(patches)[1] w = tf.shape(patches)[2] patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3)) return patches, (h, w) # Function to reconstruct image def patches_to_img(update, _block_shape, _stride=100): with tf.variable_scope('patches2im'): _h = _block_shape[0] _w = _block_shape[1] bs = tf.shape(update)[0] # batch size np = tf.shape(update)[1] # number of patches ps_h = tf.shape(update)[2] # patch height ps_w = tf.shape(update)[3] # patch width col_ch = tf.shape(update)[4] # Colour channel count wout = (_w - 1) * _stride + ps_w # Recalculate output shape of "extract_image_patches" including padded pixels hout = (_h - 1) * _stride + ps_h # Recalculate output shape of "extract_image_patches" including padded pixels x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h)) x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1)) y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1)) xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride), tf.range(0, (hout - ps_h) + 1, _stride)) bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1)) # batch indices yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1)) # y indices xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1)) # x indices cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1)) # color indices dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1)) # shift indices idx = tf.concat([bb, yy, xx, cc, dd], -1) stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np)) stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3)) stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np)) stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3)) with tf.variable_scope("consolidate"): sum_stratified_img = tf.reduce_sum(stratified_img, axis=1) stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1) reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count) return reconstructed_img, stratified_img if __name__ == "__main__": # load initial image image_org = cv2.imread('orig_img.jpg') # Add batch dimension image = np.expand_dims(image_org, axis=0) # set parameters patch_size = (228, 228) stride = 200 input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img") # Extract patches using "extract_image_patches()" extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride) # block_shape is the number of patches extracted in the x and in the y dimension # extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3) reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride) # Reconstruct Image with tf.Session() as sess: ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image}) # print(bs) si = si.astype(np.int32) # Show reconstructed image cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255) cv2.waitKey(0) # Show stratified images for i in range(si.shape[1]): im_1 = si[0, i, :, :, :] cv2.imshow('sd', im_1.astype(np.float32)/255) 

The above solution should work for burst images of any color channel size.

0
source
 _,n_row,n_col,n_channel = x.shape n_patch = n_row*n_col // (patch_size**2) #assume square patch patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID') patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel]) rows = tf.split(patches,n_col//patch_size,axis=0) rows = [tf.concat(tf.unstack(x),axis=1) for x in rows] reconstructed = tf.concat(rows,axis=0) 

I do not know if this is an effective implementation, but it works!

0
source

All Articles