how to load and use a saved model on tensorflow?
Tensorflow
's preferred way of building and using a model in different languages is tensorflow serving
Now in your case, you are using saver.save
to save the model. This way it saves a meta
file, ckpt
file and some other files to save the weights and network information, steps trained etc. This is the preferred way of saving while you are training.
If you are done with training now you should freeze the graph using SavedModelBuilder
from the files you save by saver.save
. This frozen graph contains a pb
file and contains all the network and weights.
This frozen model should be used to serve by tensorflow serving
and then other languages can use the model using gRPC
protocol.
The whole procedure is described in this excellent tutorial.
A code snippet that worked for me to load a pb file and inference on a single image.
The code follows the following steps: load the pb file into a GraphDef (a serialized version of a graph (used to read pb files), load GraphDef into a Graph, get input and output tensors by their names, inference on a single image.
import tensorflow as tf
import numpy as np
import cv2
INPUT_TENSOR_NAME = 'input_tensor_name:0'
OUTPUT_TENSOR_NAME = 'output_tensor_name:0'
# Read image, get shape
# Add dimension to fit batch shape
img = cv2.imread(IMAGE_PATH)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = img.astype(float)
height, width, channels = image.shape
image = np.expand_dims(image, 0) # Add dimension (to fit batch shape)
# Read pb file into the graph as GraphDef - Serialized version of a graph (used to read pb files)
with tf.gfile.FastGFile(PB_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Load GraphDef into Graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="")
# Get tensors (input and output) by name
input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME)
output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME)
# Inference on single image
with tf.Session(graph=graph) as sess:
output_vals = sess.run(output_tensor, feed_dict={input_tensor: image}) #
Here's the code snippet to load and restore/predict models using the simple_save
#Save the model:
tf.saved_model.simple_save(sess, export_dir=saveModelPath,
inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,
"isTrainingBool": isTraining},
outputs={"predictedClassBatch": predClass})
Note that using simple_save sets certain default values (this can be seen at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py)
Now, to restore and use the inputs/outputs dict:
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants
with tf.Session() as sess:
model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.
inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name
inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)
inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name
inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)
isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name
isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)
outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name
outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)
outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})
print("predicted classes:", outPred)
Note: the default signature_def was needed to make use of the tensor names specified in the input & output dicts.
What was missing was the signature
# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"x": x},
outputs= {"finalnode": model})
})
builder.save()
# loading
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["tag"], export_dir)
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
model = graph.get_tensor_by_name("finalnode:0")
print(sess.run(model, {x: [5, 6, 7, 8]}))