计算 JAX 卷积的雅可比行列式

问题描述 投票:0回答:1

我正在使用 JAX 来生成卷积

def gaussian_kernel(size: int, std: float):
    """Generates a 2D Gaussian kernel."""
    x, y = jnp.mgrid[-size:size+1, -size:size+1]
    g = jnp.exp(-(x**2 + y**2) / (2 * std**2))
    return g / g.sum()
    
def gaussian_blur(image, kernel_size=5, sigma=1.0):
    """Applies Gaussian blur to a 2D image."""
    kernel = gaussian_kernel(kernel_size, sigma)
    blurred_image = convolve2d(image, kernel, mode='same')
    return blurred_image 

基本上,只是普通的模糊。

从数学上讲,我不明白卷积的导数相对于输入像素会是什么样子。

例如,更改输入像素 y 对输出像素 x 有何影响。

我该如何定义这个?我如何从 JAX 中提取它。我什至不知道从哪里开始!

嗯,我希望能够使用 JAX 提取输出像素相对于输入像素的梯度。

python math jax autodiff
1个回答
0
投票

您可以使用

jax.jacobian
计算雅可比:

image_out = gaussian_blur(image)
image_jac = jax.jacobian(gaussian_blur)(image)

对于形状为

image
的输入
(M, N)
image_out
也将具有形状
(M, N)
,并且
image_jac
将具有形状
(M, N, M, N)

image_jac[i, j, k, l]
告诉您
image_out[i, j]
相对于
image[k, l]
的偏导数。

© www.soinside.com 2019 - 2024. All rights reserved.