PyTorch set_grad_enabled(False) vs with no_grad():
Actually no, there no difference in the way used in the question. When you take a look at the source code of no_grad
. You see that it is actually using torch.set_grad_enabled
to archive this behaviour:
class no_grad(object):
r"""Context-manager that disabled gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
Also functions as a decorator.
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
"""
def __init__(self):
self.prev = torch.is_grad_enabled()
def __enter__(self):
torch._C.set_grad_enabled(False)
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
return False
def __call__(self, func):
@functools.wraps(func)
def decorate_no_grad(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_no_grad
However there is an additional functionality of torch.set_grad_enabled
over torch.no_grad
when used in a with
-statement which lets you control to switch on or off gradient computation:
>>> x = torch.tensor([1], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
https://pytorch.org/docs/stable/_modules/torch/autograd/grad_mode.html
Edit:
@TomHale Regarding your comment. I just made a short test with PyTorch 1.0 and it turned out that the gradient will be active:
import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
torch.set_grad_enabled(False)
with torch.enable_grad():
scalar = w.sum()
scalar.backward()
# Gradient tracking will be enabled here.
torch.set_grad_enabled(True)
print('Grad After:', w.grad)
Output:
Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])
So gradients will be computed in this setting.
The other setting you posted in your answer also yields to the same result:
import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
with torch.no_grad():
with torch.enable_grad():
# Gradient tracking IS enabled here.
scalar = w.sum()
scalar.backward()
print('Grad After:', w.grad)
Output:
Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])
The torch.autograd.enable_grad
documentation says:
Enables gradient calculation inside a
no_grad
context. This has no effect outside ofno_grad
.
Given this wording, the following is expected:
torch.set_grad_enabled(False)
with torch.enable_grad:
# Gradient tracking will NOT be enabled here.
torch.set_grad_enabled(True)
vs:
with torch.no_grad():
with torch.enable_grad:
# Gradient tracking IS enabled here.
But as blue-phoenix shows, this is not the case.
I raised an issue here.