Reconstructing an image after using extract_image_patches

Since I also struggled with this, I post a solution that might be useful to others. The trick is to realize that the inverse of tf.extract_image_patches is its gradient, as suggested here. Since the gradient of this op is implemented in Tensorflow, it is easy to build the reconstruction function:

import tensorflow as tf
from keras import backend as K
import numpy as np

def extract_patches(x):
    return tf.extract_image_patches(
        x,
        (1, 3, 3, 1),
        (1, 1, 1, 1),
        (1, 1, 1, 1),
        padding="VALID"
    )

def extract_patches_inverse(x, y):
    _x = tf.zeros_like(x)
    _y = extract_patches(_x)
    grad = tf.gradients(_y, _x)[0]
    # Divide by grad, to "average" together the overlapping patches
    # otherwise they would simply sum up
    return tf.gradients(_y, _x, grad_ys=y)[0] / grad

# Generate 10 fake images, last dimension can be different than 3
images = np.random.random((10, 28, 28, 3)).astype(np.float32)
# Extract patches
patches = extract_patches(images)
# Reconstruct image
# Notice that original images are only passed to infer the right shape
images_reconstructed = extract_patches_inverse(images, patches) 

# Compare with original (evaluating tf.Tensor into a numpy array)
# Here using Keras session
images_r = images_reconstructed.eval(session=K.get_session())

print (np.sum(np.square(images - images_r))) 
# 2.3820458e-11

Use Update#2 - One small example for your task: (TF 1.0)

Considering image of size (4,4,1) converted to patches of size (4,2,2,1) and reconstructed them back to image.

import tensorflow as tf
image = tf.constant([[[1],   [2],  [3],  [4]],
                 [[5],   [6],  [7],  [8]],
                 [[9],  [10], [11],  [12]],
                [[13], [14], [15],  [16]]])

patch_size = [1,2,2,1]
patches = tf.extract_image_patches([image],
    patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [4, 2, 2, 1])
reconstructed = tf.reshape(patches, [1, 4, 4, 1])
rec_new = tf.space_to_depth(reconstructed,2)
rec_new = tf.reshape(rec_new,[4,4,1])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I)
print(I.shape)
print(P.shape)
print(R_n)
print(R_n.shape)

Output:

[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4, 4, 1)
(4, 2, 2, 1)
[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4,4,1)

#Update - for 3 channels (debugging..) working only for p = sqrt(h)

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 32

image = tf.random_normal([h,h,c])
patch_size = [1,p,p,1]
patches = tf.extract_image_patches([image],
   patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [h, p, p, c])
reconstructed = tf.reshape(patches, [1, h, h, c])
rec_new = tf.space_to_depth(reconstructed,p)
rec_new = tf.reshape(rec_new,[h,h,c])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

Output :

(1024, 1024, 3)
(1024, 32, 32, 3)
(1024, 1024, 3)
0.0

#Update 2

Reconstructing from the output of extract_image_patches seems difficult. Used other functions to extract patches and reverse the process to reconstruct which seems easier.

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 128


image = tf.random_normal([1,h,h,c])

# Image to Patches Conversion
pad = [[0,0],[0,0]]
patches = tf.space_to_batch_nd(image,[p,p],pad)
patches = tf.split(patches,p*p,0)
patches = tf.stack(patches,3)
patches = tf.reshape(patches,[(h/p)**2,p,p,c])

# Do processing on patches
# Using patches here to reconstruct
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
patches_proc = tf.split(patches_proc,p*p,3)
patches_proc = tf.stack(patches_proc,axis=0)
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])

reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)

sess = tf.Session()
I,P,R_n = sess.run([image,patches,reconstructed])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

Output:

(1, 1024, 1024, 3)
(64, 128, 128, 3)
(1, 1024, 1024, 3)
0.0

You could see other cool tensor transformation functions here: https://www.tensorflow.org/api_guides/python/array_ops


tf.extract_image_patches is quite difficult to use, as it does a lot of stuff in the background.

If you just need non-overlapping, then it's much easier to write it ourselves. You can reconstruct the full image by inverting all operations in image_to_patches.

Code sample (plots original image and patches):

import tensorflow as tf
from skimage import io
import matplotlib.pyplot as plt


def image_to_patches(image, patch_height, patch_width):
    # resize image so that it's dimensions are dividable by patch_height and patch_width
    image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32)
    image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32)
    height = tf.cast(tf.ceil(image_height / patch_height) * patch_height, dtype=tf.int32)
    width = tf.cast(tf.ceil(image_width / patch_width) * patch_width, dtype=tf.int32)

    num_rows = height // patch_height
    num_cols = width // patch_width
    # make zero-padding
    image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width))

    # get slices along the 0-th axis
    image = tf.reshape(image, [num_rows, patch_height, width, -1])
    # h/patch_h, w, patch_h, c
    image = tf.transpose(image, [0, 2, 1, 3])
    # get slices along the 1-st axis
    # h/patch_h, w/patch_w, patch_w,patch_h, c
    image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1])
    # num_patches, patch_w, patch_h, c
    image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1])
    # num_patches, patch_h, patch_w, c
    return tf.transpose(image, [0, 2, 1, 3])


image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg')
print('Original image shape:', image.shape)
tile_size = 200
image = tf.constant(image)
tiles = image_to_patches(image, tile_size, tile_size)

sess = tf.Session()
I, tiles = sess.run([image, tiles])
print(I.shape)
print(tiles.shape)


plt.figure(figsize=(1 * (4 + 1), 5))
plt.subplot(5, 1, 1)
plt.imshow(I)
plt.title('original')
plt.axis('off')
for i, tile in enumerate(tiles):
    plt.subplot(5, 5, 5 + 1 + i)
    plt.imshow(tile)
    plt.title(str(i))
    plt.axis('off')
plt.show()

I don't know if the following code is an efficient implementation but it works!

_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch

patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])

rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows] 

reconstructed = tf.concat(rows,axis=0)