How to Implement Center Loss and Other Running Averages of Labeled Embeddings
The previously posted method is too simple for cases like center loss where the expected value of the embeddings change over time as the model becomes more refined. This is because the previous center-finding routine averages all instances since start and therefore tracks changes in expected value very slowly. Instead, a moving window average is preferred. An exponential moving-window variant is as follows:
def get_embed_centers(embed_batch, label_batch):
''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
'''
decay = 0.95
with tf.variable_scope('embed', reuse=True):
embed_ctrs = tf.get_variable("ctrs")
label_batch = tf.reshape(label_batch, [-1])
old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
return embed_ctrs_batch
with tf.Session() as sess:
with tf.variable_scope('embed'):
embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label_batch_ph = tf.placeholder(tf.int32)
embed_batch_ph = tf.placeholder(tf.float32)
embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
sess.run(tf.initialize_all_variables())
tf.get_default_graph().finalize()
The get_new_centers()
routine below takes in labelled embeddings and updates shared variables center/sums
and center/cts
. These variables are then used to calculate and return the embedding centers using the updated values.
The loop just exercises get_new_centers()
and shows that it converges to the expected average embeddings for all classes over time.
Note that the alpha
term used in the original paper isn't included here but should be straightforward to add if needed.
ndims = 2
nclass = 4
nbatch = 100
with tf.variable_scope('center'):
center_sums = tf.get_variable("sums", [nclass, ndims], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
center_cts = tf.get_variable("cts", [nclass], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
def get_new_centers(embeddings, indices):
'''
Update embedding for selected class indices and return the new average embeddings.
Only the newly-updated average embeddings are returned corresponding to
the indices (including duplicates).
'''
with tf.variable_scope('center', reuse=True):
center_sums = tf.get_variable("sums")
center_cts = tf.get_variable("cts")
# update embedding sums, cts
if embeddings is not None:
ones = tf.ones_like(indices, tf.float32)
center_sums = tf.scatter_add(center_sums, indices, embeddings, name='sa1')
center_cts = tf.scatter_add(center_cts, indices, ones, name='sa2')
# return updated centers
num = tf.gather(center_sums, indices)
denom = tf.reshape(tf.gather(center_cts, indices), [-1, 1])
return tf.div(num, denom)
with tf.Session() as sess:
labels_ph = tf.placeholder(tf.int32)
embeddings_ph = tf.placeholder(tf.float32)
unq_labels, ul_idxs = tf.unique(labels_ph)
indices = tf.gather(unq_labels, ul_idxs)
new_centers_with_update = get_new_centers(embeddings_ph, indices)
new_centers = get_new_centers(None, indices)
sess.run(tf.initialize_all_variables())
tf.get_default_graph().finalize()
for i in range(100001):
embeddings = 100*np.random.randn(nbatch, ndims)
labels = np.random.randint(0, nclass, nbatch)
feed_dict = {embeddings_ph:embeddings, labels_ph:labels}
rval = sess.run([new_centers_with_update], feed_dict)
if i % 1000 == 0:
feed_dict = {labels_ph:range(nclass)}
rval = sess.run(new_centers, feed_dict)
print('\nFor step ', i)
for iclass in range(nclass):
print('Class %d, center: %s' % (iclass, str(rval[iclass])))
A typical result at step 0 is:
For step 0
Class 0, center: [-1.7618252 -0.30574229]
Class 1, center: [ -4.50493908 10.12403965]
Class 2, center: [ 3.6156714 -9.94263649]
Class 3, center: [-4.20281982 -8.28845882]
and the output at step 10,000 demonstrates convergence:
For step 10000
Class 0, center: [ 0.00313433 -0.00757505]
Class 1, center: [-0.03476512 0.04682625]
Class 2, center: [-0.03865958 0.06585111]
Class 3, center: [-0.02502561 -0.03370816]