How do I select certain columns of a 2D tensor in TensorFlow?

There is a function named tf.nn.embedding_lookup(params, ind) which retrieves the rows of the params tensor.

To achieve what you want, we can first transpose the tensor t from which you want to select certain columns from. Then look up the rows of tf.transpose(t) (columns of t). After the selection, we transpose the result back.

import tensorflow as tf


t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])
ind = tf.constant([0, 2])

result = tf.transpose(tf.nn.embedding_lookup(tf.transpose(t), ind))

with tf.Session() as sess:
    print(sess.run(result))

Meanwhile the gather method has an axis parameter.

import tensorflow as tf
params = tf.constant([[1,2,3],[4,5,6]])
indices = [0,2]
op = tf.gather(params, indices, axis=1)

produces the output

[[1 3]
 [4 6]]

So far, I created a workaround by flattening the input and using gather:

def gather_cols(params, indices, name=None):
    """Gather columns of a 2D tensor.

    Args:
        params: A 2D tensor.
        indices: A 1D tensor. Must be one of the following types: ``int32``, ``int64``.
        name: A name for the operation (optional).

    Returns:
        A 2D Tensor. Has the same type as ``params``.
    """
    with tf.op_scope([params, indices], name, "gather_cols") as scope:
        # Check input
        params = tf.convert_to_tensor(params, name="params")
        indices = tf.convert_to_tensor(indices, name="indices")
        try:
            params.get_shape().assert_has_rank(2)
        except ValueError:
            raise ValueError('\'params\' must be 2D.')
        try:
            indices.get_shape().assert_has_rank(1)
        except ValueError:
            raise ValueError('\'indices\' must be 1D.')

        # Define op
        p_shape = tf.shape(params)
        p_flat = tf.reshape(params, [-1])
        i_flat = tf.reshape(tf.reshape(tf.range(0, p_shape[0]) * p_shape[1],
                                       [-1, 1]) + indices, [-1])
        return tf.reshape(tf.gather(p_flat, i_flat),
                          [p_shape[0], -1])

Which for:

params = tf.constant([[1, 2, 3],
                      [4, 5, 6]])
indices = [0, 2]
op = gather_cols(params, indices)

produces the expected output:

[[1 3]
 [4 6]]

Tags:

Tensorflow