我正在尝试为我的 dqn 代理使用优先重播缓冲区。我遇到的问题如下。
我有一个具有 (40, 40, 1) 状态表示的世界。当我尝试向缓冲区添加过渡时,它给了我:
RuntimeError: expand(torch.DoubleTensor{[40, 40, 1]}, size=[3]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (3)
优先重播缓冲区代码:
class PrioritizedReplayBuffer:
def __init__(self, state_size=3, action_size=1, buffer_size=10000, eps=1e-2, alpha=0.1, beta=0.1):
self.tree = SumTree(size=buffer_size)
# PER params
self.eps = eps
self.alpha = alpha
self.beta = beta
self.max_priority = eps
# transition: state, action, reward, next_state, done
self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
self.reward = torch.empty(buffer_size, dtype=torch.float)
self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.done = torch.empty(buffer_size, dtype=torch.int)
self.count = 0
self.real_size = 0
self.size = buffer_size
def add(self, transition):
state, action, reward, next_state, done = transition
# store transition index with maximum priority in sum tree
self.tree.add(self.max_priority, self.count)
# store transition in the buffer
self.state[self.count] = torch.as_tensor(state)
self.action[self.count] = torch.as_tensor(action)
self.reward[self.count] = torch.as_tensor(reward)
self.next_state[self.count] = torch.as_tensor(next_state)
self.done[self.count] = torch.as_tensor(done)
# update counters
self.count = (self.count + 1) % self.size
self.real_size = min(self.size, self.real_size + 1)
任何帮助将不胜感激。 谢谢