[使用jax,我尝试计算每个样本的梯度,对其进行处理,然后将其转换为标准形式以计算常规参数更新。我的工作代码如下
differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)
# some code
gradients_summed_over_samples = []
for layer in gradients:
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
gradients_summed_over_samples.append((dw, db))
其中gradients
的格式为list(tuple(DeviceArray(...), DeviceArray(...)), ...)
。
现在我尝试将循环重写为vmap(不确定最终是否会加速)
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
vmap(sum_samples)(gradients)
但是sum_samples
仅被调用一次,而不是为列表中的每个元素调用。
列表是问题还是我了解其他错误?
jax.vmap
将仅映射到jax数组输入上,而不映射为数组或元组列表的输入。此外,vmapped函数无法就地修改输入;因此,不能使用vmapped函数。函数应该返回一个值,该返回值将与其他返回值堆叠在一起以构造输出]
例如,您可以修改定义的函数并像这样使用它:
import jax.numpy as np from jax import random def sum_samples(layer): (dw, db) = layer (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0)) return np.array([dw, db]) key = random.PRNGKey(1701) data = random.uniform(key, (10, 2, 20)) result = vmap(sum_samples)(data) print(result.shape) # (10, 2)
旁注:如果您使用的是这种方法,则上面的vmapped函数可以更简洁地表示为:
def sum_samples(layer):
return layer.sum(1)