Batch Normalization in tf.keras does not calculate average mean and average variance
This is because tf.keras.layers.BatchNormalization
inherits from tf.keras.layers.Layer
. Keras API handle update ops as part of its fit and evaluate loops. This in turn means that it won't update tf.GraphKeys.UPDATE_OPS
collection without it.
So in order to make it work, you need to update it manually
hidden = tf.keras.layers.Dense(units, activation=None)(out)
batch_normed = tf.keras.layers.BatchNormalization(trainable=True)
layer = batch_normed(hidden)
This creates separate class instance
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, batch_normed.updates)
And this updates needed collection. Also take a look https://github.com/tensorflow/tensorflow/issues/25525
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
this can solve
tf.control_dependencies(update_ops)
error problem.
if use
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, batch_normed.updates)
the return of
tf.get_collection(tf.GraphKeys.UPDATE_OPS)
is a list in list just like [[something]]
and use
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
the return of
tf.get_collection(tf.GraphKeys.UPDATE_OPS)
is [something1,something2,...]
i thinks this is the solution.
but the out put is different,and i don't know which is true.