tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

len(list(dataset)) works in eager mode, although that's obviously not a good general solution.


Take a look here: https://github.com/tensorflow/tensorflow/issues/26966

It doesn't work for TFRecord datasets, but it works fine for other types.

TL;DR:

num_elements = tf.data.experimental.cardinality(dataset).numpy()


tf.data.Dataset.list_files creates a tensor called MatchingFiles:0 (with the appropriate prefix if applicable).

You could evaluate

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

to get the number of files.

Of course, this would work in simple cases only, and in particular if you have only one sample (or a known number of samples) per image.

In more complex situations, e.g. when you do not know the number of samples in each file, you can only observe the number of samples as an epoch ends.

To do this, you can watch the number of epochs that is counted by your Dataset. repeat() creates a member called _count, that counts the number of epochs. By observing it during your iterations, you can spot when it changes and compute your dataset size from there.

This counter may be buried in the hierarchy of Datasets that is created when calling member functions successively, so we have to dig it out like this.

d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count

Note that with this technique, the computation of your dataset size is not exact, because the batch during which epoch_counter is incremented typically mixes samples from two successive epochs. So this computation is precise up to your batch length.