使用多个条件优化 SmoothStep 的 Python 函数以进行 Numba 矢量化

问题描述 投票:0回答:1

我实现了一个使用 SmoothStep 创建平滑矩形函数的函数:

import numpy as np
from numba import jit, njit
import matplotlib.pyplot as plt

@njit
def GenSmoothStep( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
    
    lowClip  = max(lowVal - rollOffWidth, 0)
    highClip = min(highVal + rollOffWidth, 1)

    for ii in range(vX.size):
        valX = vX.flat[ii]
        if valX < lowClip:
            vY.flat[ii] = 0.0
        elif valX < lowVal:
            # Smoothstep [lowClip, lowVal]
            valXN = (lowVal - valX) / (lowVal - lowClip)
            vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
        elif valX > highClip:
            vY.flat[ii] = 0.0
        elif valX > highVal:
            # Smoothstep [highVal, highClip]
            valXN = (valX - highVal) / (highClip - highVal)
            vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
        else:
            vY.flat[ii] = 1.0

numGridPts = 1000

lowVal  = 0.15
highVal = 0.75
rollOffWidth = 0.3

vX = np.linspace(0, 1, numGridPts)
vY = np.empty_like(vX)

GenSmoothStep(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)

plt.plot(vX, vY)

该函数包含几个条件,这意味着矢量化不友好。
我想知道是否有一些简单的步骤可以使该功能对 Numba 更加友好。

python numpy performance vectorization numba
1个回答
0
投票

IIUC你只是想结合

smoothstep
:

import matplotlib.pyplot as plt
import numpy as np
from numba import njit


@njit
def smoothstep(edge0, edge1, x):
    x = np.clip((x - edge0) / (edge1 - edge0), 0, 1)
    return x * x * (3.0 - 2.0 * x)


numGridPts = 1000

lowVal = 0.15
highVal = 0.75

vX = np.linspace(0, 1, numGridPts)
vY = smoothstep(0, lowVal, vX) * (1 - smoothstep(highVal, 1, vX))

plt.plot(vX, vY)
plt.show()

显示此图:

© www.soinside.com 2019 - 2024. All rights reserved.