I create a dataset by reading the TFRecords, I map the values and I want to filter the dataset for specific values, but since the result is a dict with tensors, I am not able to get the actual value of a tensor or to check it with tf.cond()
/ tf.equal
. How can I do that?
def mapping_func(serialized_example):
feature = { 'label': tf.FixedLenFeature([1], tf.string) }
features = tf.parse_single_example(serialized_example, features=feature)
return features
def filter_func(features):
# this doesn't work
#result = features['label'] == 'some_label_value'
# neither this
result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
return result
def main():
file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(file_names)
dataset = dataset.map(mapping_func)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.filter(filter_func)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
sample = iterator.get_next()
I am answering my own question. I found the issue!
What I needed to do is tf.unstack()
the label like this:
label = tf.unstack(features['label'])
label = label[0]
before I give it to tf.equal()
:
result = tf.reshape(tf.equal(label, 'some_label_value'), [])
I suppose the problem was that the label is defined as an array with one element of type string tf.FixedLenFeature([1], tf.string)
, so in order to get the first and single element I had to unpack it (which creates a list) and then get the element with index 0, correct me if I'm wrong.