使用 jax lax 扫描,其输入在扫描内的迭代中不会改变,但每次调用扫描时都不同

问题描述 投票:0回答:0

Jax lax scan 对一个函数进行操作,该函数接受两个参数,一个进位和一个输入序列。我想知道如果某些输入在扫描的迭代中没有改变,应该如何调用扫描。天真地,我可以创建一系列相同的输入,但这看起来很浪费/多余,更重要的是,这并不总是可能的,因为 scan 只能扫描数组。例如,我想传递给我的函数的输入之一是包含我的模型及其参数的训练状态(例如,from flax.training import train_state),它不能放入数组中。正如我在标题中所说,每次我调用扫描时,这些输入也可能会发生变化(例如,模型参数会发生变化)。

关于如何最好地做到这一点有什么想法吗?

谢谢。

jit jax
© www.soinside.com 2019 - 2024. All rights reserved.