相当于 TensorFlow 的 np.isin

问题描述 投票:0回答:1

我有

categories
作为列表整数列表,如下所示:

categories = [
 [0,2,4,6,8],
 [1,3,5,7,9]  
]

我有一个标签张量

y
num_batches
整数(作为类):

y = tf.constant([0, 1, 1, 2, 5, 4, 7, 9, 3, 3])

我想用可用的

y
列表替换
categories
中的值与某些索引(比方说 0-偶数,1-奇数),这样最终结果将是:

cat_labels = tf.constant([0, 1, 1, 0, 1, 0, 1, 1, 1, 1])

我可以通过迭代

y
中的每个值来获取它,如下所示:

cat_labels = tf.Variable(tf.identity(y))
for idx in range(len(categories)):
    for i, _y in enumerate(y):
        if _y in categories[idx]: # if _y value is in categories[idx]
            cat_labels[i].assign(idx) # replace all of them with idx

但显然,当该块封装在

@tf.function
父函数中时,不允许进行迭代。

有没有一种方法可以在不迭代或转换为 numpy 并应用

np.isin
的情况下应用逻辑,同时获得
tf.function
的加速?

编辑:似乎有像here这样的解决方法,但是如果您能在这个用例的上下文中进行解释,我们将不胜感激。

tensorflow
1个回答
1
投票

您可以尝试按以下方式使用

tf.map_fn

y = tf.constant([0., 1., 1., 2., 5., 4., 7., 9., 3., 3.], dtype=tf.float32)
categories = [[0,2,4,6,8],[1,3,5,7,9]]
c = tf.convert_to_tensor(categories, dtype=tf.float32)


cat_labels = tf.map_fn(                          # apply an operation on all of the elements of Y
     lambda x:tf.gather_nd(                      # get index of category: 0 or 1 or anything else
                        tf.cast(                 # cast dtype of the result of the inner function
                            tf.where(            # get index of the element of Y in categories
                                tf.equal(c, x)), # search an element of Y within categories
                                    dtype=tf.float32),[0,0]), y)

tf.print(cat_labels, summarize=-1)
# [0 1 1 0 1 0 1 1 1 1]
© www.soinside.com 2019 - 2024. All rights reserved.