我正在尝试用 pytorch 实现 DQN 算法。由于我使用输出的 argmax 来生成新状态,所以我只能优化这个 argmax 动作。
我尝试生成两个假的X和Y向量,比如错误只会在x[i]处,例如:
# i = 1 x = [ 0, 1, 0, 0 ] y = [ 0, 1.25, 0, 0 ] loss = loss_fn(x, y)
我认为这应该可行,因为其他输出的梯度将为 0,但我想知道是否有更好的解决方案。