特征张量的批处理矩阵乘法

问题描述 投票:0回答:2

我想通过获取大张量切片来批量矩阵乘法。

假设我有形状为 [N, 1, 4] 的 A,形状为 [N, 4, 4] 的 B。我想首先沿批量维度对它们进行切片,得到 [b, 1, 4] 和 [b, 4, 4] ,它们不一定是连续的,但通过批量进行矩阵乘法获得形状 [b, 4] 的结果。有没有办法使用 Eigen 来做到这一点?

c++ eigen eigen3
2个回答
0
投票

我不确定这是否是对特征张量执行批量矩阵乘法的有效方法,但一种解决方案可能是将张量页映射为矩阵并执行一般矩阵乘法:

#include <Eigen/Dense>
#include <unsupported/Eigen/CXX11/Tensor>

typedef Eigen::Tensor<double, 3> Tensor3d;

inline void batchedTensorMultiplication(const Tensor3d& A, const Tensor3d& B, const std::vector<int>& batchIndices, Tensor3d& C)
{
    Eigen::DenseIndex memStepA = A.dimension(0) * A.dimension(1);
    Eigen::DenseIndex memStepB = B.dimension(0) * B.dimension(1);
    Eigen::DenseIndex memStepC = C.dimension(0) * C.dimension(1);
    int outputBatchIndex = 0;

    for (int batchIndex : batchIndices)
    {
        Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> pageA(A.data() + batchIndex * memStepA, A.dimension(0), A.dimension(1));
        Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> pageB(B.data() + batchIndex * memStepB, B.dimension(0), B.dimension(1));
        Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> pageC(C.data() + outputBatchIndex * memStepC, C.dimension(0), C.dimension(1));

        outputBatchIndex++;

        pageC.noalias() = pageA * pageB;
    }
}

int main() 
{
    constexpr int N = 50;
    std::vector<int> batchIndices = { 0,1,2,3,4,9,10,11,12,13 };

    Tensor3d A(1, 4, N), B(4, 4, N), C(1, 4, (int)batchIndices.size());

    batchedTensorMultiplication(A, B, batchIndices, C);

    return 0;
}

0
投票

您可以尝试使用Eigen Tensor芯片和循环:

template <typename Device>
const Eigen::Tensor<float, 3> batched_matrix_multiplication(const Device &device, const Eigen::Tensor<float, 3>& A, const Eigen::Tensor<float, 3>& B) const
{
    typedef Eigen::Tensor<float, 3>::DimensionPair DimPair;
    Eigen::array<DimPair, 1> dims{DimPair(1, 0)};
    const int batch_size = A.dimension(0);
    const int dim1 = A.dimension(1);
    const int dim2 = B.dimension(2);
    Eigen::Tensor<float, 3> output(batch_size, dim1, dim2);
    for (int i = 0; i < batch_size; ++i) {
        output.chip<0>(i).device(device) = A.chip<0>(i).contract(B.chip<0>(i), dims);
    }
    return output;
}
© www.soinside.com 2019 - 2024. All rights reserved.