了解如何在 python 的多处理池中正确使用 sqlalchemy 的问题

问题描述 投票:0回答:1

我正在提取大量文件(~850,000)并使用 sqlalchemy 将它们加载到 postgres 数据库中。我可以在不使用多重处理的情况下将文件加载到数据库中,但加载需要一天多的时间,我希望它加载得更快。

我使用以下代码创建了一个进程池:

def load_submissions(submissions_path:str, db_url:str) -> None:
    submissions = {}
    for file in os.listdir(submissions_path):

        file_path = os.path.abspath(os.path.join(submissions_path, file))
        file = os.path.basename(file)
        cik = numbers_from_filename(file)
        # only take the left ten digits of the cik
        cik = cik[:10] if cik else None
        
        if cik:
            if cik in submissions:
                submissions[cik].append(file_path)
            else:
                submissions[cik] = [file_path]
        else:
            log.info(f'Could not find a CIK in the filename: {file_path}')
    log.debug(f'Submissions loaded')

    tasks = [(files, db_url) for files in submissions.values()]
    log.debug(f'Creating pool with {multiprocessing.cpu_count()} processes')
    try:
        num_processes = min(len(submissions), multiprocessing.cpu_count())
        with multiprocessing.get_context("spawn").Pool(processes=num_processes) as pool:
            log.debug(f'Loading submissions with multiprocessing.Pool()')
            pool.starmap(load_submission, tasks)

    except Exception as e:
        log.exception(e)

这会调用函数

load_submission
,该函数还使用
init
函数来创建数据库引擎。两者都显示在这里:

def init(db_url: str) -> sqlalchemy.engine.Engine:
    engine = create_engine(db_url, poolclass=sqlalchemy.pool.QueuePool)
    # check to see if the tables exists, if not create them
    if not inspect(engine).has_table(engine, 'Company'):
        Base.metadata.tables['Company'].create(engine)
    if not inspect(engine).has_table(engine, 'FormerName'):
        Base.metadata.tables['FormerName'].create(engine)
    if not inspect(engine).has_table(engine, 'Address'):
        Base.metadata.tables['Address'].create(engine)
    if not inspect(engine).has_table(engine, 'Ticker'):
        Base.metadata.tables['Ticker'].create(engine)
    if not inspect(engine).has_table(engine, 'ExchangeCompany'):
        Base.metadata.tables['ExchangeCompany'].create(engine)
    if not inspect(engine).has_table(engine, 'Filing'):
        Base.metadata.tables['Filing'].create(engine)
    return engine

