如何获取从 keras.layers 子类化的自定义图层中的所有图层

问题描述 投票:0回答:1

我正在尝试从 tf.keras.layers 子类化的自定义层中获取所有层,但我对此遇到了困难。最终目标是创建一个包含来自 tf.keras.layers 的层的 DAG(有向无环图)。* 这是一个示例:

from tensorflow import keras
...

class ResidualBlock(keras.layers.Layer):
    def __init__(
        self,
        filters0: int,
        filters1: int,
        activation: str = "leaky_relu",
        **kwargs,
    ) -> None:
        super().__init__()
        self.conv0 = keras.layers.Conv2D(
            filters0, 1, activation=activation, **kwargs)
        self.conv1 = keras.layers.Conv2D(
            filters1, 3, activation=activation,  **kwargs)

    def call(self, inputs, training=False):
        x = self.conv0(inputs, training=training)
        x = self.conv1(x, training=training)
        x = inputs + x
        return x

rb = ResidualBlock(2, 3)

new_model = Sequential([rb, keras.layers.Dense(200)])

convert_to_DAG(new_model)

我想要得到这样的东西:

[{'type': 'ResidualBlock', 'children': ['conv2D_1']}, 
{'type': 'conv2D_1', 'children': ['conv2D_2', 'residual'}, 
{'type': 'conv2D_2', 'children': ['Dense_1']}, 
...
]

我已经看过所有相关答案,例如: 如何访问tensorflow keras中自定义层的递归层,它访问来自tf模型子类的模型中的层,而不是tf groups.Layer

以下代码来自 检查张量流 keras 模型中的下一层,它破坏了基于节点的模型,但它不会递归地遵循每一层到其基础层/操作(我需要)。

def get_layer_summary_with_connections(layer, relevant_nodes): 
    info = {}
    connections = []
    for node in layer._inbound_nodes: 
        if relevant_nodes and node not in relevant_nodes:
            continue
        for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound(): 
            connections.append(inbound_layer.name)
    name = layer.name
    info['type'] = layer.__class__.__name__
    info['parents'] = connections
    return info

最终结果应该是一个包含所有基础层+操作的 DAG,如下所示: DAG end result. All Layers are base layers + operations

感谢您的帮助。如果有什么不清楚的地方我可以澄清

tensorflow keras tensorflow2.0 tf.keras keras-layer
1个回答
0
投票

看看

tf.Module.submodules
,它应该是当前模块属性的所有
tf.Module
的列表,并且它是递归搜索的。
Layer
Model
都继承自
Module

在您的情况下,尝试添加以下内容来列出子模块,并确认

ResidualBlock
的子模块确实是模型的子模块。

print(new_model.submodules)
print(rb.submodules)

for rbs in rb.submodules:
    print("module: ", rbs, "isInModel?", rbs in new_model.submodules)

有了这些信息,应该可以创建 DAG。

© www.soinside.com 2019 - 2024. All rights reserved.