Getting around tf.argmax which is not differentiable

If you are cool with approximates,

import tensorflow as tf
import numpy as np

sess = tf.Session()
x = tf.placeholder(dtype=tf.float32, shape=(None,))
beta = tf.placeholder(dtype=tf.float32)

# Pseudo-math for the below
# y = sum( i * exp(beta * x[i]) ) / sum( exp(beta * x[i]) )
y = tf.reduce_sum(tf.cumsum(tf.ones_like(x)) * tf.exp(beta * x) / tf.reduce_sum(tf.exp(beta * x))) - 1

print("I can compute the gradient", tf.gradients(y, x))

for run in range(10):
    data = np.random.randn(10)
    print(data.argmax(), sess.run(y, feed_dict={x:data/np.linalg.norm(data), beta:1e2}))

This is using a trick that computing the mean in low temperature environments gives to the approximate maximum of the probability space. Low temperature in this case correlates with beta being very large.

In fact, as beta approaches infinity, my algorithm will converge to the maximum (assuming the maximum is unique). Unfortunately, beta can't get too large before you have numerical errors and get NaN, but there are tricks to solve that I can go into if you care.

The output looks something like,

0 2.24459
9 9.0
8 8.0
4 4.0
4 4.0
8 8.0
9 9.0
6 6.0
9 8.99995
1 1.0

So you can see that it messes up in some spots, but often gets the right answer. Depending on your algorithm, this might be fine.


As aidan suggested, it's just a softargmax stretched to the limits by beta. We can use tf.nn.softmax to get around the numerical issues:

def softargmax(x, beta=1e10):
  x = tf.convert_to_tensor(x)
  x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)
  return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range, axis=-1)