Learning rate of custom training loop for tensorflow 2.0
In Tensorflow 2.1, the Optimizer class has an undocumented method _decayed_lr
(see definition here), which you can invoke in the training loop by supplying the variable type to cast to:
current_learning_rate = optimizer._decayed_lr(tf.float32)
Here's a more complete example with TensorBoard too.
train_step_count = 0
summary_writer = tf.summary.create_file_writer('logs/')
def train_step(images, labels):
train_step_count += 1
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# optimizer._decayed_lr(tf.float32) is the current Learning Rate.
# You can save it to TensorBoard like so:
with summary_writer.as_default():
tf.summary.scalar('learning_rate',
optimizer._decayed_lr(tf.float32),
step=train_step_count)
In custom training loop setting, you can print(optimizer.lr.numpy())
to get the learning rate.
If you are using keras api, you can define your own callback that records the current learning rate.
from tensorflow.keras.callbacks import Callback
class LRRecorder(Callback):
"""Record current learning rate. """
def on_epoch_begin(self, epoch, logs=None):
lr = self.model.optimizer.lr
print("The current learning rate is {}".format(lr.numpy()))
# your other callbacks
callbacks.append(LRRecorder())
Update
w := w - (base_lr*m/sqrt(v))*grad = w - act_lr*grad
The learning rate we get above is the base_lr
. However, act_lr
is adaptive changed during training. Take Adam optimizer as an example, act_lr
is determined by base_lr
, m
and v
. m
and v
are the first and second momentums of parameters. Different parameters have different m
and v
values. So if you would like to know the act_lr
, you need to know the variable's name. For example, you want to know the act_lr
of the variable Adam/dense/kernel
, you can access the m
and v
like this,
for var in optimizer.variables():
if 'Adam/dense/kernel/m' in var.name:
print(var.name, var.numpy())
if 'Adam/dense/kernel/v' in var.name:
print(var.name, var.numpy())
Then you can easily calculate the act_lr
using above formula.