在里面添加@tf.function 时循环很慢,为什么?

问题描述 投票:0回答:0

希望你一切顺利! 我需要你的帮助解决一个我不明白的小问题

所以我有一个像这样的“train_step”函数:

@tf.function
def train_step(timestep_values,noised_image,noise):
    # calculate loss and update parameters
    with tf.GradientTape() as tape:
        prediction = model(noised_image, timestep_values)
        loss_value = loss_of(noise, prediction)
    gradients = tape.gradient(loss_value, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    tf.print("end-train-step")

像这样的主循环:

EPOCHS = 1
for e in range(EPOCHS):
    for batch in X_train:
        rng, tsrng = np.random.randint(0, 100000, size=(2,))
        timestep_values = generate_timestamp(tsrng, batch.shape[0])
        noised_image, noise = forward_noise(rng, batch, timestep_values)
        train_step(timestep_values,noised_image,noise)
        print("end-of-batch")
    print(f"Epoch {e+1}/{EPOCHS}")

我的问题是

tf.print("end-train-step")
打印速度非常快,但是
print("end-of-batch")
没有显示,至少在等待 2-3 分钟后显示(在 collab/Kaggle 上)你知道为什么吗?

我不明白为什么train_step函数执行得很快,但是当返回loss_value时,一切都变慢了,为什么这个转换这么慢?

提前致谢!

我试图删除函数中的所有“tf.print”和“print”,但这没有用,我还尝试在我的计算机上的 CPU (r5 3600) 上执行这段代码,它比 collab 或 kaggle 更快

python tensorflow deep-learning google-colaboratory kaggle
© www.soinside.com 2019 - 2024. All rights reserved.