建立模型后如何添加属性并使用?

问题描述 投票:0回答:1

我已经构建了一个神经网络模型,我想合并自定义函数,encodeImage和encodeText,用于预处理数据。理想情况下,我希望这些函数在模型定义期间和训练之后(构建后)都可以调用。但是,将它们直接包含在模型定义中会限制它们只能在即时 (JIT) 编译之前使用。模型构建后进行的调用会导致函数未定义

# The Custom Attributes I wan to add in the Model
    def encode_image(self, image):
      return self.visual(image.type(self.dtype))

    def encode_text(self, text):
      x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

      x = x + self.positional_embedding.type(self.dtype)
      x = x.permute(1, 0, 2)  # NLD -> LND
      x = self.transformer(x)
      x = x.permute(1, 0, 2)  # LND -> NLD
      x = self.ln_final(x).type(self.dtype)

      # x.shape = [batch_size, n_ctx, transformer.width]
      # take features from the eot embedding (eot_token is the highest number in each sequence)
      x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

      return x
  # Image Classifier Neural Network
  class ImageClassifier(nn.Module):
      def __init__(self, n_qubits, n_layers, encode_image):
          super().__init__()
          self.model = nn.Sequential(
              qlayer,
              ClassicalLayer(2)
          )
  
      def forward(self, x):
          result = self.model(x)
          return result
deep-learning pytorch neural-network
1个回答
0
投票

您希望脚本函数仅包含模型操作 - 即没有预处理、I/O、设备传输等。

数据加载/预处理逻辑应与模型逻辑分开。例如,标记化不应出现在模型代码中。

逻辑分离后,可以通过在

module
类中实现它们并添加
@torch.jit.export
来添加其他功能。默认情况下,torch 脚本将编译
forward
方法以及包含
@torch.jit.export
装饰器的任何其他方法。请参阅 pytorch 中的示例 docs

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
© www.soinside.com 2019 - 2024. All rights reserved.