Tensorflow: What does tf.nn.separable_conv2d do?
tf.nn.separable_conv2d()
implements the so-called 'separable convolution' described on slide 26 and onwards of this talk.
The idea is that instead of convolving jointly across all channels of an image, you run a separate 2D convolution on each channel with a depth of channel_multiplier
. The in_channels * channel_multiplier
intermediate channels get concatenated together, and mapped to out_channels
using a 1x1 convolution.
It's often an effective way to reduce the parametric complexity of early convolutions in a convnet, and can materially speed up training. channel_multiplier
controls that complexity, and would typically be 4 to 8 for a RGB input. For a grayscale input, using it makes little sense.
In the regular 2D convolution performed over multiple input channels, the filter is as deep as the input and lets us freely mix channels to generate each element in the output. Depthwise convolutions don't do that - each channel is kept separate - hence the name depthwise. Here's a diagram to help explain how that works[1]:
If you look at the official documentation you will find:
output[b, i, j, k] = sum_{di, dj, q, r}
input[b, strides[1] * i + di, strides[2] * j + dj, q] *
depthwise_filter[di, dj, q, r] *
pointwise_filter[0, 0, q * channel_multiplier + r, k]
And a sample code in tensorflow to test:
import tensorflow as tf
import numpy as np
width = 8
height = 8
batch_size = 100
filter_height = 3
filter_width = 3
in_channels = 3
channel_multiplier = 1
out_channels = 3
input_tensor = tf.get_variable(shape=(batch_size, height, width, in_channels), name="input")
depthwise_filter = tf.get_variable(shape=(filter_height, filter_width, in_channels, channel_multiplier), name="deptwise_filter")
pointwise_filter = tf.get_variable(shape=[1, 1, channel_multiplier * in_channels, out_channels], name="pointwise_filter")
output = tf.nn.separable_conv2d(
input_tensor,
depthwise_filter,
pointwise_filter,
strides=[1,1,1,1],
padding='SAME',
)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_value = sess.run(output, feed_dict={input_tensor: np.random.rand(batch_size, width, height, in_channels),
depthwise_filter: np.random.rand(filter_height, filter_width, in_channels, channel_multiplier),
pointwise_filter: np.random.rand(1, 1, channel_multiplier * in_channels, out_channels)})
print(np.shape(output_value))
credit:
[1] https://eli.thegreenplace.net/2018/depthwise-separable-convolutions-for-machine-learning/
[2] https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d