How to use the function merge and switch of tensorflow?
switch
Let's start by examining the control_flow_ops.switch
function:
x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
with tf.Session() as sess:
print(sess.run(x_0)) # prints 2
print(sess.run(x_3)) # prints 7
control_flow_ops.switch
returns a tuple of tensors, but only one of them will have a value (depending on the condition argument). In the example above, it's x_0 = 2
from the first switch
and x_3 = 7
from the second one. An attempt to evaluate x_1
or x_2
will result in Retval does not have value error:
sess.run(x_1) # FAILS!
sess.run(x_2) # FAILS!
In other words, x_0
and x_3
are available, while x_1
or x_2
aren't.
merge
control_flow_ops.merge
performs an inverse op: given a tuple of tensors, it selects the available one. Precisely, it returns a named tuple ["output", "value_index"]
of a tensor that has a value. According to the current doc, the input should contain exactly one available tensor, this means that your demo is strictly speaking unsupported and leads to undefined behavior. Here's an example:
with tf.Session() as sess:
print(sess.run(merge([x_0, x_1]))) # Merge(output=2, value_index=0)
print(sess.run(merge([x_1, x_0]))) # Merge(output=2, value_index=1)
print(sess.run(merge([x_2, x_3]))) # Merge(output=7, value_index=1)
print(sess.run(merge([x_3, x_2]))) # Merge(output=7, value_index=0)
print(sess.run(merge([x_0, x_1, x_2]))) # Merge(output=2, value_index=0)
print(sess.run(merge([x_1, x_2, x_3]))) # Merge(output=7, value_index=2)
Both of these functions can be handy to control computation flow, e.g. control_flow_ops.switch
gradient is implemented through switch
itself (tensorflow source code).