Understanding tf.extract_image_patches for extracting patches from an image
Here is how the method works:
ksizes
is used to decide the dimensions of each patch, or in other words, how many pixels each patch should contain.strides
denotes the length of the gap between the start of one patch and the start of the next consecutive patch within the original image.rates
is a number that essentially means our patch should jump byrates
pixels in the original image for each consecutive pixel that ends up in our patch. (The example below helps illustrate this.)padding
is either "VALID", which means every patch must be fully contained in the image, or "SAME", which means patches are allowed to be incomplete (the remaining pixels will be filled in with zeroes).
Here is some sample code with output to help demonstrate how it works:
import tensorflow as tf
n = 10
# images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order
images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]
# We generate four outputs as follows:
# 1. 3x3 patches with stride length 5
# 2. Same as above, but the rate is increased to 2
# 3. 4x4 patches with stride length 7; only one patch should be generated
# 4. Same as above, but with padding set to 'SAME'
with tf.Session() as sess:
print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval(), '\n\n'
print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval()
Output:
[[[[ 1 2 3 11 12 13 21 22 23]
[ 6 7 8 16 17 18 26 27 28]]
[[51 52 53 61 62 63 71 72 73]
[56 57 58 66 67 68 76 77 78]]]]
[[[[ 1 3 5 21 23 25 41 43 45]
[ 6 8 10 26 28 30 46 48 50]]
[[ 51 53 55 71 73 75 91 93 95]
[ 56 58 60 76 78 80 96 98 100]]]]
[[[[ 1 2 3 4 11 12 13 14 21 22 23 24 31 32 33 34]]]]
[[[[ 1 2 3 4 11 12 13 14 21 22 23 24 31 32 33 34]
[ 8 9 10 0 18 19 20 0 28 29 30 0 38 39 40 0]]
[[ 71 72 73 74 81 82 83 84 91 92 93 94 0 0 0 0]
[ 78 79 80 0 88 89 90 0 98 99 100 0 0 0 0 0]]]]
So, for example, our first result looks like the following:
* * * 4 5 * * * 9 10
* * * 14 15 * * * 19 20
* * * 24 25 * * * 29 30
31 32 33 34 35 36 37 38 39 40
41 42 43 44 45 46 47 48 49 50
* * * 54 55 * * * 59 60
* * * 64 65 * * * 69 70
* * * 74 75 * * * 79 80
81 82 83 84 85 86 87 88 89 90
91 92 93 94 95 96 97 98 99 100
As you can see, we have 2 rows and 2 columns worth of patches, which are what out_rows
and out_cols
are.
To expand on Neal's detailed answer, there are a lot of subtleties with zero padding when using "SAME", since extract_image_patches tries to center the patches in the image if possible. Depending on the stride, there may be padding on the top and left, or not, and the first patch doesn't necessarily start in the upper left.
For example, extending the previous example:
print tf.extract_image_patches(images, [1, 3, 3, 1], [1, n, n, 1], [1, 1, 1, 1], 'SAME').eval()[0]
With a stride of n=1, the image is padded with zeros all around and the first patch starts with padding. Other strides pad the image only on the right and bottom, or not at all. With a stride of n=10, the single patch starts at element 34 (in the middle of the image).
tf.extract_image_patches is implemented by the eigen library as described in this answer. You can study that code to see exactly how patch positions and padding are computed.
Introduction
Here I would like to present a rather simple demonstration to use the tf.image.extract_patches
with images itself. I have found a rather small amount of implementation of the method with actual images with the proper visualizations, so here it is.
The image we will use is of size (256, 256, 3). The patches we will be extracting will be shaped (128, 128, 3). This means that we will retrieve 4 tiles from the image.
Data used
I will be using the flowers dataset. Due to the fact that this answer needs a little data pipeline, I will be linking my kaggle kernel here which talks about consuming the dataset with tf.data.Dataset
API.
After we are through we go through the following code snippets.
images, _ = next(iter(train_ds.take(1)))
image = images[0]
plt.imshow(image.numpy().astype("uint8"))
Here we are taking one image from the batch of images and visualizing it as is.
image = tf.expand_dims(image,0) # To create the batch information
patches = tf.image.extract_patches(images=image,
sizes=[1, 128, 128, 1],
strides=[1, 128, 128, 1],
rates=[1, 1, 1, 1],
padding='VALID')
With this snippet, we are extracting patches of size (128,128) from the image of size (256,256). This directly translates to the fact that I would want the images to be split into 4 tiles.
Visualization
plt.figure(figsize=(10, 10))
for imgs in patches:
count = 0
for r in range(2):
for c in range(2):
ax = plt.subplot(2, 2, count+1)
plt.imshow(tf.reshape(imgs[r,c],shape=(128,128,3)).numpy().astype("uint8"))
count += 1