Tensorflow:如何查找 tf.data.Dataset API 对象的大小

问题描述 投票:0回答:5

我理解Dataset API是一种迭代器,它不会将整个数据集加载到内存中,因此它无法找到数据集的大小。我正在谈论存储在文本文件或 tfRecord 文件中的大型数据集。这些文件通常使用

tf.data.TextLineDataset
或类似的东西来读取。使用
tf.data.Dataset.from_tensor_slices
查找加载的数据集的大小很简单。

我询问数据集大小的原因如下: 假设我的数据集大小是 1000 个元素。批量大小 = 50 个元素。然后训练步骤/批次(假设 1 epoch)= 20。在这 20 个步骤中,我希望将我的学习率从 0.1 指数衰减到 0.01,如

tf.train.exponential_decay(
    learning_rate = 0.1,
    global_step = global_step,
    decay_steps = 20,
    decay_rate = 0.1,
    staircase=False,
    name=None
)

在上面的代码中,我有“and”想要设置

decay_steps = number of steps/batches per epoch = num_elements/batch_size
。只有预先知道数据集中的元素数量,才能计算出这一点。

提前知道大小的另一个原因是使用

tf.data.Dataset.take()
tf.data.Dataset.skip()
方法将数据分成训练集和测试集。

PS:我并不是在寻找暴力方法,例如迭代整个数据集并更新计数器来计算元素数量,或者放置非常大的批量大小,然后查找结果数据集的大小等。

python tensorflow tensorflow-datasets
5个回答
4
投票

您可以使用以下方式轻松获取数据样本的数量:

dataset.__len__()

你可以这样获取每个元素:

for step, element in enumerate(dataset.as_numpy_iterator()):
...     print(step, element)

您还可以获得一个样本的形状:

dataset.element_spec

如果你想获取特定元素,也可以使用分片方法。


1
投票

我意识到这个问题已经有两年了,但也许这个答案会有用。

如果您使用

tf.data.TextLineDataset
读取数据,那么获取样本数量的一种方法可能是计算您正在使用的所有文本文件中的行数。

考虑以下示例:

import random
import string
import tensorflow as tf

filenames = ["data0.txt", "data1.txt", "data2.txt"]

# Generate synthetic data.
for filename in filenames:
    with open(filename, "w") as f:
        lines = [random.choice(string.ascii_letters) for _ in range(random.randint(10, 100))]
        print("\n".join(lines), file=f)

dataset = tf.data.TextLineDataset(filenames)

尝试使用

len
获取长度会引发
TypeError
:

len(dataset)

但是可以相对快速地计算出文件中的行数。

# https://stackoverflow.com/q/845058/5666087
def get_n_lines(filepath):
    i = -1
    with open(filepath) as f:
        for i, _ in enumerate(f):
            pass
    return i + 1

n_lines = sum(get_n_lines(f) for f in filenames)

在上面,

n_lines
等于使用

迭代数据集时找到的元素数量
for i, _ in enumerate(dataset):
    pass
n_lines == i + 1

0
投票

您可以选择手动指定数据集的大小吗?

我如何加载数据:

sample_id_hldr = tf.placeholder(dtype=tf.int64, shape=(None,), name="samples")

sample_ids = tf.Variable(sample_id_hldr, validate_shape=False, name="samples_cache")
num_samples = tf.size(sample_ids)

data = tf.data.Dataset.from_tensor_slices(sample_ids)
# "load" data by id:
# return (id, data) for each id
data = data.map(
    lambda id: (id, some_load_op(id))
)

在这里,您可以通过使用占位符初始化

sample_ids
一次来指定所有样本 ID。
您的样本 ID 可能是例如文件路径或简单数字(
np.arange(num_elems)
)

然后可以在

num_samples
中获得元素的数量。


0
投票

这是我解决问题的方法,将以下行添加到您的数据集中

tf.data.experimental.assert_cardinality(len_of_data)
这将解决问题,

ast = Audioset(df) # the generator class
db = tf.data.Dataset.from_generator(ast, output_types=(tf.float32, tf.float32, tf.int32))
db = db.apply(tf.data.experimental.assert_cardinality(len(ast))) # number of samples
db = db.batch(batch_size)

数据集长度根据batch_size改变,只需运行

len(db)
即可获得数据集长度。 检查这里了解更多详情


0
投票
 print(len(list(dataset.as_numpy_iterator())))  

会做需要的事情

© www.soinside.com 2019 - 2024. All rights reserved.