如何在不导入tensorflow的情况下从保存的模型.h5中访问激活函数?

问题描述 投票:0回答:2

每个层的激活函数是否存储在由model.save()生成的.h5文件中?还是已经 "烘焙 "到权重中了?

我正在编写一个AWS Lambda函数,用于每五分钟从多个回归模型中生成时间序列预测。不幸的是,TensorFlow是一个太大的库,无法加载到AWS Lambda函数中,所以我正在编写自己的Python代码来加载保存的.h5模型文件,并根据权重和输入数据生成预测。这是我目前的情况。

def generate_predictions(model_path, df):
    model_info = h5py.File(model_path, 'r')
    model_weights = model_info['model_weights']
    # Initialize predictions matrix with preprocessed inputs
    predictions = preprocessing.scale(df[inputs])
    layer_list = list(model_weights.keys())
    for layer in layer_list:
        weights = model_weights[layer][layer]['kernel:0'][:]
        bias = model_weights[layer][layer]['bias:0'][:]
        predictions = predictions.dot(weights)
        predictions += bias
        # How to retrieve activation function for layer?
        # predictions = activation_function(predictions)

    return predictions

我知道我可能需要一些casewitch语句 来处理各种激活函数。

python tensorflow keras h5py
2个回答
1
投票

如果你保存完整的模型与 model.save你可以访问每一层和它的激活功能。

from tensorflow.keras.models import load_model
model = load_model('model.h5')

for l in model.layers:
  try:
    print(l.activation)
  except: # some layers don't have any activation
    pass
<function tanh at 0x7fa513b4a8c8>
<function softmax at 0x7fa513b4a510>

例如,在这里。softmax 是在最后一层使用的。

如果你不想导入tensorflow,也可以从h5py中读取。

import h5py
import json

model_info = h5py.File('model.h5', 'r')

model_config = json.loads(model_info.attrs.get('model_config').decode('utf-8'))

for k in model_config['config']['layers']:
  if 'activation' in k['config']:
      print(f"{k['class_name']}: {k['config']['activation']}")
LSTM: tanh
Dense: softmax

这里,最后一层是一个密集层,它有softmax激活。


0
投票

模型配置可以通过顶层组的一个名为 "model_config "的属性访问,这个属性似乎包含了完整的模型配置JSON,由model.to_json()产生。

import json
import h5py
model_info = h5py.File('model.h5', 'r')
model_config_json = json.loads(model_info.attrs['model_config'])
© www.soinside.com 2019 - 2024. All rights reserved.