What does the gather function do in pytorch in layman terms?

This is based on @Ritesh answer (thanks @Ritesh!) with some real codes.

The torch.gather API is

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Example 1

When dim = 0,

enter image description here

dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]]

output = torch.gather(input, dim, index))
# tensor([[10, 14, 18],
#         [13, 17, 12]])

Example 2

When dim = 1,

enter image description here

dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]]

output = torch.gather(input, dim, index))
# tensor([[10, 11],
#         [14, 15],
#         [18, 16]])

@Ritesh and @cleros gave great answers (with lots of upvotes), but after reading them I was still a bit confused, and I know why. This post will perhaps help folks like me.

For these sorts of exercises with rows and columns I think it really helps to use a non-square object, so let's start with a larger 4x3 source (torch.Size([4, 3])) using source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]]). This will give us

\\ This is the source tensor
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

Now let's start indexing along the columns (dim=1) and create index = torch.tensor([[0,0],[1,1],[2,2],[0,1]]), which is a list of lists. Here's the key: since our dimension is columns, and the source has 4 rows, the index must contain 4 lists! We need a list for each row. Running source.gather(dim=1, index=index) will give us

tensor([[ 1,  1],
        [ 5,  5],
        [ 9,  9],
        [10, 11]])

So, each list within index gives us the columns from which to pull the values. The 1st list of the index ([0,0]) is telling us to take to look at the 1st row of the source and take the 1st column of that row (it's zero-indexed) twice, which is [1,1]. The 2nd list of the index ([1,1]) is telling us to take to look at the 2nd row of source and take the 2nd column of that row twice, which is [5,5]. Jumping to the 4th list of the index ([0,1]), which is asking us to look at the 4th and final row of the source, is asking us to take the 1st column (10) and then the 2nd column (11) which gives us [10,11].

Here's a nifty thing: each list of your index has to be the same length, but they may be as long as you like! For example, with index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]]), source.gather(dim=1, index=index) will give us

tensor([[ 1,  2,  3,  2,  1],
        [ 6,  5,  4,  5,  6],
        [ 8,  9,  7,  9,  8],
        [11, 10, 12, 10, 11]])

The output will always have the same number of rows as the source, but the number of columns will equal the length of each list in index. For example, the 2nd list of the index ([2,1,0,1,2]) is going to the 2nd row of the source and pulling, respectively, the 3rd, 2nd, 1st, 2nd and 3rd items, which is [6,5,4,5,6]. Note, the value of every element in index has to be less than the number of columns of source (in this case 3), otherwise you get an out of bounds error.

Switching to dim=0, we'll now be using the rows as opposed to the columns. Using the same source, we now need an index where the length of each list equals the number of columns in the source. Why? Because each element in the list represents the row from source as we move column by column.

Therefore, index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]]) will then have source.gather(dim=0, index=index) give us

tensor([[ 1,  2,  3],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])

Looking at the 1st list in the index ([0,0,0]), we can see that we're moving across the 3 columns of source picking the 1st element (it's zero-indexed) of each column, which is [1,2,3]. The 2nd list in the index ([0,1,2]) tells us to move across the columns taking the 1st, 2nd and 3rd items, respectively, which is [1,5,9]. And so on.

With dim=1 our index had to have a number of lists equal to the number of rows in the source, but each list could be as long, or short, as you like. With dim=0, each list in our index has to be the same length as the number of columns in the source, but we can now have as many lists as we like. Each value in index, however, needs to be less than the number of row in source (in this case 4).

For example, index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]]) would have source.gather(dim=0, index=index) give us

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])

With dim=1 the output always has the same number of rows as the source, although the number of columns will equal the length of the lists in index. The number of lists in index has to equal the number of rows in source. Each value in index, however, needs to be less than the number of columns in source.

With dim=0 the output always has the same number of columns as the source, but the number of rows will equal the number of lists in index. The length of each list in index has to equal the number of columns in source. Each value in index, however, needs to be less than the number of row in source.

That's it for two dimensions. Moving beyond that will follow the same patterns.


The torch.gather function (or torch.Tensor.gather) is a multi-index selection method. Look at the following example from the official docs:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

Let's start with going through the semantics of the different arguments: The first argument, input, is the source tensor that we want to select elements from. The second, dim, is the dimension (or axis in tensorflow/numpy) that we want to collect along. And finally, index are the indices to index input. As for the semantics of the operation, this is how the official docs explain it:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

So let's go through the example.

the input tensor is [[1, 2], [3, 4]], and the dim argument is 1, i.e. we want to collect from the second dimension. The indices for the second dimension are given as [0, 0] and [1, 0].

As we "skip" the first dimension (the dimension we want to collect along is 1), the first dimension of the result is implicitly given as the first dimension of the index. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the index tensor itself. For the example, this means that the output will have in its first row a selection of the elements of the input tensor's first row as well, as given by the first row of the index tensor's first row. As the column-indices are given by [0, 0], we therefore select the first element of the first row of the input twice, resulting in [1, 1]. Similarly, the elements of the second row of the result are a result of indexing the second row of the input tensor by the elements of the second row of the index tensor, resulting in [4, 3].

To illustrate this even further, let's swap the dimension in the example:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

As you can see, the indices are now collected along the first dimension.

For the example you referred,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather will index the rows of the q-values (i.e. the per-sample q-values in a batch of q-values) by the batch-list of actions. The result will be the same as if you had done the following (though it will be much faster than a loop):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

torch.gather creates a new tensor from the input tensor by taking the values from each row along the input dimension dim. The values in torch.LongTensor, passed as index, specify which value to take from each 'row'. The dimension of the output tensor is same as the dimension of index tensor. Following illustration from the official docs explains it more clearly: Pictoral representation from the docs

(Note: In the illustration, indexing starts from 1 and not 0).

In first example, the dimension given is along rows (top to bottom), so for (1,1) position of result, it takes row value from the index for the src that is 1. At (1,1) in source value is 1 so, outputs 1 at (1,1) in result. Similarly for (2,2) the row value from the index for src is 3. At (3,2) the value in src is 8 and hence outputs 8 and so on.

Similarly for second example, indexing is along columns, and hence at (2,2) position of the result, the column value from the index for src is 3, so at (2,3) from src ,6 is taken and outputs to result at (2,2)

Tags:

Pytorch