访问tensorflow PrefetchDataset中的数据

问题描述 投票:-1回答:1

我仍在学习tensorflow和keras,我怀疑这个问题有一个很简单的答案,由于缺乏熟悉性,我只是想念它。

我有一个PrefetchDataset对象:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>

...由功能和目标组成。我可以使用for循环对其进行迭代:

> for example in tf_test:
>     print(example[0].numpy())
>     print(example[1].numpy())
>     exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
   [-0.22 -0.54 -0.14 ... 0.33 -0.55]
   [-0.60 -0.02 -1.41 ... 0.21 -0.63]
   ...
   [-0.03 -0.91 -0.12 ... 0.77 -0.23]
   [-0.76 -1.48 -0.15 ... 0.38 -0.35]
   [-0.55 -0.08 -0.69 ... 0.44 -0.36]]
  [0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
   ...
   0 1 1 0]

但是,这非常慢。我想做的是访问与类标签相对应的张量,并将其转换为numpy数组或列表,或者将其输入到scikit-learn的分类报告和/或混淆矩阵中的任何可迭代对象:] >

> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
   [0.14]
   [0.00]
   ...
   [0.32]
   [0.03]
   [0.00]]
> y_pred_list = [int(x[0]) for x in y_pred]             # assumes value >= 0.5 is positive prediction
> y_true = []                                           # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)

...或访问数据,以便可以在tensorflow的混淆矩阵中使用它:

> labels = []                                           # what I need help with
> predictions = y_pred_list                             # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)

在两种情况下,以一种在计算上并不昂贵的方式从原始对象中获取目标数据的一般能力将非常有帮助(并且可能有助于我的基本直觉:tensorflow和keras)。

任何建议将不胜感激。

我仍在学习tensorflow和keras,我怀疑这个问题有一个非常简单的答案,由于缺乏熟悉性,我只是想念它。我有一个PrefetchDataset对象:> print(tf_test)$ <...>

python tensorflow machine-learning keras prefetch
1个回答
0
投票

您的帖子的某些部分对我来说很模糊,所以我的回答代表这一部分:

© www.soinside.com 2019 - 2024. All rights reserved.