Tensorflow Estimator在每次调用predict时给出不同的预测结果。

问题描述 投票:0回答:1

我用TF估计器为自己的数据集训练了一个分类器,但每次预测调用后都会得到不同的预测结果。我检查了一下数据集,数据例子的顺序没有问题,每次预测调用都一样,但是模型会给出不同的分类结果。我很困惑,不知道是不是我做错了什么。

这是我读取输入的代码。

    def parse_predict_record(example):
        features = {"user_id": tf.FixedLenFeature([], tf.int64),"ad_info": tf.VarLenFeature(tf.string)}
        data = tf.parse_single_example(example, features)
        uid = data['user_id']
        ad_info = tf.sparse_tensor_to_dense(data['ad_info'], default_value='0')
        test_info = tf.sparse_tensor_to_dense(tf.string_split(ad_info, "#"), default_value="0")
        test_info = tf.string_to_number(test_info, out_type=tf.int32)

        feature_dict = {"user_id": uid, "ad_info": test_info}
        print("feature_dict=", feature_dict)
        return feature_dict

    files, cnt = get_files(FLAGS.predict_path)
    predict_dataset = tf.data.TFRecordDataset(files).map(lambda x: parse_predict_record(x)) \
        .padded_batch(FLAGS.batch_size, padded_shapes={'user_id': [], 'ad_info': [None, None]}) \
        .prefetch(32)
    iterator = predict_dataset.make_one_shot_iterator()
    return iterator.get_next()

预测的代码。

        print("start predict")
        result = model.predict(input_fn=predict_input_fn,
                               hooks=[tf.train.LoggingTensorHook([ 'user_id', 'ad_info', 'predict_id'], every_n_iter=500)])
        prediction_res = []
        for prediction in result:
            # print("predictions=", prediction)
            user_id = prediction['user_id']
            predict_label = prediction['predict_label']
            print("usre_id=", user_id)
            print("predict_label=", predict_label)

tensorflow tensorflow-estimator
1个回答
0
投票

我有答案了,这个对我很有效。

        checkpoint_path = model.latest_checkpoint()
        print("checkpoint_path=", checkpoint_path)
        result = model.predict(input_fn=predict_input_fn,
                               hooks=[tf.train.LoggingTensorHook([ 'user_id', 'ad_info',
                                                                   'predict_id'], every_n_iter=1000)],
                               checkpoint_path=checkpoint_path)
        prediction_res = []
        for prediction in result:
            # print("predictions=", prediction)
            user_id = prediction['user_id']
            predict_label = prediction['predict_label']
© www.soinside.com 2019 - 2024. All rights reserved.