Extract target from Tensorflow PrefetchDataset
You can turn use map
to select either the input or label from every (input, label)
pair, and turn this into a list:
import tensorflow as tf
import numpy as np
inputs = np.random.rand(100, 99)
targets = np.random.rand(100)
ds = tf.data.Dataset.from_tensor_slices((inputs, targets))
X_train = list(map(lambda x: x[0], ds))
y_train = list(map(lambda x: x[1], ds))
If you want to retain the batches or extract all the labels as a single tensor you could use the following function:
def get_labels_from_tfdataset(tfdataset, batched=False):
labels = list(map(lambda x: x[1], tfdataset)) # Get labels
if not batched:
return tf.concat(labels, axis=0) # concat the list of batched labels
return labels
You can convert it to a list with list(ds)
and then recompile it as a normal Dataset with tf.data.Dataset.from_tensor_slices(list(ds))
. From there your nightmare begins again but at least it's a nightmare that other people have had before.
Note that for more complex datasets (e.g. nested dictionaries) you will need more preprocessing after calling list(ds)
, but this should work for the example you asked about.
This is far from a satisfying answer but unfortunately the class is entirely undocumented and none of the standard Dataset tricks work.
You can generate lists by looping your PrefetchDataset which is train_dataset in my example;
train_data = [(example.numpy(), label.numpy()) for example, label in train_dataset]
Thus you can reach every single example and label separately by using indexes;
train_data[0][0]
train_data[0][1]
You can also convert them into data frame with 2 columns by using pandas
import pandas as pd
pd.DataFrame(train_data, columns=['example', 'label'])
Then, if you want to convert back your list into PrefetchFataset, you can simply use ;
dataset = tf.data.Dataset.from_generator(
lambda: train_data, ( tf.string, tf.int32)) # you should define dtypes of yours
And you can check if it worked with this ;
list(dataset.as_numpy_iterator())