我正在使用 JAX 处理视觉任务,并且我面临将中间 JAX TRACER 对象作为图像传递给 CLIP 模型以计算损失的问题。 CLIP 模型需要 NumPy 数组作为输入,因此 JAX TRACER 对象不直接兼容。
这是代码的简化版本:
def train_step(state, batch, rng):
"""Train Step"""
inputs, targets = batch
def clip_loss_fn(params):
model_fn = lambda x: state.apply_fn({"params": params}, x)
ray_origins, ray_directions = inputs
rgb, *_ = perform_volume_rendering(
model_fn, ray_origins, ray_directions, rng
)
# Here's where the issue arises
clip_input = processor(
text=["a bulldozer"], images=[rgb], return_tensors="jax", padding=True
)
outputs = img_txt_clip(**clip_input)
logits_per_image = outputs.logits_per_image
return jnp.mean(logits_per_image)
train_loss, gradients = jax.value_and_grad(clip_loss_fn)(state.params)
gradients = lax.pmean(gradients, axis_name="batch")
new_state = state.apply_gradients(grads=gradients)
train_loss = jnp.mean(train_loss)
return train_loss, new_state
我已经尝试使用 FlaxCLIPModel 来与 JAX 兼容,但是将 JAX TRACER 对象作为图像传递给 CLIP 模型会引发错误。将 JAX TRACER 对象转换为 NumPy 数组是低效的。
对于将 JAX TRACER 对象有效地转换为 NumPy 数组或使 CLIP 模型接受 JAX TRACER 对象作为输入的任何建议或解决方案,我将不胜感激。
提前感谢您的帮助!
这个问题在 JAX 的常见问题解答中得到了回答:JAX 常见问题解答:如何将 JAX Tracer 转换为 NumPy 数组?.
不可能将 JAX 跟踪器转换为 NumPy 数组,因为跟踪器不是数组,而是给定
shape
和 dtype
的所有可能的 NumPy 数组的抽象表示。
但是,当这个问题出现时,通常意味着提问者有兴趣在运行时调用主机端非 JAX 代码,这表现为无法将跟踪器转换为数组的错误。如果这实际上是您想要做的,那么最好的途径可能是
jax.pure_callback
,如JAX 中的外部回调中所述。但请注意,回调可能会导致执行缓慢,因为在主机通信和计算发生时,设备上的进程将暂停。
更好的选择是找到您尝试调用的函数的 JAX 兼容实现(即支持 JAX Tracers 作为输入的实现),它可以直接在设备上执行,而不需要回调主机。