How to find an index of the first matching element in TensorFlow
It seems that tf.argmax
works like np.argmax
(according to the test), which will return the first index when there are multiple occurrences of the max value.
You can use tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)
to get what you want. However, currently the behavior of tf.argmax
is undefined in case of multiple occurrences of the max value.
If you are worried about undefined behavior, you can apply tf.argmin
on the return value of tf.where
as @Igor Tsvetkov suggested.
For example,
# test with tensorflow r1.0
import tensorflow as tf
val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0 , 0, val, 0, val],
[val, 0, val, val, 0],
[0 , val, 0, 0, 0]]
tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])
with tf.Session() as sess:
print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]
Note that tf.segment_min
will raise InvalidArgumentError
when there is some row containing no val
. In your code row_elems.index(val)
will raise exception too when row_elems
don't contain val
.
Here is another solution to the problem, assuming there is a hit on every row.
import tensorflow as tf
val = 3
m = tf.constant([
[0 , 0, val, 0, val],
[val, 0, val, val, 0],
[0 , val, 0, 0, 0]])
# replace all entries in the matrix either with its column index, or out-of-index-number
match_indices = tf.where( # [[5, 5, 2, 5, 4],
tf.equal(val, m), # [0, 5, 2, 3, 5],
x=tf.range(tf.shape(m)[1]) * tf.ones_like(m), # [5, 1, 5, 5, 5]]
y=(tf.shape(m)[1])*tf.ones_like(m))
result = tf.reduce_min(match_indices, axis=1)
with tf.Session() as sess:
print(sess.run(result)) # [2, 0, 1]
Looks a little ugly but works (assuming m
and val
are both tensors):
idx = list()
for t in tf.unpack(m, axis=0):
idx.append(tf.reduce_min(tf.where(tf.equal(t, val))))
idx = tf.pack(idx, axis=0)
EDIT:
As Yaroslav Bulatov mentioned, you could achieve the same result with tf.map_fn
:
def index1d(t):
return tf.reduce_min(tf.where(tf.equal(t, val)))
idx = tf.map_fn(index1d, m, dtype=tf.int64)
Here is a solution which also considers the case the element is not included by the matrix (solution from github repository of DeepMind)
def get_first_occurrence_indices(sequence, eos_idx):
'''
args:
sequence: [batch, length]
eos_idx: scalar
'''
batch_size, maxlen = sequence.get_shape().as_list()
eos_idx = tf.convert_to_tensor(eos_idx)
tensor = tf.concat(
[sequence, tf.tile(eos_idx[None, None], [batch_size, 1])], axis = -1)
index_all_occurrences = tf.where(tf.equal(tensor, eos_idx))
index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1],
index_all_occurrences[:, 0])
index_first_occurrences.set_shape([batch_size])
index_first_occurrences = tf.minimum(index_first_occurrences + 1, maxlen)
return index_first_occurrences
And:
import tensorflow as tf
mat = tf.Variable([[1,2,3,4,5], [2,3,4,5,6], [3,4,5,6,7], [0,0,0,0,0]], dtype = tf.int32)
idx = 3
first_occurrences = get_first_occurrence_indices(mat, idx)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(first_occurrence) # [3, 2, 1, 5]