import numpy as np
from IPython.display import display, HTML
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import pandas as pd
import re
import nltk
def word_similarity_scatter_plot(index_to_word,weight,plot_title,fig,axes):
labels = []
tokens = []
for key,value in index_to_word.items():
tokens.append(weight[key])
labels.append(value)
#TSNE : Compressing the weights to 2 dimensions to plot the data
tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)
new_values = tsne_model.fit_transform(tokens)
x = []
y = []
for value in new_values:
x.append(value[0])
y.append(value[1])
#plt.figure(figsize=(5, 5))
for i in range(len(x)):
axes.scatter(x[i],y[i])
axes.annotate(labels[i],
xy=(x[i], y[i]),
xytext=(5, 2),
textcoords='offset points',
ha='right',
va='bottom')
#plt.title(plot_title)
axes.set_title(plot_title, loc='center')
虽然我尝试在 google colab 中运行 word2vec 算法,但我没有得到单词相似度散点图和错误状态
AttributeError Traceback (most recent call last)
<ipython-input-46-ce4c95c25962> in <cell line: 447>()
450 loss_epoch.update( {dim: epoch_loss} )
451
--> 452 word_similarity_scatter_plot(
453 index_to_word,
454 weights_1[epochs -1],
2 frames
<ipython-input-46-ce4c95c25962> in word_similarity_scatter_plot(index_to_word, weight, plot_title, fig, axes)
387 #TSNE : Compressing the weights to 2 dimensions to plot the data
388 tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)
--> 389 new_values = tsne_model.fit_transform(tokens)
390
391 x = []
/usr/local/lib/python3.10/dist-packages/sklearn/manifold/_t_sne.py in fit_transform(self, X, y)
1116 """
1117 self._validate_params()
-> 1118 self._check_params_vs_input(X)
1119 embedding = self._fit(X)
1120 self.embedding_ = embedding
/usr/local/lib/python3.10/dist-packages/sklearn/manifold/_t_sne.py in _check_params_vs_input(self, X)
826
827 def _check_params_vs_input(self, X):
--> 828 if self.perplexity >= X.shape[0]:
829 raise ValueError("perplexity must be less than n_samples")
830
AttributeError: 'list' object has no attribute 'shape'
我应该怎么做才能解决错误?提前致谢。