Training a tf.keras model with a basic low-level TensorFlow training loop doesn't work
Replacing the low-level TF loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))
by its Keras equivalent
loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=labels, output=model_output, from_logits=True))
does the trick. Now the low-level TensorFlow training loop behaves just like model.fit()
.
However, I don't know why this is. If anyone knows why tf.keras.backend.categorical_crossentropy()
behaves well while tf.nn.softmax_cross_entropy_with_logits_v2()
doesn't work at all, please post an answer.
Another important note:
In order to train a tf.keras
model with a low-level TF training loop and a tf.data.Dataset
object, one generally shouldn't call the model on the iterator output. That is, one shouldn't do this:
model_output = model(features)
Instead, one should create a model in which the input layer is set to build on the iterator output instead of creating a placeholder, like so:
input_tensor = tf.keras.layers.Input(tensor=features)
This doesn't matter in this example, but it becomes relevant if any layers in the model have internal updates that need to be run during the training (e.g. BatchNormalization).