Equivalence of slicing tensor in Pytorch/ATen C++
1. You can also use .slice
Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t step)
auto partial_gates = gates.slice(1, 0, 3).chunk(4, 1);
2. Pytorch 1.5 using Tensor::index
and Tensor::index_put_
using namespace torch::indexing;
auto partial_gates = gates.index({"...", Slice(None, 2)}).chunk(4, 1);
Also supports multimensional indexing
General translation for Tensor::index
and Tensor::index_put_
Python C++ (assuming `using namespace torch::indexing`)
-------------------------------------------------------------------
0 0
None None
... "..." or Ellipsis
: Slice()
start:stop:step Slice(start, stop, step)
True / False true / false
[[1, 2]] torch::tensor({{1, 2}})
It's .narrow()
from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp#L364
auto partial_gates = gates.narrow(1,0,2).chunk(4, 1);