Tensorflow 2.0 dataset and dataloader
I am not familiar with Pytorch but Tensorflow implements the Keras API which has the Sequence class that is:
Base object for fitting to a sequence of data, such as a dataset
https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence
This class contains getitem for an index.
When using the tf.data
API, you will usually also make use of the map
function.
In PyTorch, your __getItem__
call basically fetches an element from your data structure given in __init__
and transforms it if necessary.
In TF2.0, you do the same by initializing a Dataset
using one of the Dataset.from_...
functions (see from_generator
, from_tensor_slices
, from_tensors
); this is essentially the __init__
part of a PyTorch Dataset
. Then, you can call map
to do the element-wise manipulations you would have in __getItem__
.
Tensorflow datasets are pretty much fancy iterators, so by design you don't access their elements using indices, but rather by traversing them.
The guide on tf.data
is very useful and provides a wide variety of examples.