What does `tf.strided_slice()` do?
The conceptualization that really helped me understand this was that this function emulates the indexing behavior of numpy arrays.
If you're familiar with numpy arrays, you'll know that you can make slices via input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN]
. Basically, a very succinct way of writing for
loops to get certain elements of the array.
(If you're familiar with python indexing, you know that you can grab an array slice via input[start:end:step]
. Numpy arrays, which may be nested, make use of the above tuple of slice objects.)
Well, strided_slice
just allows you to do this fancy indexing without the syntactic sugar. The numpy example from above just becomes
# input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN]
tf.strided_slice(input, [start1, start2, ..., startN],
[end1, end2, ..., endN], [step1, step2, ..., stepN])
The documentation is a bit confusing about this in the sense that:
a) begin
- end
is not strictly the shape of the return value:
The documentation claims otherwise, but this is only true if your strides are all ones. Examples:
rank1 = tf.constant(list(range(10)))
# The below op is basically:
# rank1[1:10:2] => [1, 3, 5, 7, 9]
tf.strided_slice(rank1, [1], [10], [2])
# [10,10] grid of the numbers from 0 to 99
rank2 = tf.constant([[i+j*10 for i in range(10)] for j in range(10)])
# The below op is basically:
# rank2[3:7:1, 5:10:2] => numbers 30 - 69, ending in 5, 7, or 9
sliced = tf.strided_slice(rank2, [3, 5], [7, 10], [1, 2])
# The below op is basically:
# rank2[3:7:1] => numbers 30 - 69
sliced = tf.strided_slice(rank2, [3], [7], [1])
b) it states that "begin
, end
, and strides
will be all length n, where n is in general not the same dimensionality as input
"
It sounds like dimensionality means rank here - but input
does have to be a tensor of at least rank-n; it can't be lower (see rank-2 example above).
N.B. I've said nothing/not really explored the masking feature, but that seems beyond the scope of the question.
I experimented a bit with this method, which gave me some insights, which I think might be of some use. let's say we have a tensor.
a = np.array([[[1, 1.2, 1.3], [2, 2.2, 2.3], [7, 7.2, 7.3]],
[[3, 3.2, 3.3], [4, 4.2, 4.3], [8, 8.2, 8.3]],
[[5, 5.2, 5.3], [6, 6.2, 6.3], [9, 9.2, 9.3]]])
# a.shape = (3, 3, 3)
strided_slice()
requires 4 required arguments input_, begin, end, strides
in which we are giving our a
as input_
argument.
As the case with tf.slice()
method, the begin
argument is zero-based and rest of args shape-based. However in the docs begin
and end
both are zero-based.
The functionality of method is quite simple:
It works like iterating over a loop, where begin
is the location of element in the tensor from where the loop initiates and end
is where it stops.
tf.strided_slice(a, [0, 0, 0], [3, 3, 3], [1, 1, 1])
# output = the tensor itself
tf.strided_slice(a, [0, 0, 0], [3, 3, 3], [2, 2, 2])
# output = [[[ 1. 1.3]
# [ 7. 7.3]]
# [[ 5. 5.3]
# [ 9. 9.3]]]
strides
are like steps over which the loop iterates, here the [2,2,2]
makes method to produce values starting at (0,0,0), (0,0,2), (0,2,0), (0,2,2), (2,0,0), (2,0,2) ..... in the a
tensor.
tf.strided_slice(input3, [1, 1, 0], [2, -1, 3], [1, 1, 1])
will produce output similar to tf.strided_slice(input3, [1, 1, 0], [2, 2, 3], [1, 1, 1])
as the tensora
has shape = (3,3,3)
.
The mistake in your argument is the fact that you are directly adding the lists strides
and begin
element by element. This will make the function a lot less useful. Instead, it increments the begin
list one dimension at a time, starting from the last dimension.
Let's solve the first example part by part. begin = [1, 0, 0]
and end = [2, 1, 3]
. Also, all the strides
are 1
. Work your way backwards, from the last dimension.
Start with element [1,0,0]
. Now increase the last dimension only by its stride amount, giving you [1,0,1]
. Keep doing this until you reach the limit. Something like [1,0,2]
, [1,0,3]
(end of the loop). Now in your next iteration, start by incrementing the second to last dimension and resetting the last dimension, [1,1,0]
. Here the second to last dimension is equal to end[1]
, so move to the first dimension (third to last) and reset the rest, giving you [2,0,0]
. Again you are at the first dimension's limit, so quit the loop.
The following code is a recursive implementation of what I described above,
# Assume global `begin`, `end` and `stride`
def iterate(active, dim):
if dim == len(begin):
# last dimension incremented, work on the new matrix
# Note that `active` and `begin` are lists
new_matrix[active - begin] = old_matrix[active]
else:
for i in range(begin[dim], end[dim], stride[dim]):
new_active = copy(active)
new_active[dim] = i
iterate(new_active, dim + 1)
iterate(begin, 0)