如何最好地将 kwarg 添加到控制执行计算的 ndarray 轴的函数中?

问题描述 投票:0回答:0

考虑一个对 n 维数组进行操作的函数。我想修改它以采用 (n+1)-d 数组,允许我指定一个轴,并让它对轴的所有 n-d 切片执行操作。例如,我想采用像

这样的功能
def some_operation(arr):
    #does something to the n-d array
    return result

让高维数组像

一样工作
def f(arr,axis=0):
   # the array is (n+1)-d, so we operate on a given axis

   n = range(arr.shape[axis])

   if axis == 0:
       result = np.array([some_operation(arr[j,:,:,...,:]) for j in n])
   elif axis == 1:
       result = np.array([some_operation(arr[:,j,:,...,:]) for j in n])
   # and so on

   return result

但是有一个辅助函数或装饰器来简化代码,比如

@along_axis
def some_operation(arr)
    #does same thing to a n-d array
    # but also now works on (n+1)-d arrays
    return result

我认为我正在寻找的是

np.apply_along_axis()
here,但具有处理 n-d 数组函数的能力。我还认为这篇关于如何动态切片指定轴的帖子可用于处理上面看到的一些索引问题。可以肯定的是,我强烈希望操作避免复制。我不确定我上面使用的列表理解是否是最好的做事方式。

这是我正在拍摄的一个工作示例。首先,这是我做的包装纸

import numpy as np
from functools import wraps

def along_axis(func):
"""
Wraps a function of a 2d array. Enables performing that function's operations along all 2d slices down an axis of a 3d array. When given a 2d array, the axis parameter is ignored.
"""
    @wraps(func)
    def wrapper(arr, *args, axis=0, **kwargs):
        if arr.ndim == 3:
            arr = np.moveaxis(arr,axis,-1)
            return np.array([func(arr[:,:,t],*args,**kwargs) for t in range(arr.shape[-1])])
        return func(arr,*args,**kwargs)
    return wrapper

现在我将把它应用到一个作用于二维数组的函数。这个特别是受这篇关于图像径向方式的帖子的启发。

@along_axis
def radial_mean(image):
"""
Computes radial means of a 2d image. After wrapping it, we can input 3d arrays as well.
"""
    # create a radial mesh at an image's center
    X,Y = np.meshgrid(np.arange(image.shape[1]),np.arange(image.shape[0]))
    R = np.sqrt((X - image.shape[1]//2)**2 + (Y - image.shape[0]//2)**2)
    
    # for each r in R
    r  = np.arange(int(R.max()))
    
    # compute the radial mean
    f = np.vectorize(lambda r : image[(R >= r-0.5) & (R < r+0.5)].mean())
    return f(r)

现在我可以像以前一样计算许多图像或单个图像的径向均值。

images = np.random.rand(1000,16,16)
print(radial_mean(images, axis=0).shape)
print(radial_mean(images[0], axis=0).shape)

> (100,11)
> (11,)

这种方法有意义吗?有没有更好的办法?谢谢。

python numpy numpy-ndarray python-decorators array-broadcasting
© www.soinside.com 2019 - 2024. All rights reserved.