我在使用 PyTorch 时遇到此错误:RuntimeError: gather_out_cpu(): Expected dtype int64 for index

问题描述 投票:0回答:2

我正在尝试使用 PyTorch 制作 AI,但出现此错误:

RuntimeError: gather_out_cpu(): Expected dtype int64 for index

这是我的职能:

def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
    outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
    next_outputs = self.model(batch_next_state).detach().max(1)[0]
    target = self.gamma * next_outputs + batch_reward
    td_loss = F.smooth_l1_loss(outputs, target)
    self.optimizer.zero_grad()
    td_loss.backward(retain_variables = True)
    self.optimizer.step()
python-3.x pytorch artificial-intelligence
2个回答
4
投票

您需要在将

batch_action
张量传递给
torch.gather
之前更改其数据类型。

def learn(...):
    batch_action = batch_action.type(torch.int64) 
    outputs = ...
    ...

# or
outputs = self.model(batch_state).gather(1, batch_action.type(torch.int64).unsqueeze(1)).squeeze(1)

0
投票

这似乎是一个非常晚的回应(因为它是 1 年 8 个月前被问到的)但我在我的一个实现中遇到了类似的错误。我从我目前正在上的课程中获得了代码。代码非常相似。 就我而言,我发现当我稍后在代码中调用 learn 函数时,传递给参数

batch_reward
batch_action
的参数发生了交换。这意味着
batch_reward
的值指向
batch_action
,反之亦然。

(以下是问题中的代码)

 def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
next_outputs = self.model(batch_next_state).detach().max(1)[0]

(以下是我的代码)

 def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
    outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
    next_outputs = self.model(batch_next_state).detach().max(1)[0]
    target = self.gamma*next_outputs + batch_reward
    td_loss = F.smooth_l1_loss(outputs, target)
    self.optimizer.zero_grad()
    td_loss.backward(retain_graph = True)
    self.optimizer.step()

def update(self, reward, new_signal):
    new_state = torch.Tensor(new_signal).float().unsqueeze(0)
    self.memory.push((self.last_state, new_state, torch.LongTensor([int(self.last_action)]), torch.Tensor([self.last_reward])))
    action = self.select_action(new_state)
    if len(self.memory.memory) > 100:
        batch_state, batch_next_state, batch_action, batch_reward = self.memory.sample(100)
        self.learn(batch_state, batch_next_state, batch_reward, batch_action)
    self.last_action = action
    self.last_state = new_state
    self.last_reward = reward
    self.reward_window.append(reward)
    if len(self.reward_window) > 1000:
        del self.reward_window[0]
    return action

当从更新函数中调用学习函数时,

batch_reward
参数传递给
batch_action
,反之亦然。这反过来会导致您和我在执行代码时遇到的上述错误。所以我想错误的解决方案有点简单,其中将错误的参数传递给了错误的参数。 底线 - 解决方案:检查将参数传递给函数的代码。可能存在不匹配的参数和参数。

© www.soinside.com 2019 - 2024. All rights reserved.