我不明白pickle在加载文件时如何找到它的属性。我想在不同的模块中为 pickle 提供不同的类定义,但它对我不起作用。
这是一个例子。
# mod_a.py
import pickle
class A(object):
def __init__(self):
self.mod_a = True
#------------------------------------------------------------------
def load(file):
print('Unpickle',repr(A))
return pickle.load(file)
if __name__ == '__main__':
# Create a pickle file
a = A()
with open('a.ar','wb') as f:
pickle.dump(a,f,protocol=2)
和
# mod_b.py
import pickle
class A(object):
def __init__(self):
self.mod_b = True
#------------------------------------------------------------------
def load(file):
print('Unpickle',repr(A))
return pickle.load(file)
if __name__ == '__main__':
# Create a pickle file
a = A()
with open('b.ar','wb') as f:
pickle.dump(a,f,protocol=2)
最后
import mod_a
import mod_b
# from mod_a import A
# from mod_b import A
with open('a.ar','rb') as f:
print(mod_a.load(f))
我首先创建文件
a.ar
和 b.ar
,然后运行 mod_c.py
,这给了我这个错误:
Unpickle <class 'mod_a.A'>
Traceback (most recent call last):
File "C:\proj_py\GTC3\tmp\mod_c.py", line 8, in <module>
print(mod_a.load(f))
^^^^^^^^^^^^^
File "C:\proj_py\GTC3\tmp\mod_a.py", line 10, in load
return pickle.load(file)
^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'A' on <module '__main__' from 'C:\\proj_py\\GTC3\\tmp\\mod_c.py'>
我原以为pickle会在定义了
A
的模块中找到load()
的定义,这在print()语句中显示。然而,这并没有发生。
为了克服该错误,我可以取消注释
mod_c.py
中的导入语句之一。例如,如果我使用from mod_b import A
,则恢复mod_b
中定义的类A,这不是我调用mod_a.load(f)
的意图。
我可以修改
mod_a
和 mod_b
以便它们的 load()
函数恢复各自模块中的类吗?
编辑:------------------------- 按照 @KamilCuk 的提示,以下修改后的文件似乎可以解决问题。
# mod_a.py
import pickle
import sys
class _Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if hasattr(sys.modules[__name__],name):
return getattr(sys.modules[__name__],name)
else:
return super(_Unpickler,self).find_class(module, name)
class A(object):
def __init__(self):
self.mod_a = True
def load(file):
return _Unpickler(file).load()
if __name__ == '__main__':
a = A()
with open('a.ar','wb') as f:
pickle.dump(a,f,protocol=2)
和
# mod_b.py
import pickle
import sys
class _Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if hasattr(sys.modules[__name__],name):
return getattr(sys.modules[__name__],name)
else:
return super(_Unpickler,self).find_class(module, name)
class A(object):
def __init__(self):
self.mod_b = True
def load(file):
return _Unpickler(file).load()
if __name__ == '__main__':
a = A()
with open('b.ar','wb') as f:
pickle.dump(a,f,protocol=2)
与
# mod_c.py
import mod_a
import mod_b
with open('a.ar','rb') as f:
print( mod_a.load(f) )
# print( mod_b.load(f) )
现在
mod_a.load(f)
取消 mod_a
版本,mod_b.load(f)
取消 mod_b
版本。
(当然,这并不一定需要使用
A
的初始定义来解封对象——我认为 pickle 文件没有记录足够的信息来做到这一点。这段代码所做的就是明确 的特定定义A
已使用。)
我作为问题更新输入的解决方案对我来说效果很好,因此我建议将其作为答案。
每个模块都需要遵循这个模式
# mod_a.py
import pickle
import sys
class _Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if hasattr(sys.modules[__name__],name):
return getattr(sys.modules[__name__],name)
else:
return super(_Unpickler,self).find_class(module, name)
class A(object):
def __init__(self):
self.mod_b = True
def load(file):
return _Unpickler(file).load()
if __name__ == '__main__':
a = A()
with open('b.ar','wb') as f:
pickle.dump(a,f,protocol=2)
通过这种方式,客户端代码可以控制所搜索的模块,如下所示:
import mod_a
with open('a.ar','rb') as f:
print( mod_a.load(f) )