How to batch convert sentence lengths to masks in PyTorch?
torch.arange(max_len)[None, :] < lens[:, None]
One way that I found is:
torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)
Please share if there are better ways!
Just to provide a bit of explanation to the answer of @ypc (cannot comment due to lack of reputation):
torch.arange(max_len)[None, :] < lens[:, None]
In a word, the answer uses broadcasting
mechanism to implicitly expand
the tensor, as done in the accepted answer. Step-by-step:
torch.arange(max_len) gives you
[0, 1, 2, 3, 4]
;adding
[None, :]
appends 0th dimension to the tensor, making its shape(1, 5)
, which gives you[[0, 1, 2, 3, 4]]
;similarly,
lens[:, None]
appends 1st dimension to the tensorlens
, making its shape(3, 1)
, that is[[3], [5], [4]]
;By comparing (or doing anything like +,-,*,/, etc) a tensor of
(1, 5)
and(3, 1)
, following the rule ofbroadcasting
, the resulting tensor will be of shape(3, 5)
, and the result values will beresult[i, j] = (j < lens[i])
.