希望你一切顺利! 我需要你的帮助解决一个我不明白的小问题
所以我有一个像这样的“train_step”函数:
@tf.function
def train_step(timestep_values,noised_image,noise):
# calculate loss and update parameters
with tf.GradientTape() as tape:
prediction = model(noised_image, timestep_values)
loss_value = loss_of(noise, prediction)
gradients = tape.gradient(loss_value, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
tf.print("end-train-step")
像这样的主循环:
EPOCHS = 1
for e in range(EPOCHS):
for batch in X_train:
rng, tsrng = np.random.randint(0, 100000, size=(2,))
timestep_values = generate_timestamp(tsrng, batch.shape[0])
noised_image, noise = forward_noise(rng, batch, timestep_values)
train_step(timestep_values,noised_image,noise)
print("end-of-batch")
print(f"Epoch {e+1}/{EPOCHS}")
我的问题是
tf.print("end-train-step")
打印速度非常快,但是 print("end-of-batch")
没有显示,至少在等待 2-3 分钟后显示(在 collab/Kaggle 上)你知道为什么吗?
我不明白为什么train_step函数执行得很快,但是当返回loss_value时,一切都变慢了,为什么这个转换这么慢?
提前致谢!
我试图删除函数中的所有“tf.print”和“print”,但这没有用,我还尝试在我的计算机上的 CPU (r5 3600) 上执行这段代码,它比 collab 或 kaggle 更快