在Matlab中收缩张量

问题描述 投票:2回答:4

我正在寻找一种在Matlab中收缩张量的两个指数的方法。

假设我有一个维度张量[17,10,17,12]我正在寻找一个函数,该函数用相同的索引求和第一维和第三维,并留下一个维度矩阵[10,12](类似于一个跟踪在两个方面)。

我目前正在研究张量网络,我主要使用“permute”和“reshape”这两个函数。如果一个人正在签订多个契约并且从一开始就不小心,那么人们最终可能会想要以[i,j,i,k]形式的一个张量收缩。

当然,人们可以以一种不会发生这种情况的方式返回并签订契约,但我仍然对更强大的解决方案感兴趣。

编辑:

有效的东西:

A = rand(17,10,17,12);
A_contracted = zeros(10,12);
for i = [1:10]
    for j = [1:12]
        for k = [1:17]
            A_contracted(i,j) = A_contracted(i,j) + A(k,i,k,j);
        end
    end

end
matlab tensor
4个回答
2
投票

这是一种方法:

A_contracted = permute(sum( ...
   A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);

以上使用implicit expansion以及在sum中同时沿多个维度操作的可能性,这是最近的Matlab特性。对于较旧的Matlab版本,

A_contracted = permute(sum(sum( ...
   A.*bsxfun(@eq, (1:size(A,1)).', reshape(1:size(A,3), 1, 1, [])),1),3), [2 4 1 3]);

2
投票

[我觉得我开始听起来像是一张破纪录......]

您应该首先将代码实现为循环,然后尝试使用permutereshape进行优化。但请注意,permute需要复制数据,因此往往会增加工作量,而不是减少工作量。 MATLAB的最新版本不再对循环缓慢,因此复制数据不再总是一个有用的黑客来加快速度。

例如,问题中的循环可以简化为:

A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
    A_contracted = A_contracted + squeeze(A(k,:,k,:));
end

(我也推广到任意大小)。

Luis' answer相比,我看到矢量化方法赢得了小阵列,例如OP中的一个(17x10x17x12),0.09 ms vs 0.19 ms。但是在非常短的时间内,它可能不值得努力。但是,对于较大的阵列(我试过17x100x17x120),我看到循环方法赢得1.3毫秒vs 2.6毫秒。

数据越多,使用普通旧循环的优势就越大。对于170x100x170x120,它是0.04秒对0.45秒。


测试代码:

A = rand(17,100,17,120);
assert(all(method2(A)==method1(A),'all'))
timeit(@()method1(A))
timeit(@()method2(A))

function A_contracted = method1(A)
A_contracted = permute(sum( ...
   A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);
end

function A_contracted = method2(A)
A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
    A_contracted = A_contracted + squeeze(A(k,:,k,:));
end
end

1
投票

我的教授提出了另一个涉及重塑和矩阵乘法的解决方案(在下面用方法3表示)。

  1. 采用合同索引大小的单位矩阵
  2. 将其重塑为矢量
  3. 重塑你想要相应收缩的张量
  4. 乘以向量和张量
  5. 重塑契约张量

示例代码与Luis's(method1)和Cris's回答(method2)相比较:

A = rand(17,10,17,10);

timeit(@()method1(A))
timeit(@()method2(A))
timeit(@()method3(A))

function A_contracted = method1(A)
A_contracted = permute(sum( ...
   A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);
end


function A_contracted = method2(A)
A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
    A_contracted = A_contracted + squeeze(A(k,:,k,:));
end
end


function A_contracted = method3(A)
sa_1 = size(A,1);
Unity = eye(size(A, 1));
Unity = reshape(Unity, [1,sa_1*sa_1]);
A1 = permute(A, [1,3,2,4]);
A2 = reshape(A1, [sa_1*sa_1, size(A1, 3)* size(A1,4)]);
UnA = Unity*A2;
A_contracted = reshape(UnA, [size(A1,3), size(A1,4)]);
end

方法3在方法1和方法2上比小维度支配一个数量级,并且对于更大的维度也胜过方法1,但是对于更大维度的for循环而言,它被一个数量级打败。

方法3具有(在某种程度上是个人的)优势,在我的物理课程中对应用程序更直观,因为收缩本身不是张量本身,而是指标。方法3可以容易地适用于结合该特征。


0
投票

满容易

squeeze(sum(sum(a,3),1))

sum(a,n)求和数组的第n维,squeeze除去任何单个维数

© www.soinside.com 2019 - 2024. All rights reserved.