如何在亚麻模型中的多个Dense实例上进行vmap?尝试避免循环密集实例列表

问题描述 投票:0回答:1
from jax import random,vmap
from jax import numpy as jnp
import pprint

def f(s,layers,do,dx):
    x = jnp.zeros((do,dx))
    for i,layer in enumerate(layers):
        x=x.at[i].set( layer( s[i] ) )
    return x

class net(nn.Module):
    dx: int 
    do: int 
    def setup(self):
        self.layers = [ nn.Dense( self.dx, use_bias=False )
                        for _ in range(self.do) ]
    def __call__(self, s):
        x = vmap(f,in_axes=(0,None,None,None))(s,self.layers,self.do,self.dx)
        return x

if __name__ == '__main__':
    seed = 123
    key = random.PRNGKey( seed )
    key,subkey = random.split( key )
    outer_batches = 4
    s_observations = 5 # AKA the inner batch
    x_features = 2
    s_features = 3
    s_shape = (outer_batches,s_observations, s_features)
    s = random.uniform( subkey, s_shape )

    key,subkey = random.split( key )    
    model = net(x_features,s_observations)
    p = model.init( subkey, s )
    x = model.apply( p, s )    

    params = p['params']
    pkernels = jnp.array([params[key]['kernel'] for key in params.keys()])
    x_=jnp.zeros((outer_batches,s_observations,x_features))
    
    g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
    
    x_=g(s,pkernels)
    print('s shape:',s.shape)
    print('p shape:',pkernels.shape)
    print('x shape:',x.shape)
    print('x_ shape:',x_.shape)
    print('sum of difference:',jnp.sum(x-x_))

嗨。我的模型中需要一些“特定于批次的”参数。这里,有一个长度为

do
的“内部批次”,因此该批次中的每个元素都有一个
flax.linen.Dense
实例。外部批次只是将多个数据实例传递到这些层中。我通过在
flax.linen.Dense
方法中创建
setup
实例列表来实现此目的。然后在
__call__
方法中,我迭代这些层以填充数组。此迭代被封装在一个函数中,并且该函数被包装在
jax.vmap
中。

我想用对

__call__
的调用来替换
jax.vmap
方法中的 for 循环。当我将列表传递给
vmap
时,我收到错误,当我尝试将多个
Dense
实例放入 jax 数组时,我收到错误。除了使用列表来包含多个
Dense
实例之外,还有其他选择吗?一个限制是我应该能够在模型初始化时创建任意数量的
Dense
实例。

vectorization jax flax
1个回答
1
投票

vmap
可用于在批量数据上映射单个函数。您正在尝试使用它来映射批量数据上的多个函数,但它无法做到这一点。

一般来说,解决方法是定义一个可以传递给

vmap
的单个参数化层。在您给出的示例中,每一层都是相同的,因此为了实现您正在寻找的结果,您可以编写如下内容:

def f(s,layer,dx):
  return layer(s)

class net(nn.Module):
    dx: int 
    do: int 
    def setup(self):
        self.layer = nn.Dense( self.dx, use_bias=False )
    def __call__(self, s):
        x = vmap(f,in_axes=(0,None,None))(s,self.layer,self.dx)
        return x

如果每层有不同的参数化,那么您也可以通过将这些参数传递给

vmap
来在
vmap
内实现此目的。

© www.soinside.com 2019 - 2024. All rights reserved.