What does model.train() do in PyTorch?
model.train()
tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.
More details:
model.train()
sets the mode to train
(see source code). You can call either model.eval()
or model.train(mode=False)
to tell that you are testing.
It is somewhat intuitive to expect train
function to train model but it does not do that. It just sets the mode.
Here is the code for nn.Module.train()
:
def train(self, mode=True):
r"""Sets the module in training mode."""
self.training = mode
for module in self.children():
module.train(mode)
return self
Here is the code for nn.Module.eval()
:
def eval(self):
r"""Sets the module in evaluation mode."""
return self.train(False)
By default, the self.training
flag is set to True
, i.e., modules are in train mode by default. When self.training
is False
, the module is in the opposite state, eval mode.
Of the most commonly used layers, only Dropout
and BatchNorm
care about that flag.