计算图像的雅可比行列式:如何正确重塑 numpy 数组?

问题描述 投票:0回答:1

我有一批

x
形状为
[k, width, height, channel_count]
的图像。该批次由函数
f
进行转换。结果具有相同的形状,我需要计算此变换的散度(即雅可比行列式的迹)。 (强调这一点:转换正在批量进行;它是神经网络的输出,我无法更改它)。

我正在使用

jax.jacfwd
来计算雅可比行列式。输出的形状为
[k, width, height, channel_count, k, width, height, channel_count]
。这是第一个问题,我实际上需要每个图像的雅可比行列式。因此输出应该具有形状
[k, width, height, channel_count, width, height, channel_count]
。我不知道如何使用
jax.jacfwd
来实现此目的,因为我没有每个图像转换(只有每个批次转换)。

即使我有所需的输出,我如何计算跟踪(每个图像)?输出的形状应为

[k]
。我想我需要将雅可比输出重塑为
[k, width * height * channel_count, width * height * channel_count]
,但我该怎么做?

备注:请注意,我还需要知道变换本身的值。由于没有

jax.val_and_jacfwd
,我将实际值作为辅助变量返回,这应该没问题。我的问题的解决方案应该仍然允许这样做。

编辑

这是一个最小的可重现示例:

import jax
import torch

def f(x):
    # s = model.apply(params, x) ### The actual code queries a network model
    # for reproducibility, here is a simple transform:
    s = jax.numpy.empty(x.shape)
    for i in range(x.shape[0]):
        for u in range(x.shape[1]):
            for v in range(x.shape[2]):
                for c in range(x.shape[3]):
                    s = s.at[i, u, v, c].set(x[i, u, v, c] * x[i, u, v, c])
    return [s, s]

jac = jax.jacfwd(f, has_aux = True)

k = 3
width = 2
height = 2
channel_count = 1

x = torch.empty((k, width, height, channel_count))

### The actual loads x from a batch
it = 1
for i in range(x.shape[0]):
    for u in range(x.shape[1]):
        for v in range(x.shape[2]):
            x[i, u, v, 0] = it
            it += 1

f_jac, f_val = jac(x.numpy())
print(f_jac.shape)

输出为

(3, 2, 2, 1, 3, 2, 2, 1)
。这显然不是我想要的。我不想“区分一幅图像与另一幅图像的像素”。我真正想要的是“每张图像”雅可比行列式。因此,输出应该是
(3, 2, 2, 1, 2, 2, 1)
的形状。

但是让我再次强调一下:我仍然需要

f
需要一批图像,因为单独为批次中的每个图像调用
model.apply
会非常慢。

编辑2

顺便说一句,如果有一种直接的方法来计算

f
的散度 - 而无需之前计算整个雅可比行列式 - 我肯定会更喜欢。

编辑3

仅关于“每批次雅可比行列式”的事情:这是一个更简单的示例:

import jax
import torch

def f(x):
    s = jax.numpy.empty(x.shape)
    s = s.at[0].set(x[0] * x[0])
    s = s.at[1].set(x[1] * x[1])
    s = s.at[2].set(x[2] * x[2])
    return [s, s]

jac = jax.jacfwd(f, has_aux = True)

x = torch.empty(3)
x[0] = 1
x[1] = 2
x[2] = 3

f_jac, f_val = jac(x.numpy())
print(f_jac.shape)
print(f_jac)
print(f_val)

目标是

f_jac
具有形状
[3, 1, 1]
。 (一维函数的雅可比行列式只是一个标量)。

python numpy neural-network jax automatic-differentiation
1个回答
0
投票

听起来您想要雅可比行列式的批次尺寸,您可以通过使您的

f
在单个批次上工作,然后将雅可比式包装在
vmap
中来实现:

def f(x):
    s = jax.numpy.empty(x.shape)
    for u in range(x.shape[0]):
        for v in range(x.shape[1]):
            for c in range(x.shape[2]):
                    s = s.at[u, v, c].set(x[u, v, c] * x[u, v, c])
    return [s, s]

jac = jax.vmap(jax.jacfwd(f, has_aux = True))

然后我得到的输出是这样的:

f_jac, f_val = jac(x.numpy())
print(f_jac.shape) # (3, 2, 2, 1, 2, 2, 1)
print(f_val.shape) # (3, 2, 2, 1)
© www.soinside.com 2019 - 2024. All rights reserved.