SeqGAN的Generator中的样本函数是如何使用LSTM生成序列的?

问题描述 投票:0回答:0

我参考了这个 GitHub 存储库 https://github.com/HeroKillerEver/SeqGAN-Pytorch 对于 SeqGAN,但我不太明白为什么 LSTM 用于生成器的“样本”功能,特别是在行中

output, (_, _) = self.lstm(embedding, (h, c))
,如果
h
c
在过程中没有更新?

我认为这类似于随机生成数字的方法。

class Generator(nn.Module):
    """Generator"""
    def __init__(self, vocab_size, embedding_size, hidden_dim, num_layers):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers



    def forward(self, x):
        """
        x: (None, sequence_len) LongTensor
        """
        embedding = self.embedding(x) # (None, sequence_len, embedding_size)
        batch_size = x.size(0)
        h0, c0 = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
        output, (_, _) = self.lstm(embedding, (h0, c0))  # (None, sequence_len, hidden_dim)
        logits = self.linear(output) # (None, sequence_len, vocab_size)
        logits = logits.transpose(1, 2) # (None, vocab_size, sequence_len)

        return logits  # (None, vocab_size, sequence_len)


    def step(self, x, h, c):
        """
        Args:
            x: (batch_size,  1), sequence of tokens generated by generator
            h: (1, batch_size, hidden_dim), lstm hidden state
            c: (1, batch_size, hidden_dim), lstm cell state
        """
        embedding = self.embedding(x) # (batch_size, 1, embedding_size)
        self.lstm.flatten_parameters()
        output, (h, c) = self.lstm(embedding, (h, c)) # (batch_size, 1, hidden_dim)
        logits = self.linear(output).squeeze_(1)  # (batch_size, vocab_size)

        return logits, h, c






    def sample(self, batch_size, sequence_len, x=None):

        flag = False
        if x is None:
            x = util.to_var(torch.zeros(batch_size, 1).long())
            flag = True

        h, c = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
        samples = []
        if flag:
            for _ in range(sequence_len):
                logits, h, c = self.step(x, h, c)
                probs = F.softmax(logits, dim=1)
                sample = probs.multinomial(1) # (batch_size, 1)
                samples.append(sample)
        else:
            given_len = x.size(1)
            lis = x.chunk(x.size(1), dim=1)
            for i in range(given_len):
                logits, h, c = self.step(lis[i], h, c)
                samples.append(lis[i])
            x = F.softmax(logits, dim=1).multinomial(1)
            for i in range(given_len, sequence_len):
                samples.append(x)
                logits, h, c = self.step(x, h, c)
                x = F.softmax(logits, dim=1).multinomial(1)
        output = torch.cat(samples, 1)
        return output # (batch_size, sequence_len)


    def init_hidden(self, num_layers, batch_size, hidden_dim):
        """
        initialize h0, c0
        """
        h = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))
        c = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))

        return h, c

请帮我Q_Q,谢谢

deep-learning pytorch lstm sequence generative-adversarial-network
© www.soinside.com 2019 - 2024. All rights reserved.