Torch sum a tensor along an axis
The simplest and best solution is to use torch.sum()
.
To sum all elements of a tensor:
torch.sum(x) # gives back a scalar
To sum over all rows (i.e. for each column):
torch.sum(x, dim=0) # size = [1, ncol]
To sum over all columns (i.e. for each row):
torch.sum(x, dim=1) # size = [nrow, 1]
Alternatively, you can use tensor.sum(axis)
where axis
indicates 0
and 1
for summing over rows and columns respectively, for a 2D tensor.
In [210]: X
Out[210]:
tensor([[ 1, -3, 0, 10],
[ 9, 3, 2, 10],
[ 0, 3, -12, 32]])
In [211]: X.sum(1)
Out[211]: tensor([ 8, 24, 23])
In [212]: X.sum(0)
Out[212]: tensor([ 10, 3, -10, 52])
As, we can see from the above outputs, in both cases, the output is a 1D tensor. If you, on the other hand, wish to retain the dimension of the original tensor in the output as well, then you've set the boolean kwarg keepdim
to True
as in:
In [217]: X.sum(0, keepdim=True)
Out[217]: tensor([[ 10, 3, -10, 52]])
In [218]: X.sum(1, keepdim=True)
Out[218]:
tensor([[ 8],
[24],
[23]])