我正在尝试使用 PyTorch 制作 AI,但出现此错误:
RuntimeError: gather_out_cpu(): Expected dtype int64 for index
这是我的职能:
def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
next_outputs = self.model(batch_next_state).detach().max(1)[0]
target = self.gamma * next_outputs + batch_reward
td_loss = F.smooth_l1_loss(outputs, target)
self.optimizer.zero_grad()
td_loss.backward(retain_variables = True)
self.optimizer.step()
您需要在将
batch_action
张量传递给torch.gather
之前更改其数据类型。
def learn(...):
batch_action = batch_action.type(torch.int64)
outputs = ...
...
# or
outputs = self.model(batch_state).gather(1, batch_action.type(torch.int64).unsqueeze(1)).squeeze(1)
这似乎是一个非常晚的回应(因为它是 1 年 8 个月前被问到的)但我在我的一个实现中遇到了类似的错误。我从我目前正在上的课程中获得了代码。代码非常相似。 就我而言,我发现当我稍后在代码中调用 learn 函数时,传递给参数
batch_reward
和 batch_action
的参数发生了交换。这意味着 batch_reward
的值指向 batch_action
,反之亦然。
(以下是问题中的代码)
def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
next_outputs = self.model(batch_next_state).detach().max(1)[0]
(以下是我的代码)
def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
next_outputs = self.model(batch_next_state).detach().max(1)[0]
target = self.gamma*next_outputs + batch_reward
td_loss = F.smooth_l1_loss(outputs, target)
self.optimizer.zero_grad()
td_loss.backward(retain_graph = True)
self.optimizer.step()
def update(self, reward, new_signal):
new_state = torch.Tensor(new_signal).float().unsqueeze(0)
self.memory.push((self.last_state, new_state, torch.LongTensor([int(self.last_action)]), torch.Tensor([self.last_reward])))
action = self.select_action(new_state)
if len(self.memory.memory) > 100:
batch_state, batch_next_state, batch_action, batch_reward = self.memory.sample(100)
self.learn(batch_state, batch_next_state, batch_reward, batch_action)
self.last_action = action
self.last_state = new_state
self.last_reward = reward
self.reward_window.append(reward)
if len(self.reward_window) > 1000:
del self.reward_window[0]
return action
当从更新函数中调用学习函数时,
batch_reward
参数传递给 batch_action
,反之亦然。这反过来会导致您和我在执行代码时遇到的上述错误。所以我想错误的解决方案有点简单,其中将错误的参数传递给了错误的参数。
底线 - 解决方案:检查将参数传递给函数的代码。可能存在不匹配的参数和参数。