如何在 pytorch 中只优化输出的一个元素?

问题描述 投票:0回答:0

我正在尝试用 pytorch 实现 DQN 算法。由于我使用输出的 argmax 来生成新状态,所以我只能优化这个 argmax 动作。

我尝试生成两个假的X和Y向量,比如错误只会在x[i]处,例如:

# i = 1
x = [
  0,
  1,
  0,
  0
]
y = [
  0, 
  1.25,
  0,
  0
]
loss = loss_fn(x, y)

我认为这应该可行,因为其他输出的梯度将为 0,但我想知道是否有更好的解决方案。

python pytorch artificial-intelligence reinforcement-learning dqn
© www.soinside.com 2019 - 2024. All rights reserved.