What is the sequence of SessionRunHook's member function to be called?
You can find a tutorial here, a little long but you can jump the part of building the network. Or you can read my small summary below, based on my experiance.
First, MonitoredSession
should be used instead of normal Session
.
A SessionRunHook extends
session.run()
calls for theMonitoredSession
.
Then some common SessionRunHook
classes can be found here. A simple one is LoggingTensorHook
but you might want to add the following line after your imports for seeing the logs when running:
tf.logging.set_verbosity(tf.logging.INFO)
Or you have option to implement your own SessionRunHook
class. A simple one is from cifar10 tutorial
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
where loss
is defined outside the class. This _LoggerHook
uses print
to print the information while LoggingTensorHook
uses tf.logging.INFO
.
At last, for better understanding how it works, the execution order is presented by pseudocode with MonitoredSession
here:
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested: # py code: while not mon_sess.should_stop():
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
Hope this helps.
tf.SessionRunHook
enables you to add your custom code during each session run command you execute in your code. To understand it, I have created a simple example below:
- We want to print loss values after each update of the parameters.
- We will use
SessionRunHook
to achieve this.
Create a tensorflow Graph
import tensorflow as tf
import numpy as np
x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
w = tf.Variable(initial_value=[[10.], [10.]])
w0 = [[1], [1.]]
y = tf.matmul(x, w0)
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)
Creating the Hook
class _Hook(tf.train.SessionRunHook):
def __init__(self, loss):
self.loss = loss
def begin(self):
pass
def before_run(self, run_context):
return tf.train.SessionRunArgs(self.loss)
def after_run(self, run_context, run_values):
loss_value = run_values.results
print("loss value:", loss_value)
Creating a monitored Session with a hook
sess = tf.train.MonitoredSession(hooks=[_Hook(loss)])
train
for _ in range(10):
x_ = np.random.random((10, 2))
sess.run(optimizer, {x: x_})
# Output
loss value: 21.244701
loss value: 19.39169
loss value: 16.02665
loss value: 16.717144
loss value: 15.389178
loss value: 16.23935
loss value: 14.299083
loss value: 9.624525
loss value: 5.654896
loss value: 10.689494