Tensorflow 2.0蒙版的说明

问题描述 投票:0回答:1

从使用Keras子类化API时的Tensorflow文档中,他们给出了此示例,说明如何将掩码传递给实现掩码的其他层。我想知道这是否是明确要求的,或者在Embedding层具有mask_zero = True之后是否可以正确处理。

class MyLayer(layers.Layer):

  def __init__(self, **kwargs):
    super(MyLayer, self).__init__(**kwargs)
    self.embedding = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)
    self.lstm = layers.LSTM(32)

  def call(self, inputs):
    x = self.embedding(inputs)
    # Note that you could also prepare a `mask` tensor manually.
    # It only needs to be a boolean tensor
    # with the right shape, i.e. (batch_size, timesteps).
    mask = self.embedding.compute_mask(inputs)
    output = self.lstm(x, mask=mask)  # The layer will ignore the masked values
    return output

layer = MyLayer()
x = np.random.random((32, 10)) * 100
x = x.astype('int32')
layer(x)

我的困惑来自文档的另一个区域,其中指出:

遮罩

此层支持对数量可变的输入数据进行屏蔽时间步伐。要将掩码引入数据,请使用嵌入层mask_zero参数设置为True。

这似乎意味着,如果mask_zero = True,则无需在后续层上执行其他命令。

tensorflow keras masking
1个回答
1
投票

如果您了解Masking层,它也支持在开始使用蒙版后,所有其他层都会自动获得蒙版。

报价:

对于输入张量中的每个时间步长(张量中的维数1),如果该时间步长上的输入张量中的所有值都等于mask_value,则该时间步长将在所有下游层中被屏蔽(跳过)(只要他们支持遮罩。

如果任何下游层不支持掩码但仍接收到这样的输入掩码,则会引发异常。

other link也表示相同。遮罩将传播到所有层。

报价:

[当使用功能性API或顺序API时,对于能够使用它们的任何层(例如,RNN层),由嵌入或遮罩层生成的遮罩将通过网络传播。 Keras将自动获取与输入相对应的蒙版,并将其传递给知道如何使用它的任何层。

第二个链接确实包含有关遮罩的详细信息。

注意,您显示的代码是用于custom嵌入的。如果您要教您如何“创建并传递”蒙版,则要创建一个将创建蒙版的图层。它基本上显示了普通嵌入层的功能。

因此,我们可以得出结论,如果您使用的是普通的Embedding层,则只需mask_zero=True,一切都会顺畅进行。

© www.soinside.com 2019 - 2024. All rights reserved.