用 numpy 操作替换通过域的循环和操作

问题描述 投票:0回答:1

我需要循环遍历 (i,j) 域并对它们进行一些数学运算,但循环本身需要很长时间才能运行,我想知道是否有一种方法可以用矢量化计算替换 for 循环?

import numpy as np

# Variable Init
tole = 1e-4
residual = 1
Nx = int(1e3)
Ny = Nx
sigma = np.zeros((Nx,Ny))
R = np.zeros((Nx,Ny))

# Sample constants
a = 10
b = 3
c = 2
d = 1
w = 0.5

# Output: Sigma and residual after convergence
while residual > tole:
    for i in range(1,Nx-1):
        for j in range(1,Ny-1):
            temp = a * sigma[j, i + 1] + b * sigma[j, i - 1] + c * sigma[j + 1, i] + d * sigma[j - 1, i]

            R[j, i] = abs(sigma[j,i] - temp)
            sigma[j, i] = (1 - w) * sigma[j,i] + w * temp
    residual = np.sum(R)/(Nx*Ny)


我当前的 numpy 解决方案:

import numpy as np

# Variable Init
tole = 1e-4
residual = 1
Nx = int(1e3)
Ny = Nx
sigma = np.zeros((Nx,Ny))
R = np.zeros((Nx,Ny))

# Sample constants
a = 10
b = 3
c = 2
d = 1
w = 0.5

# Output: Sigma and residual after convergence
while residual > tole:
    
    temp = a * sigma[:-1, 1:] + b * sigma[1:, :-1] + c * sigma[1:, :-1] + d * sigma[:-1, 1:]
    R = sigma[:-1,:-1] - temp
    sigma[:-1,:-1] = (1 - w) * sigma[:-1,:-1] + w * temp
    residual = np.sum(R)/(Nx*Ny)

它不应该在这个示例中显示,因为大部分已被删除,但在完整版本中,残差根本没有改变,这让我认为西格玛在某种程度上没有迭代自身。但不太确定问题出在哪里。

python numpy
1个回答
0
投票

对于具有因变量的这些类型的任务,我建议(这里是“sigma”)尝试

import numba
import numpy as np


@numba.njit
def calculate(
    sigma,
    R,
    residual=1.0,
    tole=1e-4,
    a=10.0,
    b=3.0,
    c=2.0,
    d=1.0,
    w=0.5,
):
    while residual > tole:
        for i in range(1, Nx - 1):
            for j in range(1, Ny - 1):
                temp = (
                    a * sigma[j, i + 1]
                    + b * sigma[j, i - 1]
                    + c * sigma[j + 1, i]
                    + d * sigma[j - 1, i]
                )
                R[j, i] = abs(sigma[j, i] - temp)
                sigma[j, i] = (1 - w) * sigma[j, i] + w * temp
        residual = np.sum(R) / (Nx * Ny)
    return residual


# Variable Init
tole = 1e-4
residual = 1

Nx = int(1e3)
Ny = Nx

sigma = np.zeros((Nx, Ny), dtype="float32")
R = np.zeros((Nx, Ny), dtype="float32")

r = calculate(sigma=sigma, R=R, residual=residual, tole=tole)
print(r)
© www.soinside.com 2019 - 2024. All rights reserved.