pytorch get all layers of model
You can iterate over all modules of a model (including those inside each Sequential
) with the modules()
method. Here's a simple example:
>>> model = nn.Sequential(nn.Linear(2, 2),
nn.ReLU(),
nn.Sequential(nn.Linear(2, 1),
nn.Sigmoid()))
>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]
>>> l
[Linear(in_features=2, out_features=2, bias=True),
ReLU(),
Linear(in_features=2, out_features=1, bias=True),
Sigmoid()]
In case you want the layers in a named dict
, this is the simplest way:
named_layers = dict(model.named_modules())
This returns something like:
{
'conv1': <some conv layer>,
'fc1': < some fc layer>,
### and other layers
}
Example:
import torchvision.models as models
model = models.inception_v3(pretrained = True)
named_layers = dict(model.named_modules())
I netted it for a deeper model and not all blocks were from nn.sequential.
def get_children(model: torch.nn.Module):
# get children form model!
children = list(model.children())
flatt_children = []
if children == []:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
return flatt_children