Python `__getattr__` + `torch.nn.Module` 产生无限递归

问题描述 投票:0回答:1

我编写了一个简单的包装器来向给定的 PyTorch 神经网络添加特殊方法。 虽然下面的实现对于字符串、列表等一般对象效果很好。当将其应用于

RecursionError
时,我得到了
torch.nn.Module
。看来在后一种情况下,
self.instance
方法内部对
__getattr__
的调用不成功,所以又回落到
__getattr__
,导致无限循环(我也尝试了
self.__dict__['instane']
,但没有运气)。

我认为这种行为源于

__getattr__
__setattr__
方法
torch.nn.Module
的实现,但在检查了它们的实现后,我仍然不知道如何实现。

我想详细了解发生了什么以及如何修复我的实现中的错误。

(我知道link中有类似的问题,但它没有回答我的问题。)

这是重现我的情况的最小实现。

import torch

class MyWrapper(torch.nn.Module):
    def __init__(self, instance):
        super().__init__()
        self.instance = instance

    def __getattr__(self, name):
        print("trace", name)
        return getattr(self.instance, name)

# Working example
obj = "test string"
obj_wrapped = MyWrapper(obj)
print(obj_wrapped.split(" ")) # trace split\n ['test', 'string']

# Failing example
net = torch.nn.Linear(12, 12)
net.test_attribute = "hello world"
b = MyWrapper(net)

print(b.test_attribute) # RecursionError: maximum recursion depth exceeded
b.instance # RecursionError: maximum recursion depth exceeded
python pytorch getattr infinite-recursion
1个回答
0
投票

该错误与

nn.Module
(或其任何超类/子类)没有太大关系。这是由于属性查找在 Python 类中的工作方式所致。

当您重写

__getattr__
类中的
MyWrapper
特殊方法时,当您在
self.instance
中执行
__getattr__
时,它会进入无限递归情况以获取名为
instance
的属性,它正在查找
__getattr__ 
一次又一次失败。

修复:

您可以从以下事实中获得帮助:Python 允许您使用超类的

__getattr__
方法(可以使用
super
方法轻松访问)。因此,如果我们使用超类的
__getattr__
来正确获取
instance
分辨率,那么我们仍然可以使用
getattr
来获取下一个
name
查找。例如,类似以下内容应该有效:


    In [259]: class MyWrapper(torch.nn.Module):
         ...:     def __init__(self, instance):
         ...:         super().__init__()
         ...:         self.instance = instance
         ...: 
         ...:     def __getattr__(self, name):
         ...:         instance = super().__getattr__("instance")
         ...:         return getattr(instance, name)
         ...:         
    
    In [260]: # Your failing example - now working
         ...: net = torch.nn.Linear(12, 12)
         ...: net.test_attribute = "hello world"
         ...: b = MyWrapper(net)
    
    In [261]: print(b.test_attribute)
    hello world

© www.soinside.com 2019 - 2024. All rights reserved.