connect input and output tensors of two different graphs tensorflow
Accepted answer does connect two graphs, however it does not restore the collections, global and trainable variables. After an exhaustive search I came to a better solution:
import tensorflow as tf
from tensorflow.python.framework import meta_graph
with tf.Graph().as_default() as graph1:
input = tf.placeholder(tf.float32, (None, 20), name='input')
output = tf.identity(input, name='output')
with tf.Graph().as_default() as graph2:
input = tf.placeholder(tf.float32, (None, 20), name='input')
output = tf.identity(input, name='output')
graph = tf.get_default_graph()
x = tf.placeholder(tf.float32, (None, 20), name='input')
We use tf.train.export_meta_graph
that exports also CollectionDef and meta_graph.import_scoped_meta_graph
to import it. Here is where the connection happens, specifically in input_map
parameter.
meta_graph1 = tf.train.export_meta_graph(graph=graph1)
meta_graph.import_scoped_meta_graph(meta_graph1, input_map={'input': x}, import_scope='graph1')
out1 = graph.get_tensor_by_name('graph1/output:0')
meta_graph2 = tf.train.export_meta_graph(graph=graph2)
meta_graph.import_scoped_meta_graph(meta_graph2, input_map={'input': out1}, import_scope='graph2')
Now graph is connected as well as global variables are being re-mapped.
print(tf.global_variables())
You can also import meta graphs directly from a file.
Assuming that your Protobuf files contain serialized tf.GraphDef
protos, you can use the input_map
argument of tf.import_graph_def()
to connect the two graphs:
# Import graph1.
graph1_def = ... # tf.GraphDef object
out1_name = "..." # name of the graph1out tensor in graph1_def.
graph1out, = tf.import_graph_def(graph1_def, return_elements=[out_name])
# Import graph2 and connect it to graph1.
graph2_def = ... # tf.GraphDef object
inp2_name = "..." # name of the graph2inp tensor in graph2_def.
out2_name = "..." # name of the graph2out tensor in graph2_def.
graph2out, = tf.import_graph_def(graph2_def, input_map={inp2_name: graph1out},
return_elements=[out2_name])