PyTorch: manually setting weight parameters with numpy array for GRU / LSTM
If you want to set a certain weight/bias (or a few) I like doing:
model.state_dict()["your_weight_names_here"][:] = torch.Tensor(your_numpy_array)
That is a good question, and you already give a decent answer. However, it reinvents the wheel - there is a very elegant Pytorch internal routine that will allow you to do the same without as much effort - and one that is applicable for any network.
The core concept here is PyTorch's state_dict
. The state dictionary effectively contains the parameters
organized by the tree-structure given by the relationship of the nn.Modules
and their submodules, etc.
The short answer
If you only want the code to load a value into a tensor using the state_dict
, then try this line (where the dict
contains a valid state_dict
):
`model.load_state_dict(dict, strict=False)`
where strict=False
is crucial if you want to load only some parameter values.
The long answer - including an introduction to PyTorch's state_dict
Here's an example of how a state dict looks for a GRU (I chose input_size = hidden_size = 2
so that I can print the entire state dict):
rnn = torch.nn.GRU(2, 2, 1)
rnn.state_dict()
# Out[10]:
# OrderedDict([('weight_ih_l0', tensor([[-0.0023, -0.0460],
# [ 0.3373, 0.0070],
# [ 0.0745, -0.5345],
# [ 0.5347, -0.2373],
# [-0.2217, -0.2824],
# [-0.2983, 0.4771]])),
# ('weight_hh_l0', tensor([[-0.2837, -0.0571],
# [-0.1820, 0.6963],
# [ 0.4978, -0.6342],
# [ 0.0366, 0.2156],
# [ 0.5009, 0.4382],
# [-0.7012, -0.5157]])),
# ('bias_ih_l0',
# tensor([-0.2158, -0.6643, -0.3505, -0.0959, -0.5332, -0.6209])),
# ('bias_hh_l0',
# tensor([-0.1845, 0.4075, -0.1721, -0.4893, -0.2427, 0.3973]))])
So the state_dict
all the parameters of the network. If we have "nested" nn.Modules
, we get the tree represented by the parameter names:
class MLP(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.lin_a = torch.nn.Linear(2, 2)
self.lin_b = torch.nn.Linear(2, 2)
mlp = MLP()
mlp.state_dict()
# Out[23]:
# OrderedDict([('lin_a.weight', tensor([[-0.2914, 0.0791],
# [-0.1167, 0.6591]])),
# ('lin_a.bias', tensor([-0.2745, -0.1614])),
# ('lin_b.weight', tensor([[-0.4634, -0.2649],
# [ 0.4552, 0.3812]])),
# ('lin_b.bias', tensor([ 0.0273, -0.1283]))])
class NestedMLP(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.mlp_a = MLP()
self.mlp_b = MLP()
n_mlp = NestedMLP()
n_mlp.state_dict()
# Out[26]:
# OrderedDict([('mlp_a.lin_a.weight', tensor([[ 0.2543, 0.3412],
# [-0.1984, -0.3235]])),
# ('mlp_a.lin_a.bias', tensor([ 0.2480, -0.0631])),
# ('mlp_a.lin_b.weight', tensor([[-0.4575, -0.6072],
# [-0.0100, 0.5887]])),
# ('mlp_a.lin_b.bias', tensor([-0.3116, 0.5603])),
# ('mlp_b.lin_a.weight', tensor([[ 0.3722, 0.6940],
# [-0.5120, 0.5414]])),
# ('mlp_b.lin_a.bias', tensor([0.3604, 0.0316])),
# ('mlp_b.lin_b.weight', tensor([[-0.5571, 0.0830],
# [ 0.5230, -0.1020]])),
# ('mlp_b.lin_b.bias', tensor([ 0.2156, -0.2930]))])
So - what if you want to not extract the state dict, but change it - and thereby the network's parameters? Use nn.Module.load_state_dict(state_dict, strict=True)
(link to the docs)
This method allows you to load an entire state_dict with arbitrary values into an instantiated model of the same kind as long as the keys (i.e. the parameter names) are correct and the values (i.e. the parameters) are torch.tensors
of the right shape.
If the strict
kwarg is set to True
(the default), the dict you load has to exactly match the original state dict, except for the values of the parameters. That is, there has to be one new value for each parameter.
For the GRU example above, we need a tensor of the correct size (and the correct device, btw) for each of 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'
. As we sometimes only want to load some values (as I think you want to do), we can set the strict
kwarg to False
- and we can then load only partial state dicts, as e.g. one that only contains parameter values for 'weight_ih_l0'
.
As a practical advice, I'd simply create the model you want to load values into, and then print the state dict (or at least a list of the keys and the respective tensor sizes)
print([k, v.shape for k, v in model.state_dict().items()])
That tells you what the exact name of the parameter is you want to change. You then simply create a state dict with the respective parameter name and tensor, and load it:
from dollections import OrderedDict
new_state_dict = OrderedDict({'tensor_name_retrieved_from_original_dict': new_tensor_value})
model.load_state_dict(new_state_dict, strict=False)