将 JAX 跟踪器传递给 Huggingface CLIP 转换器以计算损失

问题描述 投票:0回答:1

我正在使用 JAX 处理视觉任务,并且我面临将中间 JAX TRACER 对象作为图像传递给 CLIP 模型以计算损失的问题。 CLIP 模型需要 NumPy 数组作为输入,因此 JAX TRACER 对象不直接兼容。

这是代码的简化版本:

def train_step(state, batch, rng):
    """Train Step"""
    inputs, targets = batch

    def clip_loss_fn(params):
        model_fn = lambda x: state.apply_fn({"params": params}, x)
        ray_origins, ray_directions = inputs
        rgb, *_ = perform_volume_rendering(
            model_fn, ray_origins, ray_directions, rng
        )

        # Here's where the issue arises
        clip_input = processor(
            text=["a bulldozer"], images=[rgb], return_tensors="jax", padding=True
        )
        outputs = img_txt_clip(**clip_input)
        logits_per_image = outputs.logits_per_image
        return jnp.mean(logits_per_image)

    train_loss, gradients = jax.value_and_grad(clip_loss_fn)(state.params)
    gradients = lax.pmean(gradients, axis_name="batch")
    new_state = state.apply_gradients(grads=gradients)
    train_loss = jnp.mean(train_loss)
    return train_loss, new_state

我已经尝试使用 FlaxCLIPModel 来与 JAX 兼容,但是将 JAX TRACER 对象作为图像传递给 CLIP 模型会引发错误。将 JAX TRACER 对象转换为 NumPy 数组是低效的。

对于将 JAX TRACER 对象有效地转换为 NumPy 数组或使 CLIP 模型接受 JAX TRACER 对象作为输入的任何建议或解决方案,我将不胜感激。

提前感谢您的帮助!

huggingface-transformers clip jax flax
1个回答
0
投票

这个问题在 JAX 的常见问题解答中得到了回答:JAX 常见问题解答:如何将 JAX Tracer 转换为 NumPy 数组?.

不可能将 JAX 跟踪器转换为 NumPy 数组,因为跟踪器不是数组,而是给定

shape
dtype
的所有可能的 NumPy 数组的抽象表示。

但是,当这个问题出现时,通常意味着提问者有兴趣在运行时调用主机端非 JAX 代码,这表现为无法将跟踪器转换为数组的错误。如果这实际上是您想要做的,那么最好的途径可能是

jax.pure_callback
,如JAX 中的外部回调中所述。但请注意,回调可能会导致执行缓慢,因为在主机通信和计算发生时,设备上的进程将暂停。

更好的选择是找到您尝试调用的函数的 JAX 兼容实现(即支持 JAX Tracers 作为输入的实现),它可以直接在设备上执行,而不需要回调主机。

© www.soinside.com 2019 - 2024. All rights reserved.