How do I pass a keyword argument to the forward used by a pre-forward hook?
Torchscript incompatible (as of 1.2.0
)
First of all, your example torch.nn.Module
has some minor mistakes (probably by an accident).
Secondly, you can pass anything to forward and register_forward_pre_hook
will just get the argument that will be passed your your torch.nn.Module
(be it layer or model or anything) else. You indeed cannot do it without modifying forward
call, but why would you want to avoid that? You could simply forward the arguments to base function as can be seen below:
import torch
class NeoEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)
# First argument should be named something like module, as that's what
# you are registering this hook to
@staticmethod
def neo_genesis(module, inputs): # No need for self as first argument
net_input, higgs_bosson = inputs # Simply unpack tuple here
return net_input
def forward(self, inputs, higgs_bosson):
# Do whatever you want here with both arguments, you can ignore
# higgs_bosson if it's only needed in the hook as done here
return super().forward(inputs)
if __name__ == "__main__":
x = NeoEmbeddings(10, 5, 1)
# You should call () instead of forward so the hooks register appropriately
print(x(torch.tensor([0, 2, 5, 8]), 1))
You can't do it in more succinct way, but the limitation is base's class forward
method, not the hook itself (and tbh I wouldn't want it to be more succinct as it would become unreadable IMO).
Torchscript compatible
If you want to use torchscript (tested on 1.2.0
) you could use composition instead of inheritance. All you have to change are merely two lines and your code may look something like this:
import torch
# Inherit from Module and register embedding as submodule
class NeoEmbeddings(torch.nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
super().__init__()
# Just use it as a container inside your own class
self._embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx)
self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)
@staticmethod
def neo_genesis(module, inputs):
net_input, higgs_bosson = inputs
return net_input
def forward(self, inputs: torch.Tensor, higgs_bosson: torch.Tensor):
return self._embedding(inputs)
if __name__ == "__main__":
x = torch.jit.script(NeoEmbeddings(10, 5, 1))
# All arguments must be tensors in torchscript
print(x(torch.tensor([0, 2, 5, 8]), torch.tensor([1])))
Since a forward pre-hook is called with only the tensor by definition, a keyword argument doesn't make much sense here. What would make more sense is to use an instance attribute for example:
def neo_genesis(self, input):
if self.higgs_bosson:
input = input + self.higgs_bosson
return input
Then you can switch that attribute as appropriate. You could also use a context manager for that:
from contextlib import contextmanager
@contextmanager
def HiggsBoson(module):
module.higgs_boson = 1
yield
module.higgs_boson = 0
with HiggsBoson(x):
x.forward(...)
If you have that function already and you really need to change that parameter you can still replace the function's __defaults__
attribute:
x.neo_genesis.__defaults__ = (1,) # this corresponds to `higgs_boson` parameter
x.forward(...)
x.neo_genesis.__defaults__ = (0,) # reset to default