如何检查NamedTuple是否在列表中?

问题描述 投票:-1回答:1

我试图检查NamedTuple“ Transition”的斜率是否等于列表“ self.memory”中的任何对象。

这是我尝试运行的代码:

from typing import NamedTuple
import random
import torch as t

Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)


class ReplayMemory:

    def __init__(self, capacity):
        self.memory = []
        self.capacity = capacity
        self.position = 0

    def store(self, *args):
        print(self.memory == Transition(*args))
        if Transition(*args) in self.memory:
            return
    if len(self.memory) < self.capacity:
        self.memory.append(None)
    self.memory[self.position] = Transition(*args)
    ...

这是输出:

False
False

和我得到的错误:

   ...
        if Transition(*args) in self.memory:
    RuntimeError: bool value of Tensor with more than one value is ambiguous

在我看来这很奇怪,因为打印告诉我“ ==”操作返回一个布尔值。

如何正确完成?

谢谢

编辑:

* args是一个由]组成的元组

torch.Size([16, 12])
int
int
torch.Size([16, 12])
int
torch.Size([4])

我试图检查NamedTuple“ Transition”的斜率是否等于列表“ self.memory”中的任何对象。这是我尝试运行的代码:从输入import NamedTuple import import random import ...

python list if-statement namedtuple
1个回答
0
投票

我相信您应该明确定义平等。

© www.soinside.com 2019 - 2024. All rights reserved.