I am still learning tensorflow and keras, and I suspect this question has a very easy answer I'm just missing due to lack of familiarity.
I have a PrefetchDataset
object:
> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>
...made up of features and a target. I can iterate over it using a for
loop:
> for example in tf_test:
> print(example[0].numpy())
> print(example[1].numpy())
> exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
[-0.22 -0.54 -0.14 ... 0.33 -0.55]
[-0.60 -0.02 -1.41 ... 0.21 -0.63]
...
[-0.03 -0.91 -0.12 ... 0.77 -0.23]
[-0.76 -1.48 -0.15 ... 0.38 -0.35]
[-0.55 -0.08 -0.69 ... 0.44 -0.36]]
[0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
...
0 1 1 0]
However, this is very slow. What I'd like to do is access the tensor corresponding to the class labels and turn that into a numpy array, or a list, or any sort of iterable that can be fed into scikit-learn's classification report and/or confusion matrix:
> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
[0.14]
[0.00]
...
[0.32]
[0.03]
[0.00]]
> y_pred_list = [int(x[0]) for x in y_pred] # assumes value >= 0.5 is positive prediction
> y_true = [] # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)
...OR access the data such that it could be used in tensorflow's confusion matrix:
> labels = [] # what I need help with
> predictions = y_pred_list # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)
In both cases, the general ability to grab the target data from the original object in a manner that isn't computationally expensive would be very helpful (and might help with my underlying intuitions re: tensorflow and keras).
Any advice would be greatly appreciated.
You can convert it to a list with list(ds)
and then recompile it as a normal Dataset with tf.data.Dataset.from_tensor_slices(list(ds))
. From there your nightmare begins again but at least it's a nightmare that other people have had before.
Note that for more complex datasets (e.g. nested dictionaries) you will need more preprocessing after calling list(ds)
, but this should work for the example you asked about.
This is far from a satisfying answer but unfortunately the class is entirely undocumented and none of the standard Dataset tricks work.