How to Properly Combine TensorFlow's Dataset API and Keras?
Update June 09, 2018
- Starting from Tensorflow 1.9, one can pass
tf.data.Dataset
object directly intokeras.Model.fit()
and it would act similar tofit_generator
. - A complete example can be found on this gist.
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)
model = # your keras model here
model.fit(
training_set.make_one_shot_iterator(),
steps_per_epoch=len(x_train) // 128,
epochs=5,
verbose = 1)
tfdata_generator
is a function that returns an iterabletf.data.Dataset
.
def tfdata_generator(images, labels, is_training, batch_size=128):
'''Construct a data generator using `tf.Dataset`. '''
def map_fn(image, label):
'''Preprocess raw data to trainable input. '''
x = tf.reshape(tf.cast(image, tf.float32), (28, 28, 1))
y = tf.one_hot(tf.cast(label, tf.uint8), _NUM_CLASSES)
return x, y
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
if is_training:
dataset = dataset.shuffle(1000) # depends on sample size
dataset = dataset.map(map_fn)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
return dataset
Old Solution:
In addition to @Yu-Yang's answer, you can also modify tf.data.Dataset
to become a generator for fit_generator
as following
from tensorflow.contrib.learn.python.learn.datasets import mnist
data = mnist.load_mnist()
model = # your Keras model
model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels),
steps_per_epoch=200,
workers = 0 , # This is important
verbose = 1)
def tfdata_generator(images, labels, batch_size=128, shuffle=True,):
def map_func(image, label):
'''A transformation function'''
x_train = tf.reshape(tf.cast(image, tf.float32), image_shape)
y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes)
return [x_train, y_train]
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.map(map_func)
dataset = dataset.shuffle().batch(batch_size).repeat()
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
while True:
yield K.get_session().run(next_batch)
There is indeed a more efficient way to use Dataset
without having to convert the tensors into numpy arrays. However, it is not (yet?) on the official documentation. From the release note, it's a feature introduced in Keras 2.0.7. You may have to install keras>=2.0.7 in order to use it.
x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()
y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()
input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)
Several differences:
- Supply the
tensor
argument to theInput
layer. Keras will read values from this tensor, and use it as the input to fit the model. - Supply the
target_tensors
argument toModel.compile()
. - Remember to convert both x and y into
float32
. Under normal usage, Keras will do this conversion for you. But now you'll have to do it yourself. - Batch size is specified during the construction of
Dataset
. Usesteps_per_epoch
andepochs
to control when to stop model fitting.
In short, use Input(tensor=...)
, model.compile(target_tensors=...)
and model.fit(x=None, y=None, ...)
if your data are to be read from tensors.
The other answers are good, however it is important to note that using from_tensor_slices
directly with large numpy arrays can quickly fill up your memory as, IIRC, the values are copied into the graph as tf.constants
. In my experience, this will cause a silent failure where training will eventually start but no improvement in loss etc will occur.
A better way is to use placeholders. E.g. here is my code to create a generator for images and their onehot targets:
def create_generator_tf_dataset(self, images, onehots, batch_size):
# Get shapes
img_size = images.shape
img_size = (None, img_size[1], img_size[2], img_size[3])
onehot_size = onehots.shape
onehot_size = (None, onehot_size[1])
# Placeholders
images_tensor = tf.placeholder(tf.float32, shape=img_size)
onehots_tensor = tf.placeholder(tf.float32, shape=onehot_size)
# Dataset
dataset = tf.data.Dataset.from_tensor_slices((images_tensor, onehots_tensor))
# Map function (e.g. augmentation)
if map_fn is not None:
dataset = dataset.map(lambda x, y: (map_fn(x), y), num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Combined shuffle and infinite repeat
dataset = dataset.apply(
tf.data.experimental.shuffle_and_repeat(len(images), None))
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
# Make the iterator
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
next_val = iterator.get_next()
with K.get_session().as_default() as sess:
sess.run(init_op, feed_dict={images_tensor: images, onehots_tensor: onehots})
while True:
inputs, labels = sess.run(next_val)
yield inputs, labels