我编写的代码接受您输入的内容并将其与单词列表进行比较。当用户的单词与保存的单词之一匹配时,会将其标记为1。
def bag_of_words(s, words):
bag = [0 for _ in range(len(words))]
s_words = nltk.word_tokenize(s)
s_words = [stemmer.stem(word.lower()) for word in s_words]
for se in s_words:
for i, w in enumerate(words):
if w == se:
bag[i] = 1
print(bag)
return numpy.array(bag)
[当输入(var = s)为“ Hello”时,它返回:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
然后,我将其设置为通过使用model.predict()来预测要使用的类”>
results = model.predict(bag_of_words(inp, words))
但是我认为我的模型有问题,
tf.keras.backend.clear_session() model = tf.keras.Sequential() model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(8, activation=tf.nn.relu)) model.add(keras.layers.Dense(8, activation=tf.nn.relu)) model.add(keras.layers.Dense(len(output[0]), activation=tf.nn.softmax)) try: keras.models.load_model("savedModel") except: model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) model.fit(training, output, epochs=1000, batch_size=8) model.save("savedModel")
var“ results”的输出是一个大型数组:
[[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.21546566 0.17384247 0.2511855 0.11917856 0.13687664 0.10345111] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]]
它对每个可能的单词有46个不同的预测,但是我希望它对输入的单词仅具有一个预测。对于Ex:
[0.21546566 0.17384247 0.2511855 0.11917856 0.13687664 0.10345111]
我编写的代码接受您输入的内容并将其与单词列表进行比较。当用户的单词与保存的单词之一匹配时,会将其标记为1. def bag_of_words(s,words):bag ...
问题在于,通过移除Flatten()层,它应该可以按预期工作。拼合层用于拼合具有多个维度的输入(例如图像),但是您只有一个维度。