Increasing each element of a tensor by the predecessor in Tensorflow 2.0
I'll give you a couple of different methods to implement that. I think the most obvious solution is to use tf.scan
:
import tensorflow as tf
def apply_momentum_scan(m, p, axis=0):
# Put axis first
axis = tf.convert_to_tensor(axis, dtype=tf.int32)
perm = tf.concat([[axis], tf.range(axis), tf.range(axis + 1, tf.rank(m))], axis=0)
m_t = tf.transpose(m, perm)
# Do computation
res_t = tf.scan(lambda a, x: a * p + x, m_t)
# Undo transpose
perm_t = tf.concat([tf.range(1, axis + 1), [0], tf.range(axis + 1, tf.rank(m))], axis=0)
return tf.transpose(res_t, perm_t)
However, you can also implement this as a particular matrix product, if you build a matrix of exponential factors:
import tensorflow as tf
def apply_momentum_matmul(m, p, axis=0):
# Put axis first and reshape
m = tf.convert_to_tensor(m)
p = tf.convert_to_tensor(p)
axis = tf.convert_to_tensor(axis, dtype=tf.int32)
perm = tf.concat([[axis], tf.range(axis), tf.range(axis + 1, tf.rank(m))], axis=0)
m_t = tf.transpose(m, perm)
shape_t = tf.shape(m_t)
m_tr = tf.reshape(m_t, [shape_t[0], -1])
# Build factors matrix
r = tf.range(tf.shape(m_tr)[0])
p_tr = tf.linalg.band_part(p ** tf.dtypes.cast(tf.expand_dims(r, 1) - r, p.dtype), -1, 0)
# Do computation
res_tr = p_tr @ m_tr
# Reshape back and undo transpose
res_t = tf.reshape(res_tr, shape_t)
perm_t = tf.concat([tf.range(1, axis + 1), [0], tf.range(axis + 1, tf.rank(m))], axis=0)
return tf.transpose(res_t, perm_t)
This can also be rewritten to avoid the first transposing (which in TensorFlow is expensive) with tf.tensordot
:
import tensorflow as tf
def apply_momentum_tensordot(m, p, axis=0):
# Put axis first and reshape
m = tf.convert_to_tensor(m)
# Build factors matrix
r = tf.range(tf.shape(m)[axis])
p_mat = tf.linalg.band_part(p ** tf.dtypes.cast(tf.expand_dims(r, 1) - r, p.dtype), -1, 0)
# Do computation
res_t = tf.linalg.tensordot(m, p_mat, axes=[[axis], [1]])
# Transpose
last_dim = tf.rank(res_t) - 1
perm_t = tf.concat([tf.range(axis), [last_dim], tf.range(axis, last_dim)], axis=0)
return tf.transpose(res_t, perm_t)
The three functions would be used in a similar way:
import tensorflow as tf
p = tf.Variable(0.5, dtype=tf.float32)
m = tf.constant([[0, 1, 2, 3, 4],
[1, 3, 5, 7, 10],
[1, 1, 1, -1, 0]], tf.float32)
# apply_momentum is one of the functions above
print(apply_momentum(m, p, axis=0).numpy())
# [[ 0. 1. 2. 3. 4. ]
# [ 1. 3.5 6. 8.5 12. ]
# [ 1.5 2.75 4. 3.25 6. ]]
print(apply_momentum(m, p, axis=1).numpy())
# [[ 0. 1. 2.5 4.25 6.125 ]
# [ 1. 3.5 6.75 10.375 15.1875]
# [ 1. 1.5 1.75 -0.125 -0.0625]]
Using a matrix product is more asymptotically complex, but it can be faster than scanning. Here is a small benchmark:
import tensorflow as tf
import numpy as np
# Make test data
tf.random.set_seed(0)
p = tf.constant(0.5, dtype=tf.float32)
m = tf.random.uniform([100, 30, 50], dtype=tf.float32)
# Axis 0
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_matmul(m, p, 0).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_tensordot(m, p, 0).numpy()))
# True
%timeit apply_momentum_scan(m, p, 0)
# 11.5 ms ± 610 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply_momentum_matmul(m, p, 0)
# 1.36 ms ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit apply_momentum_tensordot(m, p, 0)
# 1.62 ms ± 7.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# Axis 1
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_matmul(m, p, 1).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_tensordot(m, p, 1).numpy()))
# True
%timeit apply_momentum_scan(m, p, 1)
# 4.27 ms ± 60.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply_momentum_matmul(m, p, 1)
# 1.27 ms ± 36.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit apply_momentum_tensordot(m, p, 1)
# 1.2 ms ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# Axis 2
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_matmul(m, p, 2).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_tensordot(m, p, 2).numpy()))
# True
%timeit apply_momentum_scan(m, p, 2)
# 6.29 ms ± 64.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply_momentum_matmul(m, p, 2)
# 1.41 ms ± 21.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit apply_momentum_tensordot(m, p, 2)
# 1.05 ms ± 26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
So, matrix product seems to win. Let's see if this scales:
import tensorflow as tf
import numpy as np
# Make test data
tf.random.set_seed(0)
p = tf.constant(0.5, dtype=tf.float32)
m = tf.random.uniform([1000, 300, 500], dtype=tf.float32)
# Axis 0
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_matmul(m, p, 0).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_tensordot(m, p, 0).numpy()))
# True
%timeit apply_momentum_scan(m, p, 0)
# 784 ms ± 6.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_matmul(m, p, 0)
# 1.13 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_tensordot(m, p, 0)
# 1.3 s ± 27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Axis 1
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_matmul(m, p, 1).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_tensordot(m, p, 1).numpy()))
# True
%timeit apply_momentum_scan(m, p, 1)
# 852 ms ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_matmul(m, p, 1)
# 659 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_tensordot(m, p, 1)
# 741 ms ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Axis 2
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_matmul(m, p, 2).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_tensordot(m, p, 2).numpy()))
# True
%timeit apply_momentum_scan(m, p, 2)
# 1.06 s ± 16.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_matmul(m, p, 2)
# 924 ms ± 17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_tensordot(m, p, 2)
# 483 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Well, now it's not so clear anymore. Scanning is still not super fast, but matrix products are sometimes slower. As you can imagine if you go to even bigger tensors the complexity of matrix products will dominate the timings.
So, if you want the fastest solution and know your tensors are not going to get huge, use one of the matrix product implementations. If you're fine with okay speed but want to make sure you don't run out of memory (matrix solution also takes much more) and timing is predictable, you can use the scanning solution.
Note: Benchmarks above were carried out on CPU, results may vary significantly on GPU.
Here is an answer that just provides some information, and a naive solution to fix the code---not the actual problem (please refer below for the why).
First of, the TypeError
is a problem of incompatible types in the tensors of your early attempt. Some tensors contain floating point numbers (double), some contain integers. It would have helped to show the full error message:
TypeError: Input 'y' of 'Mul' Op has type int32 that does not match type float64 of argument 'x'.
Which happens to put on the right track (despite the gory details of the stack trace).
Here is a naive fix to get the code to work (with caveats against the target problem):
import tensorflow as tf
@tf.function
def vectorize_predec(t, p):
_p = tf.transpose(
tf.convert_to_tensor(
[p * t[...,idx] for idx in range(t.shape[-1] - 1)],
dtype=tf.float64))
_p = tf.concat([
tf.zeroes((_p.shape[0], 1), dtype=tf.float64),
_p
], axis=1)
return t + _p
p = tf.Variable(0.5, dtype='double')
m = tf.constant([[0, 1, 2, 3, 4],
[1, 3, 5, 7, 10],
[1, 1, 1, -1, 0]], dtype=tf.float64)
n = tf.constant([[0.0, 1.0, 2.5, 4.0, 5.5],
[1.0, 3.5, 6.5, 9.5, 13.5],
[1.0, 1.5, 1.5, -0.5, -0.5]], dtype=tf.float64)
print(f'Expected: {n}')
result = vectorize_predec(m, p)
print(f'Result: {result}')
tf.test.TestCase().assertAllEqual(n, result)
The main changes:
- The
m
tensor gets adtype=tf.float64
to match the orignaldouble
, so the type error vanishes. - The function is basically a complete rewrite. The naive idea is to exploit the problem definition, which does not state whether the values in
N
are calculated before or after updates. Here is a version before update, way easier. Solving what seems to be the "real" problem requires working a bit more on the function (see other answers, and I may work more here).
How the function works:
- It calculates the expected increments
p * x1
,p * x2
, etc into a standard Python array. Note it stops before the last element of the last dimension, as we will shift the array. - It converts the array to a tensor with
tf.convert_to_tensor
, so adding the array to the computation graph. The transpose is necessary to match the original tensor shape (we could avoid it). - It appends zeroes at the beginning of each dimension along the last axis.
- The result is the sum of the original tensor and the constructed one.
The values become x1 + 0.0 * p
, then x2 + x1 * p
, etc. This illustrates a few functions and issues to look at (types, shapes), but I admit it cheats and does not solve the actual problem.
Also, this code is not efficient on any hardware. It is just illustrative, and would need to (1) eliminate the Python array, (2) eliminate the transpose, (3) eliminate the concatenate operation. Hopefully great training :-)
Extra notes:
- The problem asks for a solution on tensors of shape (a, b, c). The code you share works on tensors of shape (a, b), so fixing the code will still not solve the problem.
- The problem requires rational numbers. Not sure what the intent is, and this answer leaves this requirement aside.
- The shape of
T = [x1, x2, x3, x4]
is actually(4,)
, assumingxi
are scalars. - Why
tf.float64
? By default, we gettf.float32
, and removing thedouble
would make the code works. But the example would lose the point that types matter, so the choice for an explicit non-default type (and uglier code).