广播np.dot vs tf.matmul用于张量矩阵乘法(形状必须是等级2但是等级3错误)

问题描述 投票:1回答:2

假设我有以下张量:

X = np.zeros((3,201, 340))
Y = np.zeros((340, 28))

制作X,Y的点积是成功的numpy,并产生一个张量的形状(3,201,28)。但是对于tensorflow,我得到以下错误:Shape must be rank 2 but is rank 3 error ...

最小代码示例:

X = np.zeros((3,201, 340))
Y = np.zeros((340, 28))
print(np.dot(X,Y).shape) # successful (3, 201, 28)
tf.matmul(X, Y) # errornous

知道如何用tensorflow实现相同的结果吗?

python numpy tensorflow matrix-multiplication tensor
2个回答
3
投票

既然,你正在使用tensors,那么使用tensordot比使用np.dot更好(性能)。 NumPy允许它(numpy.dot)通过降低性能在tensors上工作,似乎tensorflow根本不允许它。

所以,对于NumPy,我们会使用np.tensordot -

np.tensordot(X, Y, axes=((2,),(0,)))

对于tensorflow,它将与tf.tensordot -

tf.tensordot(X, Y, axes=((2,),(0,)))

Related post to understand tensordot


1
投票

Tensorflow不允许像numpy那样乘以具有不同等级的矩阵。

为了解决这个问题,您可以重塑矩阵。这基本上通过“将矩阵堆叠”一个在另一个之上来投射具有等级2的等级3到1的矩阵。

你可以使用这个:tf.reshape(tf.matmul(tf.reshape(Aijk,[i*j,k]),Bkl),[i,j,l])

其中i,j和k是矩阵1的维数,k和l是矩阵2的维数。

取自here

© www.soinside.com 2019 - 2024. All rights reserved.