如何在数据集上使用 tf.repeat 和另一个内置高级函数?

问题描述 投票:0回答:1

我想做类似 NER 任务代码的事情,它将单词的 WordPieces 与该单词的标签对齐:

import tensorflow as tf

tokens = tf.ragged.constant([[4], [2, 5, 9]], dtype=tf.int32)
tags = tf.ragged.constant([3, 5], dtype=tf.int32)

flat_tokens = tf.reshape(tokens, [-1])
duplicated_tags = tf.repeat(tags, [tf.shape(tok)[0] for tok in tokens])

print(flat_tokens.numpy())  # -> [4 2 5 9]
print(duplicated_tags.numpy())  # -> [3 5 5 5]

但是输入

tokens
tags
tf.repeat
作为数据集,应该是
TextLineDataset
的输出。有什么简约的方法可以做到吗?

tensorflow tensorflow2.0 tensorflow-lite tensorflow-datasets
1个回答
0
投票

也许是这样的:

import tensorflow as tf

tokens = tf.data.Dataset.from_tensor_slices(tf.ragged.constant([[4], [2, 5, 9]], dtype=tf.int32))
tags = tf.data.Dataset.from_tensor_slices(tf.ragged.constant([3, 5], dtype=tf.int32))

ds = tf.data.Dataset.zip((tokens, tags)).map(lambda x, y: (x, tf.repeat(y, repeats=tf.shape(x)[0])))
tokens = ds.map(lambda a, b: a).flat_map(tf.data.Dataset.from_tensor_slices)
tags = ds.map(lambda a, b: b).flat_map(tf.data.Dataset.from_tensor_slices)

print(list(tokens.as_numpy_iterator()))
print(list(tags.as_numpy_iterator()))
[4, 2, 5, 9]
[3, 5, 5, 5]
© www.soinside.com 2019 - 2024. All rights reserved.