Epoch counter with TensorFlow Dataset API
To add to @mrry's great answer, if you want to stay within the tf.data
pipeline and also want to track the iteration within each epoch you can try my solution below. If you have non-unit batch size I guess you would have to add the line data = data.batch(bs)
.
import tensorflow as tf
import itertools
def step_counter():
for i in itertools.count(): yield i
num_examples = 3
num_epochs = 2
num_iters = num_examples * num_epochs
features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)
data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)
step = tf.data.Dataset.from_generator(step_counter, tf.int32)
data = tf.data.Dataset.zip((data, step))
epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
lambda i: tf.data.Dataset.zip(
(data, tf.data.Dataset.from_tensors(i).repeat())))
data = data.repeat(num_epochs)
it = data.make_one_shot_iterator()
example = it.get_next()
with tf.Session() as sess:
for _ in range(num_iters):
((x, y), st), ep = sess.run(example)
print(f'step {st} \t epoch {ep} \t x {x} \t y {y}')
Prints:
step 0 epoch 0 x 2 y 2
step 1 epoch 0 x 0 y 0
step 2 epoch 0 x 1 y 1
step 0 epoch 1 x 2 y 2
step 1 epoch 1 x 0 y 0
step 2 epoch 1 x 1 y 1
TL;DR: Replace the definition of epoch_counter
with the following:
epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
trainable=False, use_resource=True)
There are some limitations around using TensorFlow variables inside tf.data.Dataset
transformations. The principle limitation is that all variables must be "resource variables" and not the older "reference variables"; unfortunately tf.Variable
still creates "reference variables" for backwards compatibility reasons.
Generally speaking, I wouldn't recommend using variables in a tf.data
pipeline if it's possible to avoid it. For example, you might be able to use Dataset.range()
to define an epoch counter, and then do something like:
epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
(pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))
The above snippet attaches an epoch counter to every value as a second component.