tensorflowjs/discord.js 没有权重值

问题描述 投票:0回答:1

我想编写一个小型的不和谐聊天机器人来进行一些闲聊。我遇到的问题是我的训练代码没有在我的权重(模型)中设置任何值。

train.js

const tf = require('@tensorflow/tfjs');
const fs = require('fs');

// Define the training data
const trainingData = [
  { input: 'hi', output: 'Hello!' },
  { input: 'hello', output: 'Hi there!' },
  { input: 'hey', output: 'Hey there!' },
  { input: 'hi there', output: 'Hello!' },
  { input: 'what\'s up', output: 'Not much, how about you?' },
  { input: 'good morning', output: 'Good morning to you too!' },
  { input: 'good afternoon', output: 'Good afternoon to you too!' },
  { input: 'good evening', output: 'Good evening to you too!' },
  { input: 'howdy', output: 'Howdy there partner!' },
  { input: 'yo', output: 'What\'s up?' },
  { input: 'hey there', output: 'Hey there! How can I help you today?' }
];

// Define the neural network model
const model = tf.sequential();
model.add(tf.layers.dense({ units: 8, inputShape: [1], activation: 'relu' }));
model.add(tf.layers.dense({ units: 8, activation: 'relu' }));
model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
model.compile({ optimizer: 'adam', loss: 'binaryCrossentropy' });

// Train the model
const inputs = trainingData.map((item) => item.input);
const outputs = trainingData.map((item) => item.output);
const xs = tf.tensor2d(inputs.map((input) => [input.charCodeAt(0)]));
const ys = tf.tensor2d(outputs.map((output) => [output === 'Hello!' ? 1 : 0]));

async function train() {
  console.log('Training model...');
  for (let i = 0; i < inputs.length; i++) {
    console.log(`Training input: ${inputs[i]}, output: ${outputs[i]}`);
  }
  await model.fit(xs, ys, {
    epochs: 500,
    callbacks: {
      onEpochEnd: (epoch, logs) => {
        console.log(`Epoch ${epoch}: loss = ${logs.loss}`);
      }
    }
  });
  const weights = model.getWeights();
  console.log('Model weights:', weights);
  const metadata = { weights };
  fs.writeFile('./model.json', JSON.stringify(metadata), (err) => {
    if (err) {
      console.error('Error saving model:', err);
    } else {
      console.log('Model saved to file!');
    }
  });
}

async function run() {
  console.log('Running training script...');
  await train();
}
run();

权重的输出如下所示:

{
    kept: false,
    isDisposedInternal: false,
    shape: [ 1 ],
    dtype: 'float32',
    size: 1,
    strides: [],
    dataId: {},
    id: 11,
    rankType: '1',
    trainable: true,
    name: 'dense_Dense3/bias'
  }

如您所见,没有任何价值。这是我尝试访问它的代码:

client.on('messageCreate', async (message) => {
  if (message.channel.id === '1084912525914681354') if(message.author.id !== client.user.id) {
    ////code to get answer from model.json with tensorflow and string similarity

    const modelPath = './model.json';
    const modelData = fs.readFileSync(modelPath);
    const modelJSON = JSON.parse(modelData);
    const weights = modelJSON.weights.map((weight) => {
      console.log(weight)
        return tf.tensor2d(weight.values, weight.shape);

    });
    const model = tf.sequential();
    model.add(tf.layers.dense({ units: 8, inputShape: [1], activation: 'relu' }));
    model.add(tf.layers.dense({ units: 8, activation: 'relu' }));
    model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
    model.setWeights(weights);
    
    // Get the user's message
    const input = message.content.toLowerCase();
    
    // Get the model's predicted output
    const outputTensor = model.predict(tf.tensor2d([input.charCodeAt(0)]));
    const outputData = await outputTensor.data();
    const predictedOutput = outputData[0] > 0.5 ? 'Hello!' : '';
    
    // Calculate the string similarity between the user's message and the model's predicted output
    const similarity = stringSimilarity.compareTwoStrings(input, predictedOutput);
    
    // If the similarity is above a certain threshold, send the predicted output as a reply
    if (similarity > 0.6) {
      message.reply(predictedOutput);
}else message.reply("Sorry, i have no answer for that.")
  }
});

问题出在

const weights = modelJSON.weights.map((weight) => {

我不知道我能做些什么来解决这个问题。我有点绝望了,哈哈

javascript node.js tensorflow discord.js chatbot
1个回答
0
投票

你修复了它,如果你修复了它,请在那里分享代码

© www.soinside.com 2019 - 2024. All rights reserved.