目标: 我希望能够动态地导入一个子包中的所有功能与“直接调用”
用法: 我的项目:
project/
|-- main.py
|-- src/
|---- __init__.py
|---- foo.py
|---- bar.py
foo.py
只有一个功能:
def foo_funct():
print("foo")
bar.py
只有一个功能:
def bar_funct():
print("bar")
最后main.py
:
from src import *
(...)
foo_funct()
bar_funct()
(...)
评论:
__init__.py
是这样的
import os
__all__ = [i.replace(".py", "") for i in os.listdir(os.getcwd()+"/src/") if "__" not in i]
我将能够调用foo.foo_funct()
或bar.bar_funct()
但不foo_funct()
或bar_funct()
__init__.py
是这样的:
from src.foo import *
from src.bar import *
我将能够调用foo_funct()
或bar_funct()
但每一个新的子包,我将不得不修改我的__init__.py
from src import *
不是最Python的方法,并且假设它可能是非常危险的有因可能的命名冲突,如a.tree_funct()
和b.tree_funct()
直接调用,有没有达到我的目标的方法?就个人而言,我更喜欢让事情变得明确了,只是导入是包API的一部分进入__init__
明确的名称。您的项目将不会改变如此迅速,动态地导入所有的东西都__init__.py
将是一个节省时间。
但是,如果你想这样做,那么你有几个选择这里。如果您需要支持Python版本比3.7老,那么你可以通过在globals()
dictionary戳更新包命名空间。列出所有.py
文件,并使用importlib.import_module()
导入(或__import__()
如果你需要2.7之前支持的Python版本):
__all__ = []
def _load_all_submodules():
from pathlib import Path
from importlib import import_module
g = globals()
package_path = Path(__file__).resolve().parent
for pyfile in package_path.glob('*.py'):
module_name = pyfile.stem
if module_name == '__init__':
continue
module = import_module(f'.{module_name}', __package__)
names = getattr(
module, '__all__',
(n for n in dir(module) if n[:1] != '_'))
for name in names:
g[name] = getattr(module, name)
__all__.append(name)
_load_all_submodules()
del _load_all_submodules
上述保持命名空间干净;该_load_all_submodules()
功能运行后它从包装中取出。它使用__file__
全球以确定当前路径,并从那里发现任何同级.py
文件。
如果你只需要支持的Python 3.7及以上,你可以定义module-level __getattr__()
and __dir__()
functions实现动态查询。
在你的包__init__.py
文件使用这些钩子可能看起来像:
def _find_submodules():
from pathlib import Path
from importlib import import_module
package_path = Path(__file__).resolve().parent
return tuple(p.stem for p in package_path.glob('*.py') if p.stem != '__init__')
__submodules__ = _find_submodules()
del _find_submodules
def __dir__():
from importlib import import_module
names = []
for module_name in __submodules__:
module = import_module(f'.{module_name}', __package__)
try:
names += module.__all__
except AttributeError:
names += (n for n in dir(module) if n[:1] != '_')
return sorted(names)
__all__ = __dir__()
def __getattr__(name):
from importlib import import_module
for module_name in __submodules__:
module = import_module(f'.{module_name}', __package__)
try:
# cache the attribute so future imports don't call __getattr__ again
obj = getattr(module, name)
globals()[name] = obj
return obj
except AttributeError:
pass
raise AttributeError(name)