def load_submission(submission_files: list[str], db_url: str) -> None:

    submission = extract_submission(submission_files)

    c = submission.get('cik')
    if not c:
        raise KeyError(f'No CIK found in submission')
    c = c.zfill(10)  

    log.debug(f'Process {os.getpid()} is loading submission for CIK: {c}')
    
    engine = db.init(db_url)

    log.debug(f'Engine initialized for CIK: {c}')

    with db.session(engine) as session:        
        log.debug(f'Session opened for CIK: {c}')

        company = db.Company(
            cik = c,
            name = submission.get('name'),
            sic = submission.get('sic'),
            entityType = submission.get('entityType'),
            insiderTransactionForOwnerExists = submission.get('insiderTransactionForOwnerExists'),
            insiderTransactionForIssuerExists = submission.get('insiderTransactionForIssuerExists'),
            ein = submission.get('ein'),
            description = submission.get('description'),
            website = submission.get('website'),
            investorWebsite = submission.get('investorWebsite'),
            category = submission.get('category'),
            fiscalYearEnd = submission.get('fiscalYearEnd'),
            stateOfIncorporation = submission.get('stateOfIncorporation'),
            phone = submission.get('phone'),
            flags = submission.get('flags')
        )
        session.merge(company)
        session.commit()
        log.debug(f'Company added to database: {c}')

        for ticker in submission['tickers']:
            t = db.Ticker(
                ticker = ticker,
                cik = c
            )
            session.merge(t)

        for exchange in submission['exchanges']:

            ec = db.ExchangeCompany(
                exchange = exchange,
                cik = c
            )
            session.merge(ec)

        addresses = submission['addresses']
        for description, address in addresses.items():
            a = db.Address(
                cik = c,
                description = description,
                street1 = address.get('street1'),
                street2 = address.get('street2'),
                city = address.get('city'),
                stateOrCountry = address.get('stateOrCountry'),
                zipCode = address.get('zipCode')
            )
            session.merge(a)

        former_names = submission['formerNames']
        for former_name in former_names:
            f = db.FormerName(
                formerName = former_name.get('name'),
                _from = datetime.fromisoformat(former_name.get('from')) if former_name.get('from') else None,
                to = datetime.fromisoformat(former_name.get('to')) if former_name.get('to') else None,
                cik = c
            )
            session.merge(f)
        session.commit()
        log.debug(f'Ticker, ExchangeCompany, Address, and FormerName added to database: {c}')
        session.close()
        log.debug(f'Session closed for CIK: {c}')

        df = pd.DataFrame(submission['filings']['recent'])

        # if the dataframe is empty, return
        if df.empty:
            engine.dispose()
            return

        # add url column and cik column to dataframe
        df['url'] = df.apply(lambda row: f'{ARCHIVES_URL}/{str(int(c))}/{row["accessionNumber"].replace("-", "")}/{row["accessionNumber"]}.txt', axis=1)
        df['cik'] = c

        # Convert date columns to datetime objects and handle empty strings
        date_columns = ['filingDate', 'reportDate', 'acceptanceDateTime']
        for col in date_columns:
            df[col] = pd.to_datetime(df[col], errors='coerce')

        # create data types dict to pass to to_sql
        dtypes = {
            'cik': sqltypes.VARCHAR(),
            'accessionNumber': sqltypes.VARCHAR(),
            'filingDate': sqltypes.Date(),
            'reportDate': sqltypes.Date(),
            'acceptanceDateTime': sqltypes.DateTime(),
            'act': sqltypes.VARCHAR(),
            'fileNumber': sqltypes.VARCHAR(),
            'filmNumber': sqltypes.VARCHAR(),
            'items': sqltypes.VARCHAR(),
            'size': sqltypes.INTEGER(),
            'isXBRL': sqltypes.Boolean,
            'isInlineXBRL': sqltypes.Boolean,
            'primaryDocument': sqltypes.VARCHAR(),
            'primaryDocumentDescription': sqltypes.VARCHAR()
        }

        # write dataframe to database
        df.to_sql('Filing', engine, if_exists='append', index=False, dtype=dtypes)
        log.debug(f'Filings added to database: {c}')

    log.debug(f'Session closed for CIK: {c}')
    engine.dispose()
    log.debug(f'Engine disposed for CIK: {c}')

调用

load_submissions
函数时出现以下错误:

错误:eminor.eminor:发送结果时出错:“”。原因:'AttributeError(“无法腌制本地对象'create_engine..connect'”)' 回溯(最近一次调用最后一次): 文件“C:\Users\lucky\Repos\lazy_prices\src\python\srcminorminor.py”,第 275 行,在 load_submissions 中 pool.starmap(load_submission, 任务) 文件“C:\Program Files\Python311\Lib\multiprocessing\pool.py”,第 375 行,星图中 返回 self._map_async(func, iterable, starmapstar, chunksize).get() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^ 文件“C:\Program Files\Python311\Lib\multiprocessing\pool.py”,第 774 行,在 get 中 提高自我价值 multiprocessing.pool.MaybeEncodingError:发送结果时出错:“”。原因:'AttributeError(“无法腌制本地对象'create_engine..connect'”)'

为什么它会抱怨 sqlalchemy 的

create_engine
函数出现酸洗错误?

我尝试将

init
函数移出
load_submissions
,看看在进程池之外实例化引擎并为每个进程提供相同的引擎副本是否可行,但也失败了。

我期望看到

load_submission
函数在多个进程中同时创建,并将文件并发加载到数据库中。

提前感谢您的支持!

python postgresql sqlalchemy multiprocessing pickle
1个回答
0
投票

TLDR:如果在多处理池中调用的代码中出现错误,则 pickle 将无法处理它并给出此消息。

所以这最终是一个简单的修复,但从错误消息中并不明显。

我调用 sqlalchemy 函数时出现错误

inspect
我的参数太多。这导致该函数引发在多处理池中引发的错误。

我对这个修复起作用的原因做了一些假设,并没有真正验证它是否是由pickle和多处理池如何捕获异常引起的,但通过修复代码中的错误,它使其他一切运行得很好.

© www.soinside.com 2019 - 2024. All rights reserved.