pickle加载时如何查找属性?

问题描述 投票:0回答:1

我不明白pickle在加载文件时如何找到它的属性。我想在不同的模块中为 pickle 提供不同的类定义,但它对我不起作用。

这是一个例子。

# mod_a.py
import pickle

class A(object):
    def __init__(self):
        self.mod_a = True
        
#------------------------------------------------------------------     
def load(file):
    print('Unpickle',repr(A))
    return pickle.load(file)
    
if __name__ == '__main__':
    # Create a pickle file
    a = A()
    with open('a.ar','wb') as f:
        pickle.dump(a,f,protocol=2)

# mod_b.py
import pickle

class A(object):
    def __init__(self):
        self.mod_b = True
        
#------------------------------------------------------------------     
def load(file):
    print('Unpickle',repr(A))
    return pickle.load(file)
    
if __name__ == '__main__':
    # Create a pickle file    
    a = A()
    with open('b.ar','wb') as f:
        pickle.dump(a,f,protocol=2)

最后

import mod_a
import mod_b
 
# from mod_a import A 
# from mod_b import A 

with open('a.ar','rb') as f:
    print(mod_a.load(f))

我首先创建文件

a.ar
b.ar
,然后运行
mod_c.py
,这给了我这个错误:

Unpickle <class 'mod_a.A'>
Traceback (most recent call last):
  File "C:\proj_py\GTC3\tmp\mod_c.py", line 8, in <module>
    print(mod_a.load(f))
          ^^^^^^^^^^^^^
  File "C:\proj_py\GTC3\tmp\mod_a.py", line 10, in load
    return pickle.load(file)
           ^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'A' on <module '__main__' from 'C:\\proj_py\\GTC3\\tmp\\mod_c.py'>

我原以为pickle会在定义了

A
的模块中找到
load()
的定义,这在print()语句中显示。然而,这并没有发生。

为了克服该错误,我可以取消注释

mod_c.py
中的导入语句之一。例如,如果我使用
from mod_b import A 
,则恢复
mod_b
中定义的类A,这不是我调用
mod_a.load(f)
的意图。

我可以修改

mod_a
mod_b
以便它们的
load()
函数恢复各自模块中的类吗?

编辑:------------------------- 按照 @KamilCuk 的提示,以下修改后的文件似乎可以解决问题。

# mod_a.py
import pickle
import sys

class _Unpickler(pickle.Unpickler):

    def find_class(self, module, name):
        if hasattr(sys.modules[__name__],name):
            return getattr(sys.modules[__name__],name)
        else:
            return super(_Unpickler,self).find_class(module, name)

class A(object):
    def __init__(self):
        self.mod_a = True
        
def load(file):
    return _Unpickler(file).load()
    
if __name__ == '__main__':
    
    a = A()
    with open('a.ar','wb') as f:
        pickle.dump(a,f,protocol=2)

# mod_b.py
import pickle
import sys

class _Unpickler(pickle.Unpickler):

    def find_class(self, module, name):
        if hasattr(sys.modules[__name__],name):
            return getattr(sys.modules[__name__],name)
        else:
            return super(_Unpickler,self).find_class(module, name)
            
class A(object):
    def __init__(self):
        self.mod_b = True
        
def load(file):
    return _Unpickler(file).load()
    
if __name__ == '__main__':
    
    a = A()
    with open('b.ar','wb') as f:
        pickle.dump(a,f,protocol=2)

# mod_c.py
import mod_a
import mod_b

with open('a.ar','rb') as f:
    print( mod_a.load(f) )
    # print( mod_b.load(f) )

现在

mod_a.load(f)
取消
mod_a
版本,
mod_b.load(f)
取消
mod_b
版本。

(当然,这并不一定需要使用

A
的初始定义来解封对象——我认为 pickle 文件没有记录足够的信息来做到这一点。这段代码所做的就是明确
 的特定定义A
已使用。)

python scope pickle
1个回答
0
投票

我作为问题更新输入的解决方案对我来说效果很好,因此我建议将其作为答案。

每个模块都需要遵循这个模式

# mod_a.py 

import pickle
import sys

class _Unpickler(pickle.Unpickler):

    def find_class(self, module, name):
        if hasattr(sys.modules[__name__],name):
            return getattr(sys.modules[__name__],name)
        else:
            return super(_Unpickler,self).find_class(module, name)
            
class A(object):
    def __init__(self):
        self.mod_b = True
        
def load(file):
    return _Unpickler(file).load()
    
if __name__ == '__main__':
    
    a = A()
    with open('b.ar','wb') as f:
        pickle.dump(a,f,protocol=2)

通过这种方式,客户端代码可以控制所搜索的模块,如下所示:

import mod_a

with open('a.ar','rb') as f:
    print( mod_a.load(f) )
© www.soinside.com 2019 - 2024. All rights reserved.