pytorch get all layers of model

You can iterate over all modules of a model (including those inside each Sequential) with the modules() method. Here's a simple example:

>>> model = nn.Sequential(nn.Linear(2, 2), 
                          nn.ReLU(),
                          nn.Sequential(nn.Linear(2, 1),
                          nn.Sigmoid()))

>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]

>>> l

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Linear(in_features=2, out_features=1, bias=True),
 Sigmoid()]

In case you want the layers in a named dict, this is the simplest way:

named_layers = dict(model.named_modules())

This returns something like:

{
    'conv1': <some conv layer>,
    'fc1': < some fc layer>,
     ### and other layers 
}

Example:

import torchvision.models as models

model = models.inception_v3(pretrained = True)
named_layers = dict(model.named_modules())

I netted it for a deeper model and not all blocks were from nn.sequential.

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children

Tags:

Python

Pytorch