How to batch convert sentence lengths to masks in PyTorch?

torch.arange(max_len)[None, :] < lens[:, None]


One way that I found is:

torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)

Please share if there are better ways!


Just to provide a bit of explanation to the answer of @ypc (cannot comment due to lack of reputation):

torch.arange(max_len)[None, :] < lens[:, None]

In a word, the answer uses broadcasting mechanism to implicitly expand the tensor, as done in the accepted answer. Step-by-step:

  1. torch.arange(max_len) gives you [0, 1, 2, 3, 4];

  2. adding [None, :] appends 0th dimension to the tensor, making its shape (1, 5), which gives you [[0, 1, 2, 3, 4]];

  3. similarly, lens[:, None] appends 1st dimension to the tensor lens, making its shape (3, 1), that is [[3], [5], [4]];

  4. By comparing (or doing anything like +,-,*,/, etc) a tensor of (1, 5) and (3, 1), following the rule of broadcasting, the resulting tensor will be of shape (3, 5), and the result values will be result[i, j] = (j < lens[i]).

Tags:

Nlp

Pytorch