我有一批
x
形状为[k, width, height, channel_count]
的图像。该批次由函数 f
进行转换。结果具有相同的形状,我需要计算此变换的散度(即雅可比行列式的迹)。 (强调这一点:转换正在批量进行;它是神经网络的输出,我无法更改它)。
我正在使用
jax.jacfwd
来计算雅可比行列式。输出的形状为[k, width, height, channel_count, k, width, height, channel_count]
。这是第一个问题,我实际上需要每个图像的雅可比行列式。因此输出应该具有形状 [k, width, height, channel_count, width, height, channel_count]
。我不知道如何使用 jax.jacfwd
来实现此目的,因为我没有每个图像转换(只有每个批次转换)。
即使我有所需的输出,我如何计算跟踪(每个图像)?输出的形状应为
[k]
。我想我需要将雅可比输出重塑为[k, width * height * channel_count, width * height * channel_count]
,但我该怎么做?
备注:请注意,我还需要知道变换本身的值。由于没有
jax.val_and_jacfwd
,我将实际值作为辅助变量返回,这应该没问题。我的问题的解决方案应该仍然允许这样做。
编辑:
这是一个最小的可重现示例:
import jax
import torch
def f(x):
# s = model.apply(params, x) ### The actual code queries a network model
# for reproducibility, here is a simple transform:
s = jax.numpy.empty(x.shape)
for i in range(x.shape[0]):
for u in range(x.shape[1]):
for v in range(x.shape[2]):
for c in range(x.shape[3]):
s = s.at[i, u, v, c].set(x[i, u, v, c] * x[i, u, v, c])
return [s, s]
jac = jax.jacfwd(f, has_aux = True)
k = 3
width = 2
height = 2
channel_count = 1
x = torch.empty((k, width, height, channel_count))
### The actual loads x from a batch
it = 1
for i in range(x.shape[0]):
for u in range(x.shape[1]):
for v in range(x.shape[2]):
x[i, u, v, 0] = it
it += 1
f_jac, f_val = jac(x.numpy())
print(f_jac.shape)
输出为
(3, 2, 2, 1, 3, 2, 2, 1)
。这显然不是我想要的。我不想“区分一幅图像与另一幅图像的像素”。我真正想要的是“每张图像”雅可比行列式。因此,输出应该是 (3, 2, 2, 1, 2, 2, 1)
的形状。
但是让我再次强调一下:我仍然需要
f
需要一批图像,因为单独为批次中的每个图像调用model.apply
会非常慢。
编辑2:
顺便说一句,如果有一种直接的方法来计算
f
的散度 - 而无需之前计算整个雅可比行列式 - 我肯定会更喜欢。
编辑3:
仅关于“每批次雅可比行列式”的事情:这是一个更简单的示例:
import jax
import torch
def f(x):
s = jax.numpy.empty(x.shape)
s = s.at[0].set(x[0] * x[0])
s = s.at[1].set(x[1] * x[1])
s = s.at[2].set(x[2] * x[2])
return [s, s]
jac = jax.jacfwd(f, has_aux = True)
x = torch.empty(3)
x[0] = 1
x[1] = 2
x[2] = 3
f_jac, f_val = jac(x.numpy())
print(f_jac.shape)
print(f_jac)
print(f_val)
目标是
f_jac
具有形状[3, 1, 1]
。 (一维函数的雅可比行列式只是一个标量)。
听起来您想要雅可比行列式的批次尺寸,您可以通过使您的
f
在单个批次上工作,然后将雅可比式包装在vmap
中来实现:
def f(x):
s = jax.numpy.empty(x.shape)
for u in range(x.shape[0]):
for v in range(x.shape[1]):
for c in range(x.shape[2]):
s = s.at[u, v, c].set(x[u, v, c] * x[u, v, c])
return [s, s]
jac = jax.vmap(jax.jacfwd(f, has_aux = True))
然后我得到的输出是这样的:
f_jac, f_val = jac(x.numpy())
print(f_jac.shape) # (3, 2, 2, 1, 2, 2, 1)
print(f_val.shape) # (3, 2, 2, 1)