我正在使用 SentenceTransformers 库计算一些嵌入。然而,在对句子进行编码并在检查其值的总和时计算其嵌入时,我得到了不同的结果。例如:
在:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
transformer_models = [
'M-CLIP/M-BERT-Distil-40',
]
sentences = df['content'].tolist()
for transformer_model in tqdm(transformer_models, desc="Transformer Models"):
tqdm.write(f"Processing with Transformer Model: {transformer_model}")
model = SentenceTransformer(transformer_model)
embeddings = model.encode(sentences)
print(f"Embeddings Checksum for {transformer_model}:", np.sum(embeddings))
出:
Embeddings Checksum for M-CLIP/M-BERT-Distil-40: 1105.9185
或者
Embeddings Checksum for M-CLIP/M-BERT-Distil-40: 1113.5422
我注意到当我重新启动并清除 jupyter 笔记本的输出,然后重新运行完整笔记本时会发生这种情况。知道如何解决这个问题吗?
我尝试在嵌入计算之后和之前设置随机种子:
import torch
import numpy as np
import random
import tensorflow as tf
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
RANDOM_SEED = 42
# Setting seeds
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
# Ensuring PyTorch determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
transformer_models = ['M-CLIP/M-BERT-Distil-40']
sentences = df['content'].tolist()
for transformer_model in tqdm(transformer_models, desc="Transformer Models"):
# Set the seed again right before loading the model
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
tqdm.write(f"Processing with Transformer Model: {transformer_model}")
model = SentenceTransformer(transformer_model, device='cpu') # Force to use CPU
embeddings = model.encode(sentences, show_progress_bar=False) # Disable progress bar and parallel tokenization
print(f"Embeddings Checksum for {transformer_model}:", np.sum(embeddings))
但是我也遇到了同样不一致的行为。
更新
我现在尝试的并且似乎有效的是,现在我将所有计算的嵌入存储在文件中。然而,我发现奇怪的是,当这样做时,我得到了不同的结果。有人有过这样的经历吗?
这似乎是许多人面临的一个持续存在的问题。您可以关注此问题,并可能尝试一些可能对其他人有效的解决方案(例如,使用 .apply 方法而不是 .encode,尝试不同的精度设置,如 FP16、32、64 等。此外,您可以尝试指定/确保输入的一致标记化和填充。)