Both eager and graph execution in tensorflow tests
With the caveat that anything in the tf.contrib
namespace is subject to change between releases, you can decorate your test with @tf.contrib.eager.run_test_in_graph_and_eager_modes
. Some other projects, like TensorFlow Probability seem to use this.
For non-tests, some things to look into are:
tf.contrib.eager.defun
: Is useful when you have eager execution enabled but want to "compile" some computation into a graph to benefit from memory and/or performance optimizations.tf.contrib.eager.py_func
: Is useful when do not have eager execution enabled but want to execute some computation in the graph as a Python function.
One may question the reasoning behind not allowing a call to tf.enable_eager_execution()
to be undone. The idea is that library authors should not invoke it, only the end-user should invoke it in main()
. The reduces the chances that libraries are written incompatible ways (where say functions in one library disable eager execution and return symbolic tensors while functions in another library enable eager execution and expects concrete valued tensors. This would make mixing the libraries problematic).
Hope that helps
There is an official way to use eager execution in a graph environment. But I'm not sure if this is good and convenient enough for you because you need to write quite some code to wrap and run your test function. Anyway, here is your example which should at least work:
import numpy as np
import tensorflow as tf
def test_normal_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.make_one_shot_iterator()
first_elem = iterator.get_next()
with tf.Session() as sess:
result = sess.run(first_elem)
assert (result == [1, 2, 3, 4]).all()
sess.close()
def test_eager_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.__iter__()
first_elem = iterator.next()
assert (first_elem.numpy() == [1, 2, 3, 4]).all()
test_normal_execution()
# test_eager_execution() # Instead, you have to use the following three lines.
with tf.Session() as sess:
tfe = tf.contrib.eager
sess.run(tfe.py_func(test_eager_execution, [], []))