How can I filter tf.data.Dataset by specific values?
I think you don't need to make label a 1-dimensional array in the first place.
with:
feature = {'label': tf.FixedLenFeature((), tf.string)}
you won't need to unstack the label in your filter_func
I am answering my own question. I found the issue!
What I needed to do is tf.unstack()
the label like this:
label = tf.unstack(features['label'])
label = label[0]
before I give it to tf.equal()
:
result = tf.reshape(tf.equal(label, 'some_label_value'), [])
I suppose the problem was that the label is defined as an array with one element of type string tf.FixedLenFeature([1], tf.string)
, so in order to get the first and single element I had to unpack it (which creates a list) and then get the element with index 0, correct me if I'm wrong.