我正在训练一个嵌入模型,并希望将多个嵌入保存到一个检查点文件中,以便在我的本地 Tensorboard 投影仪中进行可视化。我在这个 question 的公认答案中尝试了 TF1 解决方案,但这没有用。这是我正在使用的代码:
for folder in os.listdir():
if '.' not in folder and int(folder[5]) >= 4:
models = sorted([f for f in os.listdir(folder) if f.endswith('.h5')])
print(folder)
print(f'Found {len(models)} models')
if models:
checkpoint = tf.train.Checkpoint()
checkpoint.save(os.path.join(folder, f"loaded_embedding.ckpt"))
for m in models:
print(f'Evaluating model:\t{m}')
model = tf.keras.models.load_model(os.path.join(folder, m))
if m.startswith('End'):
suffix = 'end_model'
else:
suffix = m.split('model_')[-1].replace('.h5', '')
res_train = model.predict(X_train)
res_test = model.predict(X_test)
results = np.concatenate((res_train, res_test), axis=0)
# process to save the data for local Tensorboard
embeddings = tf.Variable(results)
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
# The name of the tensor will be suffixed by `tensor_name`
embedding.tensor_name = f"embedding_{suffix}/.ATTRIBUTES/VARIABLE_VALUE"
embedding.metadata_path = "meta_loaded.tsv"
projector.visualize_embeddings(folder, config)
看一下我的 projector_config.pbtxt 文件:
embeddings {
tensor_name: "embedding/.ATTRIBUTES/VARIABLE_VALUE"
metadata_path: "meta_40Test_ThreeQuarterMargin.tsv"
}
我期待看到多个嵌入字段。当我在 Tensorboard Projector 中打开数据时,当我想查看每个模型嵌入的张量时,它只会显示“找到 1 个张量”。我如何修复代码以将所有模型嵌入存储在单个检查点中?