我正在尝试编写一个加权交叉熵损失来用 Jax 训练我的模型。但是,我认为我的输入维度存在一些问题。这是我的代码:
import jax.numpy as np
from functools import partial
import jax
@partial(np.vectorize, signature="(c),(),()->()")
def weighted_cross_entropy_loss(logits, label, weights):
one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[0])
return -np.sum(weights* logits*one_hot_label)
a=np.array([[1,2,3,4,5,6,7],[2,3,4,5,6,7,8]])
label=np.array([1,2])
weights=np.array([1,2,3,4,5,6,7])
print(weighted_cross_entropy_loss(a,label,weights))
这是我的错误消息:
Traceback (most recent call last):
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 147, in broadcast_shapes
return _broadcast_shapes_cached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 153, in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/PATH/test.py", line 15, in <module>
print(weighted_cross_entropy_loss(a,label,weights))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 274, in wrapped
broadcast_shape, dim_sizes = _parse_input_dimensions(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 123, in _parse_input_dimensions
broadcast_shape = lax.broadcast_shapes(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 149, in broadcast_shapes
return _broadcast_shapes_uncached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]
我对此还很陌生,有人可以告诉我发生了什么事吗?如有任何帮助,我们将不胜感激。
label
的长度为2,weights
的长度为8,这意味着它们不能一起广播。
从你的问题中我不清楚你的预期结果是什么,但你可以在 https://numpy.org/doc/stable/ 阅读有关 NumPy 中广播如何工作的更多信息(以及 JAX,它实现了 NumPy 的语义)用户/basics.broadcasting.html.