How to get the global_step when restoring checkpoints in Tensorflow?
General pattern is to have a global_step
variable to keep track of steps
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
Then you can save with
saver.save(sess, save_path, global_step=global_step)
When you restore, the value of global_step
is restored as well
This is a bit of a hack, but the other answers did not work for me at all
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
Update 9/2017
I'm not sure if this started working due to updates, but the following method seems to be effective in getting global_step to update and load properly:
Create two ops. One to hold global_step and another to increment it:
global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')
Now in your training loop run the increment op every time you run your training op.
sess.run([train_op,increment_global_step],feed_dict=feed_dict)
If you ever want to retrieve you global step value as an integer at any point, just use the following command after loading the model:
sess.run(global_step)
This can be useful for creating filenames or calculating what your current epoch is without having a second tensorflow Variable for holding that value. For instance, calculating the current epoch on loading would be something like:
loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)