在 Tensorflow 中将多个嵌入保存到检查点

问题描述 投票:0回答:0

我正在训练一个嵌入模型,并希望将多个嵌入保存到一个检查点文件中,以便在我的本地 Tensorboard 投影仪中进行可视化。我在这个 question 的公认答案中尝试了 TF1 解决方案,但这没有用。这是我正在使用的代码:

for folder in os.listdir():
    if '.' not in folder and int(folder[5]) >= 4:
        models = sorted([f for f in os.listdir(folder) if f.endswith('.h5')])
        print(folder)
        print(f'Found {len(models)} models')

        if models:
            checkpoint = tf.train.Checkpoint()
            checkpoint.save(os.path.join(folder, f"loaded_embedding.ckpt"))

        for m in models:
            print(f'Evaluating model:\t{m}')
            model = tf.keras.models.load_model(os.path.join(folder, m))
            if m.startswith('End'):
                suffix = 'end_model'
            else:
                suffix = m.split('model_')[-1].replace('.h5', '')
            res_train = model.predict(X_train)
            res_test = model.predict(X_test)
            results = np.concatenate((res_train, res_test), axis=0)

            # process to save the data for local Tensorboard
            embeddings = tf.Variable(results)            
            config = projector.ProjectorConfig()
            embedding = config.embeddings.add()

            # The name of the tensor will be suffixed by `tensor_name`
            embedding.tensor_name = f"embedding_{suffix}/.ATTRIBUTES/VARIABLE_VALUE"
            embedding.metadata_path = "meta_loaded.tsv"

            projector.visualize_embeddings(folder, config)

看一下我的 projector_config.pbtxt 文件:

embeddings {
  tensor_name: "embedding/.ATTRIBUTES/VARIABLE_VALUE"
  metadata_path: "meta_40Test_ThreeQuarterMargin.tsv"
}

我期待看到多个嵌入字段。当我在 Tensorboard Projector 中打开数据时,当我想查看每个模型嵌入的张量时,它只会显示“找到 1 个张量”。我如何修复代码以将所有模型嵌入存储在单个检查点中?

python-3.x tensorflow word-embedding
© www.soinside.com 2019 - 2024. All rights reserved.