如何使用张量流的ncf模型进行预测?

问题描述 投票:0回答:1

嗨,我是张量流和神经网络的新手。试图了解tensorflow官方模型库中的ncf recommendation model

我的理解是,您将使用输入层和学习层来构建模型。然后,创建一批数据以训练模型,然后使用测试数据评估模型。这是在此file中完成的。

但是,我在理解输入层时遇到了麻烦。

它显示在代码中

user_input = tf.keras.layers.Input(
      shape=(1,), name=movielens.USER_COLUMN, dtype=tf.int32)

据我所知,您一次只能输入一个参数。

但是我只能使用以下虚拟数据来调用predict_on_batch

user_input = np.full(shape=(256,),fill_value=1, dtype=np.int32)
item_input = np.full(shape=(256,),fill_value=1, dtype=np.int32)
valid_pt_mask_input = np.full(shape=(256,),fill_value=True, dtype=np.bool)
dup_mask_input = np.full(shape=(256,),fill_value=1, dtype=np.int32)
label_input = np.full(shape=(256,),fill_value=True, dtype=np.bool)
test_input_list = [user_input,item_input,valid_pt_mask_input,dup_mask_input,label_input]

tf.print(keras_model.predict_on_batch(test_input_list))

当我运行以下代码时:

    user_input = np.full(shape=(1,),fill_value=1, dtype=np.int32)
    item_input = np.full(shape=(1,),fill_value=1, dtype=np.int32)
    valid_pt_mask_input = np.full(shape=(1,),fill_value=True, dtype=np.bool)
    dup_mask_input = np.full(shape=(1,),fill_value=1, dtype=np.int32)
    label_input = np.full(shape=(1,),fill_value=True, dtype=np.bool)
    test_input_list = [user_input,item_input,valid_pt_mask_input,dup_mask_input,label_input]

    classes = _model.predict(test_input_list)
    tf.print(classes)

我收到此错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:  Input to reshape is a tensor with 1 values, but the requested shape requires a multiple of 256
     [[{{node model_1/metric_layer/StatefulPartitionedCall/StatefulPartitionedCall/Reshape_1}}]] [Op:__inference_predict_function_2828]

有人可以帮助我如何使用此模型通过单个输入进行预测吗?以及为什么在进行预测时user_id要求item_id?您是否应该提供模型返回的用户列表的用户列表?

python tensorflow neural-network recommendation-engine tensorflow-model-garden
1个回答
0
投票

我以前没有使用过ncf模型,但是看起来您输入的训练数据是1个具有256个特征的样本,而不是256个样本,每个均具有1个特征。只需翻转您的Numpy数组,确保要素矩阵为2D,并且要素数量为第一维。

user_input = np.full(shape=(1,256),fill_value=1, dtype=np.int32)

...其他。 (嗯,标签应保持一维不变)]

类似地,在预测输入中确保特征矩阵为2D:

user_input = np.full(shape=(1,1),fill_value=1, dtype=np.int32)
© www.soinside.com 2019 - 2024. All rights reserved.