What does the gather function do in pytorch in layman terms?
This is based on @Ritesh answer (thanks @Ritesh!) with some real codes.
The torch.gather
API is
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
Example 1
When dim = 0
,
dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]]
output = torch.gather(input, dim, index))
# tensor([[10, 14, 18],
# [13, 17, 12]])
Example 2
When dim = 1
,
dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]]
output = torch.gather(input, dim, index))
# tensor([[10, 11],
# [14, 15],
# [18, 16]])
@Ritesh and @cleros gave great answers (with lots of upvotes), but after reading them I was still a bit confused, and I know why. This post will perhaps help folks like me.
For these sorts of exercises with rows and columns I think it really helps to use a non-square object, so let's start with a larger 4x3 source
(torch.Size([4, 3])
) using source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
. This will give us
\\ This is the source tensor
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
Now let's start indexing along the columns (dim=1
) and create index = torch.tensor([[0,0],[1,1],[2,2],[0,1]])
, which is a list of lists. Here's the key: since our dimension is columns, and the source has 4
rows, the index
must contain 4
lists! We need a list for each row. Running source.gather(dim=1, index=index)
will give us
tensor([[ 1, 1],
[ 5, 5],
[ 9, 9],
[10, 11]])
So, each list within index
gives us the columns from which to pull the values. The 1st list of the index
([0,0]
) is telling us to take to look at the 1st row of the source
and take the 1st column of that row (it's zero-indexed) twice, which is [1,1]
. The 2nd list of the index
([1,1]
) is telling us to take to look at the 2nd row of source
and take the 2nd column of that row twice, which is [5,5]
. Jumping to the 4th list of the index
([0,1]
), which is asking us to look at the 4th and final row of the source
, is asking us to take the 1st column (10
) and then the 2nd column (11
) which gives us [10,11]
.
Here's a nifty thing: each list of your index
has to be the same length, but they may be as long as you like! For example, with index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])
, source.gather(dim=1, index=index)
will give us
tensor([[ 1, 2, 3, 2, 1],
[ 6, 5, 4, 5, 6],
[ 8, 9, 7, 9, 8],
[11, 10, 12, 10, 11]])
The output will always have the same number of rows as the source
, but the number of columns will equal the length of each list in index
. For example, the 2nd list of the index
([2,1,0,1,2]
) is going to the 2nd row of the source
and pulling, respectively, the 3rd, 2nd, 1st, 2nd and 3rd items, which is [6,5,4,5,6]
. Note, the value of every element in index
has to be less than the number of columns of source
(in this case 3
), otherwise you get an out of bounds
error.
Switching to dim=0
, we'll now be using the rows as opposed to the columns. Using the same source
, we now need an index
where the length of each list equals the number of columns in the source
. Why? Because each element in the list represents the row from source
as we move column by column.
Therefore, index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])
will then have source.gather(dim=0, index=index)
give us
tensor([[ 1, 2, 3],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
Looking at the 1st list in the index
([0,0,0]
), we can see that we're moving across the 3 columns of source
picking the 1st element (it's zero-indexed) of each column, which is [1,2,3]
. The 2nd list in the index
([0,1,2]
) tells us to move across the columns taking the 1st, 2nd and 3rd items, respectively, which is [1,5,9]
. And so on.
With dim=1
our index
had to have a number of lists equal to the number of rows in the source
, but each list could be as long, or short, as you like. With dim=0
, each list in our index
has to be the same length as the number of columns in the source
, but we can now have as many lists as we like. Each value in index
, however, needs to be less than the number of row in source
(in this case 4
).
For example, index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])
would have source.gather(dim=0, index=index)
give us
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
With dim=1
the output always has the same number of rows as the source
, although the number of columns will equal the length of the lists in index
. The number of lists in index
has to equal the number of rows in source
. Each value in index
, however, needs to be less than the number of columns in source
.
With dim=0
the output always has the same number of columns as the source
, but the number of rows will equal the number of lists in index
. The length of each list in index
has to equal the number of columns in source
. Each value in index
, however, needs to be less than the number of row in source
.
That's it for two dimensions. Moving beyond that will follow the same patterns.
The torch.gather
function (or torch.Tensor.gather
) is a multi-index selection method. Look at the following example from the official docs:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 1],
# [ 4, 3]])
Let's start with going through the semantics of the different arguments: The first argument, input
, is the source tensor that we want to select elements from. The second, dim
, is the dimension (or axis in tensorflow/numpy) that we want to collect along. And finally, index
are the indices to index input
.
As for the semantics of the operation, this is how the official docs explain it:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
So let's go through the example.
the input tensor is [[1, 2], [3, 4]]
, and the dim argument is 1
, i.e. we want to collect from the second dimension. The indices for the second dimension are given as [0, 0]
and [1, 0]
.
As we "skip" the first dimension (the dimension we want to collect along is 1
), the first dimension of the result is implicitly given as the first dimension of the index
. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the index
tensor itself.
For the example, this means that the output will have in its first row a selection of the elements of the input
tensor's first row as well, as given by the first row of the index
tensor's first row. As the column-indices are given by [0, 0]
, we therefore select the first element of the first row of the input twice, resulting in [1, 1]
. Similarly, the elements of the second row of the result are a result of indexing the second row of the input
tensor by the elements of the second row of the index
tensor, resulting in [4, 3]
.
To illustrate this even further, let's swap the dimension in the example:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 2],
# [ 3, 2]])
As you can see, the indices are now collected along the first dimension.
For the example you referred,
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
gather
will index the rows of the q-values (i.e. the per-sample q-values in a batch of q-values) by the batch-list of actions. The result will be the same as if you had done the following (though it will be much faster than a loop):
q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
torch.gather
creates a new tensor from the input tensor by taking the values from each row along the input dimension dim
. The values in torch.LongTensor
, passed as index
, specify which value to take from each 'row'. The dimension of the output tensor is same as the dimension of index tensor. Following illustration from the official docs explains it more clearly:
(Note: In the illustration, indexing starts from 1 and not 0).
In first example, the dimension given is along rows (top to bottom), so for (1,1) position of result
, it takes row value from the index
for the src
that is 1
. At (1,1) in source value is 1
so, outputs 1
at (1,1) in result
.
Similarly for (2,2) the row value from the index for src
is 3
. At (3,2) the value in src
is 8
and hence outputs 8
and so on.
Similarly for second example, indexing is along columns, and hence at (2,2) position of the result
, the column value from the index for src
is 3
, so at (2,3) from src
,6
is taken and outputs to result
at (2,2)