Jax ValueError:不兼容的广播形状:形状

问题描述 投票:0回答:1

我正在尝试编写一个加权交叉熵损失来用 Jax 训练我的模型。但是,我认为我的输入维度存在一些问题。这是我的代码:

import jax.numpy as np
from functools import partial
import jax

@partial(np.vectorize, signature="(c),(),()->()")
def weighted_cross_entropy_loss(logits, label, weights):
    one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[0])
    return -np.sum(weights* logits*one_hot_label)

a=np.array([[1,2,3,4,5,6,7],[2,3,4,5,6,7,8]])
label=np.array([1,2])
weights=np.array([1,2,3,4,5,6,7])
print(weighted_cross_entropy_loss(a,label,weights))

这是我的错误消息:

Traceback (most recent call last):
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 147, in broadcast_shapes
    return _broadcast_shapes_cached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 153, in _broadcast_shapes_cached
    return _broadcast_shapes_uncached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/PATH/test.py", line 15, in <module>
    print(weighted_cross_entropy_loss(a,label,weights))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 274, in wrapped
    broadcast_shape, dim_sizes = _parse_input_dimensions(
                                 ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 123, in _parse_input_dimensions
    broadcast_shape = lax.broadcast_shapes(*shapes)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 149, in broadcast_shapes
    return _broadcast_shapes_uncached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]

我对此还很陌生,有人可以告诉我发生了什么事吗?如有任何帮助,我们将不胜感激。

python valueerror jax
1个回答
0
投票

label
的长度为2,
weights
的长度为8,这意味着它们不能一起广播。

从你的问题中我不清楚你的预期结果是什么,但你可以在 https://numpy.org/doc/stable/ 阅读有关 NumPy 中广播如何工作的更多信息(以及 JAX,它实现了 NumPy 的语义)用户/basics.broadcasting.html.

© www.soinside.com 2019 - 2024. All rights reserved.