pytorch - connection between loss.backward() and optimizer.step()
When you call loss.backward()
, all it does is compute gradient of loss w.r.t all the parameters in loss that have requires_grad = True
and store them in parameter.grad
attribute for every parameter.
optimizer.step()
updates all the parameters based on parameter.grad
Perhaps this will clarify a little the connection between loss.backward
and optim.step
(although the other answers are to the point).
# Our "model"
x = torch.tensor([1., 2.], requires_grad=True)
y = 100*x
# Compute loss
loss = y.sum()
# Compute gradient of the loss w.r.t. to the parameters
print(x.grad) # None
loss.backward()
print(x.grad) # tensor([100., 100.])
# MOdify the parameters by subtracting the gradient
optim = torch.optim.SGD([x], lr=0.001)
print(x) # tensor([1., 2.], requires_grad=True)
optim.step()
print(x) # tensor([0.9000, 1.9000], requires_grad=True)
loss.backward()
sets the grad
attribute of all tensors with requires_grad=True
in the computational graph of which loss is the leaf (only x
in this case).
Optimizer just iterates through the list of parameters (tensors) it received on initialization and everywhere where a tensor has requires_grad=True
, it subtracts the value of its gradient stored in its .grad
property (simply multiplied by the learning rate in case of SGD). It doesn't need to know with respect to what loss the gradients were computed it just wants to access that .grad
property so it can do x = x - lr * x.grad
Note that if we were doing this in a train loop we would call optim.zero_grad()
because in each train step we want to compute new gradients - we don't care about gradients from the previous batch. Not zeroing grads would lead to gradient accumulation across batches.
Without delving too deep into the internals of pytorch, I can offer a simplistic answer:
Recall that when initializing optimizer
you explicitly tell it what parameters (tensors) of the model it should be updating. The gradients are "stored" by the tensors themselves (they have a grad
and a requires_grad
attributes) once you call backward()
on the loss. After computing the gradients for all tensors in the model, calling optimizer.step()
makes the optimizer iterate over all parameters (tensors) it is supposed to update and use their internally stored grad
to update their values.
More info on computational graphs and the additional "grad" information stored in pytorch tensors can be found in this answer.
Referencing the parameters by the optimizer can sometimes cause troubles, e.g., when the model is moved to GPU after initializing the optimizer. Make sure you are done setting up your model before constructing the optimizer. See this answer for more details.