如何强制Julia使用多线程进行矩阵乘法?

问题描述 投票:0回答:1

我想求一个相对较小的矩阵的幂,但这个矩阵由

Rational{BigInt}
类型的有理数组成。默认情况下,Julia 仅使用单个线程进行此类计算。我想检查使用多线程矩阵乘法是否会带来性能提升。我该怎么做?

下面是 32x32 矩阵的 4 次方的示例。如果我在 i7-12700k 上运行它,它只使用一个线程:

using Random
using LinearAlgebra

Random.seed!(42)

M = BigInt.(rand(Int128, 32, 32)) .// BigInt.(rand(Int128, 32, 32));

BLAS.set_num_threads(8)

@time M^4;

输出为:

19.976082 seconds (1.24 M allocations: 910.103 MiB, 0.19% gc time)

仅通过

Float64
和大矩阵,我就可以看到 Julia 正确地使用了多个线程。

N = rand(2^14,2^14)

@time N^4;

32.764584 seconds (1.71 M allocations: 4.113 GiB, 0.08% gc time, 1.14% compilation time)
multithreading julia linear-algebra blas arbitrary-precision
1个回答
0
投票

正如上面评论中指出的,BLAS 根本不参与此事。

既然有了,这里有一个非常简单的多线程函数:

julia> M3 = @time M^3;
  8.113582 seconds (1.24 M allocations: 644.222 MiB, 0.60% gc time)

julia> function mul_th(A::AbstractMatrix, B::AbstractMatrix)
         C = similar(A, size(A,1), size(B,2))
         size(A,2) == size(B,1) || error("sizes don't match up")
         Threads.@threads for i in axes(A,1)
           for j in axes(B,2)
             acc = zero(eltype(C))
             for k in axes(A,2)
               acc += A[i,k] * B[k,j]
             end
             C[i,j] = acc
           end
         end
         C
       end;

julia> M3 == @time mul_th(mul_th(M, M), M)
  2.313267 seconds (1.24 M allocations: 639.237 MiB, 2.29% gc time, 5.94% compilation time)
true

julia> Threads.nthreads()  # running e.g. julia -t4
4

各种软件包都可以为您编写此内容,例如

using Einsum; mul(A,B) = @vielsum C[i,k] := A[i,j] * B[j,k]
或者
using Tullio; mul(A,B) = @tullio C[i,k] := A[i,j] * B[j,k]  threads=10

更高的幂要慢得多,因为涉及的数字更大:

julia> M2 = @time M^2;
  0.133534 seconds (621.57 k allocations: 51.243 MiB, 3.22% gc time)

julia> M3 = @time M^3;
  8.084701 seconds (1.24 M allocations: 644.222 MiB, 0.64% gc time)

julia> M4 = @time M^4;  # uses Base.power_by_squaring
 20.915199 seconds (1.24 M allocations: 910.664 MiB, 0.84% gc time)

julia> @time M2 * M2;  # all the time is here:
 20.659935 seconds (621.57 k allocations: 859.421 MiB, 0.69% gc time)

julia> mean(x -> abs(x.den), M)
6.27823462640995725259881990669421930274423828125e+37

julia> mean(x -> abs(x.den), M2)
4.845326324048551470412760353413448348641588891008324543404627136353750441508056e+2349
© www.soinside.com 2019 - 2024. All rights reserved.