How to understand static shape and dynamic shape in TensorFlow?
Sometimes the shape of a tensor depends on a value that is computed at runtime. Let's take the following example, where x
is defined as a tf.placeholder()
vector with four elements:
x = tf.placeholder(tf.int32, shape=[4])
print x.get_shape()
# ==> '(4,)'
The value of x.get_shape()
is the static shape of x
, and the (4,
) means that it is a vector of length 4. Now let's apply the tf.unique()
op to x
y, _ = tf.unique(x)
print y.get_shape()
# ==> '(?,)'
The (?,)
means that y
is a vector of unknown length. Why is it unknown? tf.unique(x)
returns the unique values from x
, and the values of x
are unknown because it is a tf.placeholder()
, so it doesn't have a value until you feed it. Let's see what happens if you feed two different values:
sess = tf.Session()
print sess.run(y, feed_dict={x: [0, 1, 2, 3]}).shape
# ==> '(4,)'
print sess.run(y, feed_dict={x: [0, 0, 0, 0]}).shape
# ==> '(1,)'
Hopefully this makes it clear that a tensor can have a different static and dynamic shape. The dynamic shape is always fully defined—it has no ?
dimensions—but the static shape can be less specific. This is what allows TensorFlow to support operations like tf.unique()
and tf.dynamic_partition()
, which can have variable-sized outputs, and are used in advanced applications.
Finally, the tf.shape()
op can be used to get the dynamic shape of a tensor and use it in a TensorFlow computation:
z = tf.shape(y)
print sess.run(z, feed_dict={x: [0, 1, 2, 3]})
# ==> [4]
print sess.run(z, feed_dict={x: [0, 0, 0, 0]})
# ==> [1]
Here's a schematic image showing both:
Tensorflow 2.0 Compatible Answer: Mentioning the Code which mrry has specified in his Answer, in Tensorflow Version 2.x (> 2.0)
, for the benefit of the Community.
# Installing the Tensorflow Version 2.1
!pip install tensorflow==2.1
# If we don't Disable the Eager Execution, usage of Placeholder results in RunTimeError
tf.compat.v1.disable_eager_execution()
x = tf.compat.v1.placeholder(tf.int32, shape=[4])
print(x.get_shape())
# ==> 4
y, _ = tf.unique(x)
print(y.get_shape())
# ==> (None,)
sess = tf.compat.v1.Session()
print(sess.run(y, feed_dict={x: [0, 1, 2, 3]}).shape)
# ==> '(4,)'
print(sess.run(y, feed_dict={x: [0, 0, 0, 0]}).shape)
# ==> '(1,)'
z = tf.shape(y)
print(sess.run(z, feed_dict={x: [0, 1, 2, 3]}))
# ==> [4]
print(sess.run(z, feed_dict={x: [0, 0, 0, 0]}))
# ==> [1]