What is the alternative of numpy.newaxis in tensorflow?
The corresponding command is tf.newaxis
(or None
, as in numpy). It does not have an entry on its own in tensorflow's documentation, but is briefly mentioned on the doc page of tf.stride_slice
.
x = tf.ones((10,10,10))
y = x[:, tf.newaxis] # or y = x [:, None]
print(y.shape)
# prints (10, 1, 10, 10)
Using tf.expand_dims
is fine too but, as stated in the link above,
Those interfaces are much more friendly, and highly recommended.
a = a[..., tf.newaxis].astype("float32")
This Works as well
I think that would be tf.expand_dims
-
tf.expand_dims(a, 1) # Or tf.expand_dims(a, -1)
Basically, we list the axis ID where this new axis is to be inserted and the trailing axes/dims are pushed-back.
From the linked docs, here's few examples of expanding dimensions -
# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]
# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]