Why do we need to explicitly call zero_grad()?
There is a cycle in PyTorch:
- Forward when we get output or
y_hat
from the input, - Calculating loss where
loss = loss_fn(y_hat, y)
loss.backward
when we calculate the gradientsoptimizer.step
when we update parameters
Or in code:
for mb in range(10): # 10 mini batches
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
If we would not clear the gradients after the optimizer.step
, which is the appropriate step or just before the next backward()
gradients would accumulate.
Here is an example showing accumulation:
import torch
w = torch.rand(5)
w.requires_grad_()
print(w)
s = w.sum()
s.backward()
print(w.grad) # tensor([1., 1., 1., 1., 1.])
s.backward()
print(w.grad) # tensor([2., 2., 2., 2., 2.])
s.backward()
print(w.grad) # tensor([3., 3., 3., 3., 3.])
s.backward()
print(w.grad) # tensor([4., 4., 4., 4., 4.])
loss.backward()
does not have any way specifying this.
torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
From all the options you can specify there is no way to zero the gradients manually. Like this in previous mini example:
w.grad.zero_()
There was some discussion on doing zero_grad()
every time with backward()
(obviously previous gradients) and to keep grads with preserve_grads=True
, but this never came to life.
I have a use case for the current setup in PyTorch.
If one is using a recurrent neural network (RNN) that is making predictions at every step, one might want to have a hyperparameter that allows one to accumulate gradients back in time. Not zeroing the gradients at every time step allows for one to use back-propagating through time (BPTT) in interesting and novel ways.
If you would like more info on BPTT or RNNs see the article Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients or The Unreasonable Effectiveness of Recurrent Neural Networks.
We explicitly need to call zero_grad()
because, after loss.backward()
(when gradients are computed), we need to use optimizer.step()
to proceed gradient descent. More specifically, the gradients are not automatically zeroed because these two operations, loss.backward()
and optimizer.step()
, are separated, and optimizer.step()
requires the just computed gradients.
In addition, sometimes, we need to accumulate gradient among some batches; to do that, we can simply call backward
multiple times and optimize once.