Given a tensor flow model graph, how to find the input node and output node names
Try this:
run python
>>> import tensorflow as tf
>>> gf = tf.GraphDef()
>>> gf.ParseFromString(open('/your/path/to/graphname.pb','rb').read())
and then
>>> [n.name + '=>' + n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')]
Then, you can get result similar to this:
['Mul=>Placeholder', 'final_result=>Softmax']
But I'm not sure it's the problem of node names regarding the error messages. I guess you provided wrong arguements when loading the graph file or your generated graph file is something wrong?
Check this part:
E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible
shapes: [1,224,224,3] vs. [32,1,1,2048]
UPDATE: Sorry, if you're using (re)trained graph , then try this:
[n.name + '=>' + n.op for n in gf.node if n.op in ( 'Softmax','Mul')]
It seems that (re)trained graph saves input/output op name as "Mul" and "Softmax", while optimized and/or quantized graph saves them as "Placeholder" and "Softmax".
BTW, using retrained graph in mobile environment is not recommended according to Peter Warden's post: https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/ . It's better to use quantized or memmapped graph due to performance and file size issue, I couldn't find out how to load memmapped graph in android though...:( (no problem loading optimized / quantized graph in android)
Recently I came across this option directly from tensorflow:
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph
--in_graph=custom_graph_name.pb
I wrote a simple script to analyze the dependency relations in a computational graph (usually a DAG, directly acyclic graph). It's so obvious that the inputs are the nodes that lack a input. However, outputs can be defined as any nodes in a graph because, in the weirdest but still valid case, outputs can be inputs while the other nodes are all dummy. I still define the output operations as nodes without output in the code. You could neglect it at your willing.
import tensorflow as tf
def load_graph(frozen_graph_filename):
with tf.io.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
return graph
def analyze_inputs_outputs(graph):
ops = graph.get_operations()
outputs_set = set(ops)
inputs = []
for op in ops:
if len(op.inputs) == 0 and op.type != 'Const':
inputs.append(op)
else:
for input_tensor in op.inputs:
if input_tensor.op in outputs_set:
outputs_set.remove(input_tensor.op)
outputs = list(outputs_set)
return (inputs, outputs)