Creating one hot vector from indices given as a tensor

The easiest way I found. Where x is a list of numbers and class_count is the amount of classes you have.

def one_hot(x, class_count):
    return torch.eye(class_count)[x,:]

Use it like this:

x = [0,2,5,4]
class_count = 8
one_hot(x,class_count)
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])



NEW ANSWER As of PyTorch 1.1, there is a one_hot function in torch.nn.functional. Given any tensor of indices indices and a maximal index n, you can create a one_hot version as follows:

n = 5
indices = torch.randint(0,n, size=(4,7))
one_hot = torch.nn.functional.one_hot(indices, n) # size=(4,7,n)

Very old Answer

At the moment, slicing and indexing can be a bit of a pain in PyTorch from my experience. I assume you don't want to convert your tensors to numpy arrays. The most elegant way I can think of at the moment is to use sparse tensors and then convert to a dense tensor. That would work as follows:

from torch.sparse import FloatTensor as STensor

batch_size = 4
seq_length = 6
feat_dim = 16

batch_idx = torch.LongTensor([i for i in range(batch_size) for s in range(seq_length)])
seq_idx = torch.LongTensor(list(range(seq_length))*batch_size)
feat_idx = torch.LongTensor([[5, 3, 2, 11, 15, 15], [1, 4, 6, 7, 3, 3],                            
                             [2, 4, 7, 8, 9, 10], [11, 12, 15, 2, 5, 7]]).view(24,)

my_stack = torch.stack([batch_idx, seq_idx, feat_idx]) # indices must be nDim * nEntries
my_final_array = STensor(my_stack, torch.ones(batch_size * seq_length), 
                         torch.Size([batch_size, seq_length, feat_dim])).to_dense()    

print(my_final_array)

Note: PyTorch is undergoing some work currently, that will add numpy style broadcasting and other functionalities within the next two or three weeks and other functionalities. So it's possible, there'll be better solutions available in the near future.

Hope this helps you a bit.

Tags:

Pytorch