I want to use tf.data.Dataset.list_files function to feed my datasets.
But because the file is not image, I need to load it manually.
The problem is tf.data.Dataset.list_files pass variable as tf.tensor and my python code can not handle tensor.
How can I get string value from tf.tensor. The dtype is string.
train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))
def load_audio_file(file_path):
print("file_path: ", file_path)
# i want do something like string_path = convert_tensor_to_string(file_path)
file_path is Tensor("arg0:0", shape=(), dtype=string)
I use tensorflow 1.13.1 and eager mode.
thanks in advance
You can use tf.py_func
to wrap load_audio_file()
.
import tensorflow as tf
tf.enable_eager_execution()
def load_audio_file(file_path):
# you should decode bytes type to string type
print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
return file_path
train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))
for one_element in train_dataset:
print(one_element)
file_path: clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path: clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path: clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)
UPDATE for TF 2
The above solution will not work with TF 2 (tested with 2.2.0), even when replacing tf.py_func
with tf.py_function
, giving
InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'
To make it work in TF 2, make the following changes:
tf.enable_eager_execution()
(eager is enabled by default in TF 2, which you can verify with tf.executing_eagerly()
returning True
)tf.py_func
with tf.py_function
file_path
with file_path.numpy()