How does pytorch backprop through argmax?
Imagine this:
t = torch.tensor([-0.0627, 0.1373, 0.0616, -1.7994, 0.8853,
-0.0656, 1.0034, 0.6974, -0.2919, -0.0456])
torch.argmax(t).item() # outputs 6
We increase t[0]
for some, δ close to 0, will this update the argmax? It will not, so we are dealing with 0 gradients, all the time. Just ignore this layer, or assume it is frozen.
The same is for argmin
, or any other function where the dependent variable is in discrete steps.
As alvas noted in the comments, argmax
is not differentiable. However, once you compute it and assign each datapoint to a cluster, the derivative of loss with respect to the location of these clusters is well-defined. This is what your algorithm does.
Why does it work? If you had only one cluster (so that the argmax
operation didn't matter), your loss function would be quadratic, with minimum at the mean of the data points. Now with multiple clusters, you can see that your loss function is piecewise (in higher dimensions think volumewise) quadratic - for any set of centroids [C1, C2, C3, ...]
each data point is assigned to some centroid CN
and the loss is locally quadratic. The extent of this locality is given by all alternative centroids [C1', C2', C3', ...]
for which the assignment coming from argmax
remains the same; within this region the argmax
can be treated as a constant, rather than a function and thus the derivative of loss
is well-defined.
Now, in reality, it's unlikely you can treat argmax
as constant, but you can still treat the naive "argmax-is-a-constant" derivative as pointing approximately towards a minimum, because the majority of data points are likely to indeed belong to the same cluster between iterations. And once you get close enough to a local minimum such that the points no longer change their assignments, the process can converge to a minimum.
Another, more theoretical way to look at it is that you're doing an approximation of expectation maximization. Normally, you would have the "compute assignments" step, which is mirrored by argmax
, and the "minimize" step which boils down to finding the minimizing cluster centers given the current assignments. The minimum is given by d(loss)/d([C1, C2, ...]) == 0
, which for a quadratic loss is given analytically by the means of data points within each cluster. In your implementation, you're solving the same equation but with a gradient descent step. In fact, if you used a 2nd order (Newton) update scheme instead of 1st order gradient descent, you would be implicitly reproducing exactly the baseline EM scheme.