如何将 flax.linen.Module 转换为 torch.nn.Module?

问题描述 投票:0回答:1

我想将取自

here
并在本文下方复制的flax.linen.Module转换为
torch.nn.Module

但是,我发现很难弄清楚我需要如何更换

  1. flax.linen.Dense
    呼叫;
  2. flax.linen.Conv
    呼叫;
  3. 自定义类
    Dense

对于(1.),我想我需要使用

torch.nn.Linear
。但是我需要将什么指定为
in_features
out_features

对于(2.),我想我需要使用

torch.nn.Conv2d
。但是,再说一次,我需要将什么指定为
in_channels
out_channels

我想我知道如何移植

GaussianFourierProjection
类以及如何模仿“swish 激活函数”。显然,如果有人熟悉这两个模块,以便他/她可以提供相应的
torch.nn.Module
作为答案,那将非常有帮助。但如果有人至少能回答如何替换 (1.) - (3.),那也已经很有帮助了。非常感谢任何帮助!


#@title Defining a time-dependent score-based model (double click to expand or collapse)

import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn
from typing import Any, Tuple
import functools
import jax

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  embed_dim: int
  scale: float = 30.
  @nn.compact
  def __call__(self, x):    
    # Randomly sample weights during initialization. These weights are fixed 
    # during optimization and are not trainable.
    W = self.param('W', jax.nn.initializers.normal(stddev=self.scale), 
                 (self.embed_dim // 2, ))
    W = jax.lax.stop_gradient(W)
    x_proj = x[:, None] * W[None, :] * 2 * jnp.pi
    return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)


class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""  
  output_dim: int  
  
  @nn.compact
  def __call__(self, x):
    return nn.Dense(self.output_dim)(x)[:, None, None, :]    


class ScoreNet(nn.Module):
  """A time-dependent score-based model built upon U-Net architecture.
  
  Args:
      marginal_prob_std: A function that takes time t and gives the standard
        deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
      channels: The number of channels for feature maps of each resolution.
      embed_dim: The dimensionality of Gaussian random feature embeddings.
  """
  marginal_prob_std: Any
  channels: Tuple[int] = (32, 64, 128, 256)
  embed_dim: int = 256
  
  @nn.compact
  def __call__(self, x, t): 
    # The swish activation function
    act = nn.swish
    # Obtain the Gaussian random feature embedding for t   
    embed = act(nn.Dense(self.embed_dim)(
        GaussianFourierProjection(embed_dim=self.embed_dim)(t)))
        
    # Encoding path
    h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID',
                   use_bias=False)(x)    
    ## Incorporate information from t
    h1 += Dense(self.channels[0])(embed)
    ## Group normalization
    h1 = nn.GroupNorm(4)(h1)    
    h1 = act(h1)
    h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID',
                   use_bias=False)(h1)
    h2 += Dense(self.channels[1])(embed)
    h2 = nn.GroupNorm()(h2)        
    h2 = act(h2)
    h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID',
                   use_bias=False)(h2)
    h3 += Dense(self.channels[2])(embed)
    h3 = nn.GroupNorm()(h3)
    h3 = act(h3)
    h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID',
                   use_bias=False)(h3)
    h4 += Dense(self.channels[3])(embed)
    h4 = nn.GroupNorm()(h4)    
    h4 = act(h4)

    # Decoding path
    h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)),
                  input_dilation=(2, 2), use_bias=False)(h4)    
    ## Skip connection from the encoding path
    h += Dense(self.channels[2])(embed)
    h = nn.GroupNorm()(h)
    h = act(h)
    h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 3)),
                  input_dilation=(2, 2), use_bias=False)(
                      jnp.concatenate([h, h3], axis=-1)
                  )
    h += Dense(self.channels[1])(embed)
    h = nn.GroupNorm()(h)
    h = act(h)
    h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 3)),
                  input_dilation=(2, 2), use_bias=False)(
                      jnp.concatenate([h, h2], axis=-1)
                  )    
    h += Dense(self.channels[0])(embed)    
    h = nn.GroupNorm()(h)  
    h = act(h)
    h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))(
        jnp.concatenate([h, h1], axis=-1)
    )

    # Normalize output
    h = h / self.marginal_prob_std(t)[:, None, None, None]
    return h
python machine-learning pytorch neural-network flax
1个回答
0
投票

我认为您正在寻找的是惰性层。 Conv 和 Linear 层有惰性实现,您无需指定输入通道。检查此示例和更多详细信息: https://pytorch.org/docs/stable/ generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin

希望有帮助!

© www.soinside.com 2019 - 2024. All rights reserved.