What does TensorFlow's `conv2d_transpose()` operation do?
This is the best explanation I've seen online how convolution transpose works is here.
I'll give my own short description. It applies convolution with a fractional stride. In other words spacing out the input values (with zeroes) to apply the filter over a region that's potentially smaller than the filter size.
As for the why one would want to use it. It can be used as a sort of upsampling with learned weights as opposed to bilinear interpolation or some other fixed form of upsampling.
Here's another viewpoint from the "gradients" perspective, i.e. why TensorFlow documentation says conv2d_transpose()
is "actually the transpose (gradient) of conv2d rather than an actual deconvolution". For more details on the actual computation done in conv2d_transpose
, I would highly recommend this article, starting from page 19.
Four Related Functions
In tf.nn
, there are 4 closely related and rather confusing functions for 2d convolution:
tf.nn.conv2d
tf.nn.conv2d_backprop_filter
tf.nn.conv2d_backprop_input
tf.nn.conv2d_transpose
One sentence summary: they are all just 2d convolutions. Their differences are in their input arguments ordering, input rotation or transpose, strides (including fractional stride size), paddings and etc. With tf.nn.conv2d
in hand, one can implement all of the 3 other ops by transforming inputs and changing the conv2d
arguments.
Problem Settings
- Forward and backward computations:
# forward
out = conv2d(x, w)
# backward, given d_out
=> find d_x?
=> find d_w?
In the forward computation, we compute the convolution of input image x
with the filter w
, and the result is out
.
In the backward computation, assume we're given d_out
, which is the gradient w.r.t. out
. Our goal is to find d_x
and d_w
, which are the gradient w.r.t. x
and w
respectively.
For the ease of discussion, we assume:
- All stride size to be
1
- All
in_channels
andout_channels
are1
- Use
VALID
padding - Odd number filter size, this avoids some asymmetric shape problem
Short Answer
Conceptually, with the assumptions above, we have the following relations:
out = conv2d(x, w, padding='VALID')
d_x = conv2d(d_out, rot180(w), padding='FULL')
d_w = conv2d(x, d_out, padding='VALID')
Where rot180
is a 2d matrix rotated 180 degrees (a left-right flip and a top-down flip), FULL
means "apply filter wherever it partly overlaps with the input" (see theano docs). Notes that this is only valid with the above assumptions, however, one can change the conv2d arguments to generalize it.
The key takeaways:
- The input gradient
d_x
is the convolution of the output gradientd_out
and the weightw
, with some modifications. - The weight gradient
d_w
is the convolution of the inputx
and the output gradientd_out
, with some modifications.
Long Answer
Now, let's give an actual working code example of how to use the 4 functions above to compute d_x
and d_w
given d_out
. This shows how
conv2d
,
conv2d_backprop_filter
,
conv2d_backprop_input
, and
conv2d_transpose
are related to each other.
Please find the full scripts here.
Computing d_x
in 4 different ways:
# Method 1: TF's autodiff
d_x = tf.gradients(f, x)[0]
# Method 2: manually using conv2d
d_x_manual = tf.nn.conv2d(input=tf_pad_to_full_conv2d(d_out, w_size),
filter=tf_rot180(w),
strides=strides,
padding='VALID')
# Method 3: conv2d_backprop_input
d_x_backprop_input = tf.nn.conv2d_backprop_input(input_sizes=x_shape,
filter=w,
out_backprop=d_out,
strides=strides,
padding='VALID')
# Method 4: conv2d_transpose
d_x_transpose = tf.nn.conv2d_transpose(value=d_out,
filter=w,
output_shape=x_shape,
strides=strides,
padding='VALID')
Computing d_w
in 3 different ways:
# Method 1: TF's autodiff
d_w = tf.gradients(f, w)[0]
# Method 2: manually using conv2d
d_w_manual = tf_NHWC_to_HWIO(tf.nn.conv2d(input=x,
filter=tf_NHWC_to_HWIO(d_out),
strides=strides,
padding='VALID'))
# Method 3: conv2d_backprop_filter
d_w_backprop_filter = tf.nn.conv2d_backprop_filter(input=x,
filter_sizes=w_shape,
out_backprop=d_out,
strides=strides,
padding='VALID')
Please see the full scripts for the implementation of tf_rot180
, tf_pad_to_full_conv2d
, tf_NHWC_to_HWIO
. In the scripts, we check that the final output values of different methods are the same; a numpy implementation is also available.
conv2d_transpose() simply transposes the weights and flips them by 180 degrees. Then it applies the standard conv2d(). "Transposes" practically means that it changes the order of the "columns" in the weights tensor. Please check the example below.
Here there is an example that uses convolutions with stride=1 and padding='SAME'. It is a simple case but the same reasoning could be applied to the other cases.
Say we have:
- Input: MNIST image of 28x28x1, shape = [28,28,1]
- Convolutional layer: 32 filters of 7x7, weights shape = [7, 7, 1, 32], name = W_conv1
If we perform convolution of the input then the activations of the will have shape: [1,28,28,32].
activations = sess.run(h_conv1,feed_dict={x:np.reshape(image,[1,784])})
Where:
W_conv1 = weight_variable([7, 7, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = conv2d(x, W_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1
To obtain the "deconvolution" or "transposed convolution" we can use conv2d_transpose() on the convolution activations in this way:
deconv = conv2d_transpose(activations,W_conv1, output_shape=[1,28,28,1],padding='SAME')
OR using conv2d() we need to transpose and flip the weights:
transposed_weights = tf.transpose(W_conv1, perm=[0, 1, 3, 2])
Here we change the order of the "colums" from [0,1,2,3] to [0,1,3,2].So from [7, 7, 1, 32] we will obtain a tensor with shape=[7,7,32,1]. Then we flip the weights:
for i in range(n_filters):
# Flip the weights by 180 degrees
transposed_and_flipped_weights[:,:,i,0] = sess.run(tf.reverse(transposed_weights[:,:,i,0], axis=[0, 1]))
Then we can compute the convolution with conv2d() as:
strides = [1,1,1,1]
deconv = conv2d(activations,transposed_and_flipped_weights,strides=strides,padding='SAME')
And we will obtain the same result as before. Also the very same result can be obtained with conv2d_backprop_input() using:
deconv = conv2d_backprop_input([1,28,28,1],W_conv1,activations, strides=strides, padding='SAME')
The results are shown here:
Test of the conv2d(), conv2d_tranposed() and conv2d_backprop_input()
We can see that the results are the same. To see it in a better way please check my code at:
https://github.com/simo23/conv2d_transpose
Here I replicate the output of the conv2d_transpose() function using the standard conv2d().