How do I find the variable names and values that are saved in a checkpoint?
You can use the inspect_checkpoint.py
tool.
So, for example, if you stored the checkpoint in the current directory, then you can print the variables and their values as follows
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
Example usage:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
Update: all_tensors
argument was added to print_tensors_in_checkpoint_file
since Tensorflow 0.12.0-rc0 so you may need to add all_tensors=False
or all_tensors=True
if required.
Alternative method:
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
Hope it helps.
A few more details.
If your model is saved using V2 format, for example, if we have the following files in the directory /my/dir/
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta
then the file_name
parameter should only be the prefix, that is
print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
See https://github.com/tensorflow/tensorflow/issues/7696 for a discussion.
An update to the answers mentioned above
For latest Tensorflow versions (verified on TF 1.13+), a cleaner way to do is as follows
ckpt_reader = tf.train.load_checkpoint(ckpt_dir_or_file)
value = ckpt_reader.get_tensor(name_of_the_tensor)
The name_of_the_tensor
should correspond the variable name (whose value you're trying to inspect). To get a list of variable names and shapes in a checkpoint, you can check via
vars_list = tf.train.list_variables(ckpt_dir_or_file)