pytorch how to set .requires_grad False
Here is the way;
linear = nn.Linear(1,1)
for param in linear.parameters():
param.requires_grad = False
with torch.no_grad():
linear.eval()
print(linear.weight.requires_grad)
OUTPUT: False
To complete @Salih_Karagoz's answer, you also have the torch.set_grad_enabled()
context (further documentation here), which can be used to easily switch between train/eval modes:
linear = nn.Linear(1,1)
is_train = False
for param in linear.parameters():
param.requires_grad = is_train
with torch.set_grad_enabled(is_train):
linear.eval()
print(linear.weight.requires_grad)
Nice. The trick is to check that when you define a Linear layar, by default the parameters will have requires_grad=True
, because we would like to learn, right?
l = nn.Linear(1, 1)
p = l.parameters()
for _ in p:
print (_)
# Parameter containing:
# tensor([[-0.3258]], requires_grad=True)
# Parameter containing:
# tensor([0.6040], requires_grad=True)
The other construct,
with torch.no_grad():
Means you cannot learn in here.
So your code, just shows you are capable of learning, even though you are in torch.no_grad()
where learning is forbidden.
with torch.no_grad():
linear = nn.Linear(1, 1)
linear.eval()
print(linear.weight.requires_grad) #true
If you really plan to turn off requires_grad
for the weight parameter, you can do it also with:
linear.weight.requires_grad_(False)
or
linear.weight.requires_grad = False
So your code may become like this:
with torch.no_grad():
linear = nn.Linear(1, 1)
linear.weight.requires_grad_(False)
linear.eval()
print(linear.weight.requires_grad)
If you plan to switch to requires_grad for all params in a module:
l = nn.Linear(1, 1)
for _ in l.parameters():
_.requires_grad_(False)
print(_)
requires_grad=False
If you want to freeze part of your model and train the rest, you can set requires_grad
of the parameters you want to freeze to False
.
For example, if you only want to keep the convolutional part of VGG16 fixed:
model = torchvision.models.vgg16(pretrained=True)
for param in model.features.parameters():
param.requires_grad = False
By switching the requires_grad
flags to False
, no intermediate buffers will be saved, until the computation gets to some point where one of the inputs of the operation requires the gradient.
torch.no_grad()
Using the context manager torch.no_grad
is a different way to achieve that goal: in the no_grad
context, all the results of the computations will have requires_grad=False
, even if the inputs have requires_grad=True
. Notice that you won't be able to backpropagate the gradient to layers before the no_grad
. For example:
x = torch.randn(2, 2)
x.requires_grad = True
lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
x1 = lin0(x)
with torch.no_grad():
x2 = lin1(x1)
x3 = lin2(x2)
x3.sum().backward()
print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)
outputs:
(None, None, tensor([[-1.4481, -1.1789],
[-1.4481, -1.1789]]))
Here lin1.weight.requires_grad
was True, but the gradient wasn't computed because the oepration was done in the no_grad
context.
model.eval()
If your goal is not to finetune, but to set your model in inference mode, the most convenient way is to use the torch.no_grad
context manager. In this case you also have to set your model to evaluation mode, this is achieved by calling eval()
on the nn.Module
, for example:
model = torchvision.models.vgg16(pretrained=True)
model.eval()
This operation sets the attribute self.training
of the layers to False
, in practice this will change the behavior of operations like Dropout
or BatchNorm
that must behave differently at training and test time.