从另一个文件导入的函数很慢

问题描述 投票:0回答:1

我有一个这样的函数,如果我直接运行它而不从另一个

helper.py
文件导入它,它可以很好地加载。

我不确定是什么导致加载缓慢。

helper_file.py

from transformers import BertTokenizer 
def embed_answers(ans, length): 
    sentence_embeddings = []
    embeddings = BertTokenizer.from_pretrained('...')
    sentence_embeddings.extend(embeddings.encode(ans, max_length=length, padding='max_length') 
    return sentence embeddings 

def get_dataset(vec_type): 
    vec_dict = {"large": 1000, "medium": 500, "small": 150} 
    if vec_type.lower() not in vec_dict: 
        raise Exception("Invalid vector type!")
    df = pd.read_hdf('...', mode='r')
    vec_length = vec_dict[vec_type]
    df['embeddings_col'] = df['answer'].apply(embed_answers, length=vec_length) 
    return df 

当我从

get_dataset
文件导入并调用
main.py
时,它会在未完全加载的情况下崩溃。但直接从
main.py
运行该函数就可以了。

不确定问题是什么,感谢任何想法,谢谢!

python pandas dataframe apply
1个回答
0
投票

我不知道它是否可以解决您的问题,但每次调用时加载标记器

embed_answers
都是浪费资源(时间和内存)。尝试利用矢量化。

from transformers import BertTokenizer

def embed_answers(ans, length): 
    embeddings = BertTokenizer.from_pretrained('...')
    inputs = embeddings(ans, max_length=length, padding='max_length') 
    return inputs['input_ids']

def get_dataset(vec_type): 
    vec_dict = {"large": 1000, "medium": 500, "small": 150} 
    if vec_type.lower() not in vec_dict: 
        raise Exception("Invalid vector type!")
    df = pd.read_hdf('...', mode='r')
    vec_length = vec_dict[vec_type]
    df['embeddings_col'] = embed_answers(df['answer'], length=vec_length) 
    return df
© www.soinside.com 2019 - 2024. All rights reserved.