Tensorflow.Dataset.take(n) 引发“RuntimeError: input_dataset: Attempting to capture an EagerTensor without building a function.”

问题描述 投票:0回答:0

我正在编写一个数据导入管道来训练 tensorflow 中的模型,这需要来自三个不同数据集(由图像和标签组成)的示例一起生成但保持分开,以如下方式:

((img_ds1, label_ds1), (img_ds2, label_ds2), (img_ds3, label_ds3))
。数据存储在 tfRecords 中,目前使用以下函数导入,这应该非常简单:

def loadDataset(root, ds_type, SEED=1994, shuffle_size=1000, batch_size=32, f_list=[], f_kwargs=[]):
    
# extract list of tfrecord files and import them
    file_list = [os.path.join(root, f) for f in os.listdir(root)]
    ds = TFRecordDataset(file_list)
    print(f'parsing {ds_type} dataset')
    if ds_type == "a":
        ds = ds.map(parseExampleA)
    elif ds_type == "b":
        ds = ds.map(parseExampleB)
    elif ds_type == "c":
        ds = ds.map(parseExampleC)

# applies preprocessing functions
    for f, kwargs in zip(f_list, f_kwargs):
        ds = ds.map(lambda *x: (f(x[0], **kwargs), x[1]))
    
    return ds.shuffle(shuffle_size, SEED).batch(batch_size).repeat()

在脚本中,三个数据集是这样分配的:

from myLibrary import loadDataset

preproc_fun = [fun1, fun2, fun3]
preproc_kwargs = [{"k1":var1, "k2":var2}, {etc.}, {etc.}]
ds_a = loadDataset(root, 'a', f_list=preproc_fun, f_kwargs=preproc_kwargs)
ds_b = loadDataset(root, 'b', f_list=preproc_fun, f_kwargs=preproc_kwargs)
ds_c = loadDataset(root, 'c', f_list=preproc_fun, f_kwargs=preproc_kwargs)

现在,发生的事情是:如果我尝试在

ds.take(n)
函数中执行一些
loadDatases
操作,它会运行并且按照我的意图进行。但是,每当我尝试在函数之外执行此操作时,我都会在标题中得到错误,坦率地说,我无法理解,因为
.take(n)
应该返回另一个数据集,无论如何我使用它很多次,没有任何类似的问题。完整的错误输出如下:

Traceback (most recent call last):
  File "/Users/leo/Documents/repos/ChestXRayEnsembling/ChestXRAY/multidecoder/train_script.py", line 156, in <module>
    x=tf.data.Dataset.zip((chex_ds.take(2).repeat(), nih_ds.take(2).repeat(), vin_ds.take(2).repeat())),
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1697, in take
    return TakeDataset(self, count, name=name)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 5145, in __init__
    variant_tensor = gen_dataset_ops.take_dataset(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 7711, in take_dataset
    _, _, _op, _outputs = _op_def_library._apply_op_helper(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/framework/op_def_library.py", line 777, in _apply_op_helper
    _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/framework/op_def_library.py", line 550, in _ExtractInputsAndAttrs
    values = ops.convert_to_tensor(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/profiler/trace.py", line 183, in wrapped
    return func(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 1586, in convert_to_tensor
    raise RuntimeError(
RuntimeError: input_dataset: Attempting to capture an EagerTensor without building a function.

我在 M1 Mac Pro 上使用 python 3.10.8 和 tensorflow-metal 0.7 在 conda 环境中工作。我只使用我 ds 的一小部分的原因是我想在实际训练我的模型之前测试一切是否正常。

python tensorflow deep-learning tensorflow-datasets data-import
© www.soinside.com 2019 - 2024. All rights reserved.