from jax import random,vmap
from jax import numpy as jnp
import pprint
def f(s,layers,do,dx):
x = jnp.zeros((do,dx))
for i,layer in enumerate(layers):
x=x.at[i].set( layer( s[i] ) )
return x
class net(nn.Module):
dx: int
do: int
def setup(self):
self.layers = [ nn.Dense( self.dx, use_bias=False )
for _ in range(self.do) ]
def __call__(self, s):
x = vmap(f,in_axes=(0,None,None,None))(s,self.layers,self.do,self.dx)
return x
if __name__ == '__main__':
seed = 123
key = random.PRNGKey( seed )
key,subkey = random.split( key )
outer_batches = 4
s_observations = 5 # AKA the inner batch
x_features = 2
s_features = 3
s_shape = (outer_batches,s_observations, s_features)
s = random.uniform( subkey, s_shape )
key,subkey = random.split( key )
model = net(x_features,s_observations)
p = model.init( subkey, s )
x = model.apply( p, s )
params = p['params']
pkernels = jnp.array([params[key]['kernel'] for key in params.keys()])
x_=jnp.zeros((outer_batches,s_observations,x_features))
g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
x_=g(s,pkernels)
print('s shape:',s.shape)
print('p shape:',pkernels.shape)
print('x shape:',x.shape)
print('x_ shape:',x_.shape)
print('sum of difference:',jnp.sum(x-x_))
嗨。我的模型中需要一些“特定于批次的”参数。这里,有一个长度为
do
的“内部批次”,因此该批次中的每个元素都有一个 flax.linen.Dense
实例。外部批次只是将多个数据实例传递到这些层中。我通过在 flax.linen.Dense
方法中创建 setup
实例列表来实现此目的。然后在 __call__
方法中,我迭代这些层以填充数组。此迭代被封装在一个函数中,并且该函数被包装在 jax.vmap
中。
我想用对
__call__
的调用来替换 jax.vmap
方法中的 for 循环。当我将列表传递给 vmap
时,我收到错误,当我尝试将多个 Dense
实例放入 jax 数组时,我收到错误。除了使用列表来包含多个 Dense
实例之外,还有其他选择吗?一个限制是我应该能够在模型初始化时创建任意数量的 Dense
实例。
vmap
可用于在批量数据上映射单个函数。您正在尝试使用它来映射批量数据上的多个函数,但它无法做到这一点。
一般来说,解决方法是定义一个可以传递给
vmap
的单个参数化层。在您给出的示例中,每一层都是相同的,因此为了实现您正在寻找的结果,您可以编写如下内容:
def f(s,layer,dx):
return layer(s)
class net(nn.Module):
dx: int
do: int
def setup(self):
self.layer = nn.Dense( self.dx, use_bias=False )
def __call__(self, s):
x = vmap(f,in_axes=(0,None,None))(s,self.layer,self.dx)
return x
如果每层有不同的参数化,那么您也可以通过将这些参数传递给
vmap
来在 vmap
内实现此目的。