What is the use of a *.pb file in TensorFlow and how does it work?
pb
stands for protobuf. In TensorFlow, the protbuf file contains the graph definition as well as the weights of the model. Thus, a pb
file is all you need to be able to run a given trained model.
Given a pb
file, you can load it as follow.
def load_pb(path_to_pb):
with tf.gfile.GFile(path_to_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
Once you have loaded the graph, you can basically do anything. For instance, you can retrieve tensors of interest with
input = graph.get_tensor_by_name('input:0')
output = graph.get_tensor_by_name('output:0')
and use regular TensorFlow routine like:
sess.run(output, feed_dict={input: some_data})
Explanation
The .pb
format is the protocol buffer (protobuf) format, and in Tensorflow, this format is used to hold models. Protobufs are a general way to store data by Google that is much nicer to transport, as it compacts the data more efficiently and enforces a structure to the data. When used in TensorFlow, it's called a SavedModel protocol buffer, which is the default format when saving Keras/ Tensorflow 2.0 models. More information about this format can be found here and here.
For example, the following code (specifically, m.save
), will create a folder called my_new_model
, and save in it, the saved_model.pb
, an assets/
folder, and a variables/
folder.
# first download a SavedModel from TFHub.dev, a website with models
m = tf.keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4")
])
m.build([None, 224, 224, 3]) # Batch input shape.
m.save("my_new_model") # defaults to save as SavedModel in tensorflow 2
In some places, you may also see .h5
models, which was the default format for TF 1.X. source
Extra information: In TensorFlow Lite, the library for running models on mobile and IoT devices, instead of protocol buffers, flatbuffers are used. This is what the TensorFlow Lite Converter converts into (.tflite
format). This is another Google format which is also very efficient: it allows access to any part of the message without deserialization (unlike json, xml). For devices with less memory (RAM), it makes more sense to load what you need from the model file, instead of loading the entire thing into memory to deserialize it.
Loading SavedModels in TensorFlow 2
I noticed BiBi's answer to show loading models was popular, and there is a shorter way to do this in TF2:
import tensorflow as tf
model_path = "/path/to/directory/inception_v1_224_quant_20181026"
model = tf.saved_model.load(model_path)
Note,
- the directory (i.e.
inception_v1_224_quant_20181026
) has to have asaved_model.pb
orsaved_model.pbtxt
, otherwise the code will crash. You cannot specify the.pb
path, specify the directory. - you might get
TypeError: 'AutoTrackable' object is not callable
for older models, fix here.
If you load a TF1 model, I found that I don't get any errors, but the loaded file doesn't behave as expected. (e.g. it doesn't have any functions on it, like predict)