Implementation of the latest "Lookahead Optimizer" paper in Keras?

For demonstrating the concept behind, one might implement the Lookahead Optimizer in a Keras callback, see my implementation here

def on_train_batch_end(self, batch, logs=None):
    self.count += 1
    if self.slow_weights is None:
        self.slow_weights = self.model.trainable_weights
        if self.count % self.k == 0:
            slow_ups, fast_ups = [], []
            for fast, slow in zip(self.model.trainable_weights,
                slow_ups.append(K.update(slow, slow + self.alpha * (fast - slow)))
                fast_ups.append(K.update(fast, slow))

What this does is conceptually embarrassingly simple - every k updates the weights would be moved halfway (alpha=0.5) towards what their value was k iterations ago.

N.B. The above implementation might not work that well on a GPU or TPU, as the slow_weights copy of the weights would probably get updated on the CPU (and moving the weights takes time).

EDIT (2020.03): There is an official implementation in tensorflow!

Today when I want to start implementing it, I found that somebody has already done it! (Of course, when I asked this question, it cannot be Googled.)

Here is the link: (For the non-Chinese readers, I have slightly modified the repo:

And the usage is like:

model.compile(optimizer=Adam(1e-3), loss='mse') # Any optimizer
lookahead = Lookahead(k=5, alpha=0.5) # Initialize Lookahead
lookahead.inject(model) # add into model

Looking into his code, the core of the implementation is the modification of the model.train_function, i.e. model.train_function = ..., to achieve the two sets of updates.

In addition, it seems that the "hacking" trick of the repo comes from the following article (judging from his code and comments):