我有
categories
作为列表整数列表,如下所示:
categories = [
[0,2,4,6,8],
[1,3,5,7,9]
]
我有一个标签张量
y
和 num_batches
整数(作为类):
y = tf.constant([0, 1, 1, 2, 5, 4, 7, 9, 3, 3])
我想用可用的
y
列表替换 categories
中的值与某些索引(比方说 0-偶数,1-奇数),这样最终结果将是:
cat_labels = tf.constant([0, 1, 1, 0, 1, 0, 1, 1, 1, 1])
我可以通过迭代
y
中的每个值来获取它,如下所示:
cat_labels = tf.Variable(tf.identity(y))
for idx in range(len(categories)):
for i, _y in enumerate(y):
if _y in categories[idx]: # if _y value is in categories[idx]
cat_labels[i].assign(idx) # replace all of them with idx
但显然,当该块封装在
@tf.function
父函数中时,不允许进行迭代。
有没有一种方法可以在不迭代或转换为 numpy 并应用
np.isin
的情况下应用逻辑,同时获得 tf.function
的加速?
编辑:似乎有像here这样的解决方法,但是如果您能在这个用例的上下文中进行解释,我们将不胜感激。
您可以尝试按以下方式使用
tf.map_fn
:
y = tf.constant([0., 1., 1., 2., 5., 4., 7., 9., 3., 3.], dtype=tf.float32)
categories = [[0,2,4,6,8],[1,3,5,7,9]]
c = tf.convert_to_tensor(categories, dtype=tf.float32)
cat_labels = tf.map_fn( # apply an operation on all of the elements of Y
lambda x:tf.gather_nd( # get index of category: 0 or 1 or anything else
tf.cast( # cast dtype of the result of the inner function
tf.where( # get index of the element of Y in categories
tf.equal(c, x)), # search an element of Y within categories
dtype=tf.float32),[0,0]), y)
tf.print(cat_labels, summarize=-1)
# [0 1 1 0 1 0 1 1 1 1]