是否存在特定于tensorflow.not_equal()的形状/轴等效项?

问题描述 投票:0回答:1

我有一个形状为(100, 257, 121)的3级张量,我们将其称为y_pred。

我从形状为(257, 121)的位置提取了一个2级张量,我们称其为y_element。

是否有一种类似于tensorflow.not_equal()的方法,该方法将y_element与沿y_pred的轴0的每个其他等级2张量元素进行比较,并返回形状为(100)的布尔张量?

调用tensorflow.not_equal(y_pred, y_element)确实会返回布尔值的张量,但形状与y_pred相同,表明它的工作类似于将y_element张量与y_pred中的所有3109700值进行比较。

tensorflow graph tensorflow2.0 equality tensor
1个回答
1
投票
y_pred = tf.Variable(tf.ones((100, 257, 121)))
y_element = tf.Variable(tf.ones((257, 121)))
y_element[-1,:].assign(tf.zeros(121))

tf.reduce_all(tf.equal(y_pred, tf.expand_dims(y_element,0)), axis=[1,2])

这段代码基于您介绍的方法,它逐元素比较轴零上的2个张量。它返回一个形状等于第一个轴的数组(在本例中为100)。将等级3的张量的每个元素与等级2的张量进行比较。如果所有张量都相等,则返回true,否则返回false

© www.soinside.com 2019 - 2024. All rights reserved.