我有文本数据来执行#sentimentanalysis。它有三个标签-1、0、1。我想创建嵌入并获取数据的质心,以便可以根据基于余弦相似度的质心分配新数据。有任何想法吗?我正在尝试使用 MPNET 创建嵌入。
这是我尝试过的代码
import pandas as pd
从 Transformers 导入 AutoTokenizer、AutoModel
type
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Function to classify data based on cosine similarity and threshold
def classify_data(embeddings, centroids, threshold):
similarity_scores = torch.cosine_similarity(embeddings.unsqueeze(1), centroids.unsqueeze(0), dim=2)
return similarity_scores.argmax(dim=1)
# Load model from HuggingFace Hub
tokenizer_mpnet = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model_mpnet = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
# Example DataFrame with 'sentence' and 'label' columns
data = {
'sentence': [
'This is a positive sentence',
'Each sentence is neutral',
'Another example negative sentence',
'More sentences to test',
],
'label': [1, 0, -1, 0], # Assuming sentiment labels: 1 for positive, 0 for neutral, -1 for negative
}
df = pd.DataFrame(data)
# Split the dataset into training and testing sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# Tokenize sentences using MPNet
encoded_input_mpnet = tokenizer_mpnet(train_df['sentence'].tolist(), padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings using MPNet
with torch.no_grad():
model_output_mpnet = model_mpnet(**encoded_input_mpnet)
# Perform pooling for embeddings using MPNet
sentence_embeddings_mpnet = mean_pooling(model_output_mpnet, encoded_input_mpnet['attention_mask'])
# Normalize embeddings
sentence_embeddings_mpnet = F.normalize(sentence_embeddings_mpnet, p=2, dim=1)
# Compute the centroids of each class for MPNet
centroids_mpnet = []
for label in [-1, 0, 1]:
centroid = sentence_embeddings_mpnet[train_df['label'] == label].mean(dim=0)
centroids_mpnet.append(centroid)
# Tokenize testing sentences using MPNet
encoded_input_test = tokenizer_mpnet(test_df['sentence'].tolist(), padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings for testing sentences using MPNet
with torch.no_grad():
model_output_test = model_mpnet(**encoded_input_test)
# Perform pooling for testing embeddings using MPNet
sentence_embeddings_test = mean_pooling(model_output_test, encoded_input_test['attention_mask'])
# Normalize testing embeddings
sentence_embeddings_test = F.normalize(sentence_embeddings_test, p=2, dim=1)
# Classify new data based on cosine similarity and threshold for MPNet
threshold_mpnet = 0.33
predicted_labels_mpnet = classify_data(sentence_embeddings_test, torch.stack(centroids_mpnet), threshold_mpnet)
# Calculate precision, recall, and F1-score for each class for MPNet
precision_mpnet, recall_mpnet, f1_score_mpnet, _ = precision_recall_fscore_support(test_df['label'], predicted_labels_mpnet, average=None)
print("MPNet Precision:")
print(precision_mpnet)
print("MPNet Recall:")
print(recall_mpnet)
print("MPNet F1-score:")
print(f1_score_mpnet)
getting keyerror at line : label in [-1, 0, 1]:
centroid = sentence_embeddings_mpnet[train_df['label'] == label].mean(dim=0)
我应该使用标签吗?
嗨 Wellcom 堆栈溢出
某些类可能不存在于训练数据中,因此当您尝试计算该类的 centroid 时,它会抛出
KeyError
。
您可以通过检查此代码来避免这种情况::
centroids_mpnet = []
for label in [-1, 0, 1]:
if len(train_df[train_df['label'] == label]) > 0:
centroid = sentence_embeddings_mpnet[train_df['label'] == label].mean(dim=0)
centroids_mpnet.append(centroid)
else:
print(f"No instances of class {label} found in training data.")
还有一个:错误可能是由于“标签”列的类型与列表[-1,0,1]的元素不匹配而发生的。如果您的“标签”列是字符串类型,则应使用 ['-1', '0', '1'] 而不是 [-1, 0, 1]
如果您仍然看到错误,请查看数据框。您可能有一些意外的数据导致了问题。
我希望这对你有帮助