how to get string value out of tf.tensor which dtype is string

Ko Ohhashi picture Ko Ohhashi · May 14, 2019 · Viewed 11.8k times · Source

I want to use tf.data.Dataset.list_files function to feed my datasets.
But because the file is not image, I need to load it manually.
The problem is tf.data.Dataset.list_files pass variable as tf.tensor and my python code can not handle tensor.

How can I get string value from tf.tensor. The dtype is string.

train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))

def load_audio_file(file_path):
  print("file_path: ", file_path)
  # i want do something like string_path = convert_tensor_to_string(file_path)

file_path is Tensor("arg0:0", shape=(), dtype=string)

I use tensorflow 1.13.1 and eager mode.

thanks in advance

Answer

giser_yugang picture giser_yugang · May 14, 2019

You can use tf.py_func to wrap load_audio_file().

import tensorflow as tf

tf.enable_eager_execution()

def load_audio_file(file_path):
    # you should decode bytes type to string type
    print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
    return file_path

train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))

for one_element in train_dataset:
    print(one_element)

file_path:  clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path:  clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path:  clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)

UPDATE for TF 2

The above solution will not work with TF 2 (tested with 2.2.0), even when replacing tf.py_func with tf.py_function, giving

InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'

To make it work in TF 2, make the following changes:

  • Remove tf.enable_eager_execution() (eager is enabled by default in TF 2, which you can verify with tf.executing_eagerly() returning True)
  • Replace tf.py_func with tf.py_function
  • Replace all in-function references of file_path with file_path.numpy()