Check the total number of parameters in a PyTorch model
If you want to calculate the number of weights and biases in each layer without instantiating the model, you can simply load the raw file and iterate over the resulting collections.OrderedDict
like so:
import torch
tensor_dict = torch.load('model.dat', map_location='cpu') # OrderedDict
tensor_list = list(tensor_dict.items())
for layer_tensor_name, tensor in tensor_list:
print('Layer {}: {} elements'.format(layer_tensor_name, torch.numel(tensor)))
You'll get something like
conv1.weight: 312
conv1.bias: 26
batch_norm1.weight: 26
batch_norm1.bias: 26
batch_norm1.running_mean: 26
batch_norm1.running_var: 26
conv2.weight: 2340
conv2.bias: 10
batch_norm2.weight: 10
batch_norm2.bias: 10
batch_norm2.running_mean: 10
batch_norm2.running_var: 10
fcs.layers.0.weight: 135200
fcs.layers.0.bias: 260
fcs.layers.1.weight: 33800
fcs.layers.1.bias: 130
fcs.batch_norm_layers.0.weight: 260
fcs.batch_norm_layers.0.bias: 260
fcs.batch_norm_layers.0.running_mean: 260
fcs.batch_norm_layers.0.running_var: 260
To avoid double counting shared parameters, use torch.Tensor.data_ptr
. E.g.:
sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
Here's a more verbose implementation that can optionally filter out non-trainable parameters:
def numel(m: torch.nn.Module, only_trainable: bool = False):
"""
Returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)
To get the parameter count of each layer like Keras, PyTorch has model.named_paramters()
that returns an iterator of both the parameter name and the parameter itself. Example:
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
count_parameters(net)
Example output:
+-------------------+------------+
| Modules | Parameters |
+-------------------+------------+
| embeddings.weight | 922866 |
| conv1.weight | 1048576 |
| conv1.bias | 1024 |
| bn1.weight | 1024 |
| bn1.bias | 1024 |
| conv2.weight | 2097152 |
| conv2.bias | 1024 |
| bn2.weight | 1024 |
| bn2.bias | 1024 |
| conv3.weight | 2097152 |
| conv3.bias | 1024 |
| bn3.weight | 1024 |
| bn3.bias | 1024 |
| lin1.weight | 50331648 |
| lin1.bias | 512 |
| lin2.weight | 265728 |
| lin2.bias | 519 |
+-------------------+------------+
Total Trainable Params: 56773369
PyTorch doesn't have a function to calculate the total number of parameters as Keras does, but it's possible to sum the number of elements for every parameter group:
pytorch_total_params = sum(p.numel() for p in model.parameters())
If you want to calculate only the trainable parameters:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
Answer inspired by this answer on PyTorch Forums.
Note: I'm answering my own question. If anyone has a better solution, please share with us.