当我尝试将 Chroma 客户端传递给使用
OpenAIEmbeddings
的 Langchain 时,我收到一个 ValueError:
ValueError: Expected EmbeddingFunction.__call__ to have the following signature: odict_keys(['self', 'input']), got odict_keys(['self', 'args', 'kwargs'])
如何解决此错误?
该错误似乎与 langchain 的嵌入功能实现不符合 Chroma 最新更新引入的新要求有关,因为该问题是在升级 Chroma 后出现的。
我的代码:
import chromadb
from langchain_openai import OpenAIEmbeddings
client = chromadb.PersistentClient()
collection = client.get_or_create_collection(
name='chroma',
embedding_function=OpenAIEmbeddings()
)
我有 langchain==0.1.1,langchain-openai==0.0.3 和 chromadb==0.4.22。查看 github 问题,似乎将 chromadb 降级到 0.4.15 可以解决该问题,但由于这些库将在未来几个月内升级更多,我不想降级 chroma,而是找到一个适用于当前版本的解决方案。
从版本 0.4.16(?) 开始,Chroma 需要一个嵌入模型来定义返回嵌入列表的
__call__()
方法。它在错误中显示的迁移链接中说了同样多的内容。所以我发现的最简单的解决方案是创建一个继承自 OpenAIEmbeddings
的自定义类。然后我们就可以使用它的 embed_documents()
方法,而不是尝试编写我们自己的嵌入函数。
import chromadb
from langchain_openai import OpenAIEmbeddings
class CustomOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self, openai_api_key, *args, **kwargs):
super().__init__(openai_api_key=openai_api_key, *args, **kwargs)
def _embed_documents(self, texts):
return super().embed_documents(texts) # <--- use OpenAIEmbedding's embedding function
def __call__(self, input):
return self._embed_documents(input) # <--- get the embeddings
client = chromadb.PersistentClient()
collection = client.get_or_create_collection(
name='chroma',
embedding_function=CustomOpenAIEmbeddings(
openai_api_key="your very secret OpenAI api key"
) # <-- pass the new object instead of OpenAIEmbeddings()
)
使用 OpenAI 的 Embedding 对象也可以(可以通过
self.client
访问)。基本上,我们可以通过在循环中调用 CustomOpenAIEmbeddings
方法来定义 Embedding.create()
,如下所示此示例用例。
class CustomOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self, openai_api_key, *args, **kwargs):
super().__init__(openai_api_key=openai_api_key, *args, **kwargs)
def _embed_documents(self, texts):
embeddings = [
self.client.create(input=text, model="text-embedding-ada-002").data[0].embedding
for text in texts
]
return embeddings
def __call__(self, input):
return self._embed_documents(input)