我编写了一个简单的包装器来向给定的 PyTorch 神经网络添加特殊方法。 虽然下面的实现对于字符串、列表等一般对象效果很好。当将其应用于
RecursionError
时,我得到了 torch.nn.Module
。看来在后一种情况下,self.instance
方法内部对__getattr__
的调用不成功,所以又回落到__getattr__
,导致无限循环(我也尝试了self.__dict__['instane']
,但没有运气)。
我认为这种行为源于
__getattr__
和 __setattr__
方法 torch.nn.Module
的实现,但在检查了它们的实现后,我仍然不知道如何实现。
我想详细了解发生了什么以及如何修复我的实现中的错误。
(我知道link中有类似的问题,但它没有回答我的问题。)
这是重现我的情况的最小实现。
import torch
class MyWrapper(torch.nn.Module):
def __init__(self, instance):
super().__init__()
self.instance = instance
def __getattr__(self, name):
print("trace", name)
return getattr(self.instance, name)
# Working example
obj = "test string"
obj_wrapped = MyWrapper(obj)
print(obj_wrapped.split(" ")) # trace split\n ['test', 'string']
# Failing example
net = torch.nn.Linear(12, 12)
net.test_attribute = "hello world"
b = MyWrapper(net)
print(b.test_attribute) # RecursionError: maximum recursion depth exceeded
b.instance # RecursionError: maximum recursion depth exceeded
该错误与
nn.Module
(或其任何超类/子类)没有太大关系。这是由于属性查找在 Python 类中的工作方式所致。
当您重写
__getattr__
类中的 MyWrapper
特殊方法时,当您在 self.instance
中执行 __getattr__
时,它会进入无限递归情况以获取名为 instance
的属性,它正在查找 __getattr__
一次又一次失败。
修复:
您可以从以下事实中获得帮助:Python 允许您使用超类的
__getattr__
方法(可以使用 super
方法轻松访问)。因此,如果我们使用超类的 __getattr__
来正确获取 instance
分辨率,那么我们仍然可以使用 getattr
来获取下一个 name
查找。例如,类似以下内容应该有效:
In [259]: class MyWrapper(torch.nn.Module):
...: def __init__(self, instance):
...: super().__init__()
...: self.instance = instance
...:
...: def __getattr__(self, name):
...: instance = super().__getattr__("instance")
...: return getattr(instance, name)
...:
In [260]: # Your failing example - now working
...: net = torch.nn.Linear(12, 12)
...: net.test_attribute = "hello world"
...: b = MyWrapper(net)
In [261]: print(b.test_attribute)
hello world