我想将取自
here并在本文下方复制的
flax.linen.Module
转换为torch.nn.Module
。
但是,我发现很难弄清楚我需要如何更换
flax.linen.Dense
呼叫;flax.linen.Conv
呼叫;Dense
。对于(1.),我想我需要使用
torch.nn.Linear
。但是我需要将什么指定为 in_features
和 out_features
?
对于(2.),我想我需要使用
torch.nn.Conv2d
。但是,再说一次,我需要将什么指定为 in_channels
和 out_channels
。
我想我知道如何移植
GaussianFourierProjection
类以及如何模仿“swish 激活函数”。显然,如果有人熟悉这两个模块,以便他/她可以提供相应的 torch.nn.Module
作为答案,那将非常有帮助。但如果有人至少能回答如何替换 (1.) - (3.),那也已经很有帮助了。非常感谢任何帮助!
#@title Defining a time-dependent score-based model (double click to expand or collapse)
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn
from typing import Any, Tuple
import functools
import jax
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
embed_dim: int
scale: float = 30.
@nn.compact
def __call__(self, x):
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
W = self.param('W', jax.nn.initializers.normal(stddev=self.scale),
(self.embed_dim // 2, ))
W = jax.lax.stop_gradient(W)
x_proj = x[:, None] * W[None, :] * 2 * jnp.pi
return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps."""
output_dim: int
@nn.compact
def __call__(self, x):
return nn.Dense(self.output_dim)(x)[:, None, None, :]
class ScoreNet(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture.
Args:
marginal_prob_std: A function that takes time t and gives the standard
deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
channels: The number of channels for feature maps of each resolution.
embed_dim: The dimensionality of Gaussian random feature embeddings.
"""
marginal_prob_std: Any
channels: Tuple[int] = (32, 64, 128, 256)
embed_dim: int = 256
@nn.compact
def __call__(self, x, t):
# The swish activation function
act = nn.swish
# Obtain the Gaussian random feature embedding for t
embed = act(nn.Dense(self.embed_dim)(
GaussianFourierProjection(embed_dim=self.embed_dim)(t)))
# Encoding path
h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID',
use_bias=False)(x)
## Incorporate information from t
h1 += Dense(self.channels[0])(embed)
## Group normalization
h1 = nn.GroupNorm(4)(h1)
h1 = act(h1)
h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID',
use_bias=False)(h1)
h2 += Dense(self.channels[1])(embed)
h2 = nn.GroupNorm()(h2)
h2 = act(h2)
h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID',
use_bias=False)(h2)
h3 += Dense(self.channels[2])(embed)
h3 = nn.GroupNorm()(h3)
h3 = act(h3)
h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID',
use_bias=False)(h3)
h4 += Dense(self.channels[3])(embed)
h4 = nn.GroupNorm()(h4)
h4 = act(h4)
# Decoding path
h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)),
input_dilation=(2, 2), use_bias=False)(h4)
## Skip connection from the encoding path
h += Dense(self.channels[2])(embed)
h = nn.GroupNorm()(h)
h = act(h)
h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 3)),
input_dilation=(2, 2), use_bias=False)(
jnp.concatenate([h, h3], axis=-1)
)
h += Dense(self.channels[1])(embed)
h = nn.GroupNorm()(h)
h = act(h)
h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 3)),
input_dilation=(2, 2), use_bias=False)(
jnp.concatenate([h, h2], axis=-1)
)
h += Dense(self.channels[0])(embed)
h = nn.GroupNorm()(h)
h = act(h)
h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))(
jnp.concatenate([h, h1], axis=-1)
)
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
我认为您正在寻找的是惰性层。 Conv 和 Linear 层有惰性实现,您无需指定输入通道。检查此示例和更多详细信息: https://pytorch.org/docs/stable/ generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin
希望有帮助!