我试图检查NamedTuple“ Transition”的斜率是否等于列表“ self.memory”中的任何对象。
这是我尝试运行的代码:
from typing import NamedTuple
import random
import torch as t
Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)
class ReplayMemory:
def __init__(self, capacity):
self.memory = []
self.capacity = capacity
self.position = 0
def store(self, *args):
print(self.memory == Transition(*args))
if Transition(*args) in self.memory:
return
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
...
这是输出:
False
False
和我得到的错误:
...
if Transition(*args) in self.memory:
RuntimeError: bool value of Tensor with more than one value is ambiguous
在我看来这很奇怪,因为打印告诉我“ ==”操作返回一个布尔值。
如何正确完成?
谢谢
编辑:
* args是一个由]组成的元组
torch.Size([16, 12])
int
int
torch.Size([16, 12])
int
torch.Size([4])
我试图检查NamedTuple“ Transition”的斜率是否等于列表“ self.memory”中的任何对象。这是我尝试运行的代码:从输入import NamedTuple import import random import ...
我相信您应该明确定义平等。