keras tensorboard: plot train and validation scalars in a same figure
To handle the validation logs with a separate writer, you can write a custom callback that wraps around the original TensorBoard
methods.
import os
import tensorflow as tf
from keras.callbacks import TensorBoard
class TrainValTensorBoard(TensorBoard):
def __init__(self, log_dir='./logs', **kwargs):
# Make the original `TensorBoard` log to a subdirectory 'training'
training_log_dir = os.path.join(log_dir, 'training')
super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
# Log the validation metrics to a separate subdirectory
self.val_log_dir = os.path.join(log_dir, 'validation')
def set_model(self, model):
# Setup writer for validation metrics
self.val_writer = tf.summary.FileWriter(self.val_log_dir)
super(TrainValTensorBoard, self).set_model(model)
def on_epoch_end(self, epoch, logs=None):
# Pop the validation logs and handle them separately with
# `self.val_writer`. Also rename the keys so that they can
# be plotted on the same figure with the training metrics
logs = logs or {}
val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
for name, value in val_logs.items():
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value.item()
summary_value.tag = name
self.val_writer.add_summary(summary, epoch)
self.val_writer.flush()
# Pass the remaining logs to `TensorBoard.on_epoch_end`
logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)
def on_train_end(self, logs=None):
super(TrainValTensorBoard, self).on_train_end(logs)
self.val_writer.close()
- In
__init__
, two subdirectories are set up for training and validation logs - In
set_model
, a writerself.val_writer
is created for the validation logs - In
on_epoch_end
, the validation logs are separated from the training logs and written to file withself.val_writer
Using the MNIST dataset as an example:
from keras.models import Sequential
from keras.layers import Dense
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=(784,)))
model.add(Dense(10, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10,
validation_data=(x_test, y_test),
callbacks=[TrainValTensorBoard(write_graph=False)])
You can then visualize the two curves on a same figure in TensorBoard.
EDIT: I've modified the class a bit so that it can be used with eager execution.
The biggest change is that I use tf.keras
in the following code. It seems that the TensorBoard
callback in standalone Keras does not support eager mode yet.
import os
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.python.eager import context
class TrainValTensorBoard(TensorBoard):
def __init__(self, log_dir='./logs', **kwargs):
self.val_log_dir = os.path.join(log_dir, 'validation')
training_log_dir = os.path.join(log_dir, 'training')
super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
def set_model(self, model):
if context.executing_eagerly():
self.val_writer = tf.contrib.summary.create_file_writer(self.val_log_dir)
else:
self.val_writer = tf.summary.FileWriter(self.val_log_dir)
super(TrainValTensorBoard, self).set_model(model)
def _write_custom_summaries(self, step, logs=None):
logs = logs or {}
val_logs = {k.replace('val_', ''): v for k, v in logs.items() if 'val_' in k}
if context.executing_eagerly():
with self.val_writer.as_default(), tf.contrib.summary.always_record_summaries():
for name, value in val_logs.items():
tf.contrib.summary.scalar(name, value.item(), step=step)
else:
for name, value in val_logs.items():
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value.item()
summary_value.tag = name
self.val_writer.add_summary(summary, step)
self.val_writer.flush()
logs = {k: v for k, v in logs.items() if not 'val_' in k}
super(TrainValTensorBoard, self)._write_custom_summaries(step, logs)
def on_train_end(self, logs=None):
super(TrainValTensorBoard, self).on_train_end(logs)
self.val_writer.close()
The idea is the same --
- Check the source code of
TensorBoard
callback - See what it does to set up the writer
- Do the same thing in this custom callback
Again, you can use the MNIST data to test it,
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.train import AdamOptimizer
tf.enable_eager_execution()
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = y_train.astype(int)
y_test = y_test.astype(int)
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=(784,)))
model.add(Dense(10, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer=AdamOptimizer(), metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10,
validation_data=(x_test, y_test),
callbacks=[TrainValTensorBoard(write_graph=False)])
If you are using TensorFlow 2.0, you now get this by default using the Keras TensorBoard callback. (When using TensorFlow with Keras, make sure you're using tensorflow.keras.)
See this tutorial:
https://www.tensorflow.org/tensorboard/r2/scalars_and_keras