如何将具有自定义标准化功能的 Keras TextVectorization 层配置保存到 pickle 文件中并重新加载?

问题描述 投票:0回答:1

我有一个

Keras TextVectorization
层,它使用自定义标准化函数。

def custom_standardization(input_string, preserve=['[', ']'], add=['¿']):

    strip_chars = string.punctuation
    for item in add:
        strip_chars += item
    
    for item in preserve:
        strip_chars = strip_chars.replace(item, '')

    lowercase = tf.strings.lower(input_string)
    output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')

    return output
target_vectorization = keras.layers.TextVectorization(max_tokens=vocab_size,
                                                output_mode='int',
                                                output_sequence_length=sequence_length + 1,
                                                standardize=custom_standardization)
target_vectorization.adapt(train_spanish_texts)

我想保存调整后的配置以供推理模型使用。

一种方法,如here所述,是将

weights
config
分别保存为pickle文件并重新加载它们。

然而,

target_vectorization.get_config()
回归

{'name': 'text_vectorization_5',
 'trainable': True,
 ...
 'standardize': <function __main__.custom_standardization(input_string, preserve=['[', ']'], add=['¿'])>,
 ...
 'vocabulary_size': 15000}

正在保存到 pickle 文件中。

尝试使用

keras.layers.TextVectorization.from_config(pickle.load(open('ckpts/spanish_vectorization.pkl', 'rb'))['config'])
加载此配置会导致
TypeError: Could not parse config: <function custom_standardization at 0x2a1973a60>
,因为该文件没有有关此自定义标准化函数的任何信息。

在这种情况下,保存 TextVectorization 权重和配置以供推理模型使用的好方法是什么?

nlp tensorflow2.0 tf.keras keras-layer language-translation
1个回答
0
投票

问题

这似乎是与序列化自定义标准化可调用相关的问题。请参阅此处的文档:(tf.keras.layers.TextVectorization)。

解决方案

文档指出,您应该使用带有以下装饰器的包装类将该层注册为 keras 可序列化对象(tf.keras. saving.register_keras_serialized)。

我已经使用您的自定义函数测试了一个最小的工作示例,该示例适用于 python 3.9.12 和 keras/tensorflow 2.15:

import tensorflow as tf
from tensorflow import keras
import string
import re
import pickle

def custom_standardization(input_string, preserve=['[', ']'], add=['¿']):
    strip_chars = string.punctuation
    for item in add:
        strip_chars += item

    for item in preserve:
        strip_chars = strip_chars.replace(item, '')

    lowercase = tf.strings.lower(input_string)
    output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')

    return output

@keras.utils.register_keras_serializable(package='custom_layers', name='TextVectorizer')
class TextVectorizer(keras.layers.Layer):
    def __init__(self, custom_standardization, **kwargs):
        super(TextVectorizer, self).__init__(**kwargs)
        self.custom_standardization = custom_standardization
        self.vectorizer = tf.keras.layers.TextVectorization(
            standardize=self.custom_standardization,
            max_tokens=1000,  # You can adjust this parameter based on your dataset
            output_mode='int'
        )

    def call(self, inputs):
        return self.vectorizer(inputs)

    def get_config(self):
        config = super(TextVectorizer, self).get_config()
        return config

# Example usage
text_vectorizer = TextVectorizer(custom_standardization)

# Adapt the TextVectorization layer to your training data
train_data = tf.constant(['Hello [World]!', 'Another [example].'])
text_vectorizer.vectorizer.adapt(train_data)

# Build the layer to initialize the TextVectorization layer
text_vectorizer.build(input_shape=(None,))

# Create a model to include the TextVectorization layer
model = tf.keras.Sequential([text_vectorizer])
model.build(input_shape=())
# Save the weights of the model
model.save_weights('text_vectorizer_weights.tf')

# Load the weights into a new instance of TextVectorization
loaded_text_vectorizer = TextVectorizer(custom_standardization)
loaded_text_vectorizer.build(input_shape=(None,))

# Create a model to include the loaded TextVectorization layer
loaded_model = tf.keras.Sequential([loaded_text_vectorizer])

# Adapt the TextVectorization layer to the same training data
loaded_text_vectorizer.vectorizer.adapt(train_data)

# Load the weights into the new model
loaded_model.load_weights('text_vectorizer_weights.tf')

# Compile the model after loading the weights
loaded_model.compile()

# Test the loaded layer
text_input = tf.constant(['Hello [World]!'])
output = loaded_model(text_input)
print(output)

© www.soinside.com 2019 - 2024. All rights reserved.