Extract target from Tensorflow PrefetchDataset

jda picture jda · Jun 17, 2020 · Viewed 9.6k times · Source

I am still learning tensorflow and keras, and I suspect this question has a very easy answer I'm just missing due to lack of familiarity.

I have a PrefetchDataset object:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>

...made up of features and a target. I can iterate over it using a for loop:

> for example in tf_test:
>     print(example[0].numpy())
>     print(example[1].numpy())
>     exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
   [-0.22 -0.54 -0.14 ... 0.33 -0.55]
   [-0.60 -0.02 -1.41 ... 0.21 -0.63]
   ...
   [-0.03 -0.91 -0.12 ... 0.77 -0.23]
   [-0.76 -1.48 -0.15 ... 0.38 -0.35]
   [-0.55 -0.08 -0.69 ... 0.44 -0.36]]
  [0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
   ...
   0 1 1 0]

However, this is very slow. What I'd like to do is access the tensor corresponding to the class labels and turn that into a numpy array, or a list, or any sort of iterable that can be fed into scikit-learn's classification report and/or confusion matrix:

> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
   [0.14]
   [0.00]
   ...
   [0.32]
   [0.03]
   [0.00]]
> y_pred_list = [int(x[0]) for x in y_pred]             # assumes value >= 0.5 is positive prediction
> y_true = []                                           # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)

...OR access the data such that it could be used in tensorflow's confusion matrix:

> labels = []                                           # what I need help with
> predictions = y_pred_list                             # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)

In both cases, the general ability to grab the target data from the original object in a manner that isn't computationally expensive would be very helpful (and might help with my underlying intuitions re: tensorflow and keras).

Any advice would be greatly appreciated.

Answer

markemus picture markemus · Jul 31, 2020

You can convert it to a list with list(ds) and then recompile it as a normal Dataset with tf.data.Dataset.from_tensor_slices(list(ds)). From there your nightmare begins again but at least it's a nightmare that other people have had before.

Note that for more complex datasets (e.g. nested dictionaries) you will need more preprocessing after calling list(ds), but this should work for the example you asked about.

This is far from a satisfying answer but unfortunately the class is entirely undocumented and none of the standard Dataset tricks work.