如何在Tensorflow-2.0中绘制tf.keras模型图?

问题描述 投票:0回答:5

我升级到Tensorflow 2.0,没有

tf.summary.FileWriter("tf_graphs", sess.graph)
。我正在查看有关此问题的其他一些 StackOverflow 问题,他们说使用
tf.compat.v1.summary etc
。当然,一定有一种方法可以在 Tensorflow 版本 2 中对 tf.keras 模型进行图形化和可视化。它是什么?我正在寻找如下所示的张量板输出。谢谢!

python-3.x tensorflow tensorboard tensorflow2.0 tf.keras
5个回答
29
投票

您可以可视化任何

tf.function
修饰函数的图形,但首先,您必须跟踪其执行情况。

可视化 Keras 模型的图形意味着可视化它的

call
方法。

默认情况下,此方法未经过

tf.function
修饰,因此您必须将模型调用包装在正确修饰的函数中并执行它。

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)


@tf.function
def traceme(x):
    return model(x)


logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)

7
投票

根据 docs,一旦模型训练完成,您就可以使用 Tensorboard 来可视化图形。

首先,定义模型并运行它。然后,打开 Tensorboard 并切换到 Graph 选项卡。


最小可编译示例

这个例子取自文档。首先,定义您的模型和数据。

# Relevant imports.
%load_ext tensorboard

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
from packaging import version

import tensorflow as tf
from tensorflow import keras

# Define the model.
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

(train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0

接下来,训练你的模型。在这里,您需要为 Tensorboard 定义回调以用于可视化统计数据和图表。

# Define the Keras TensorBoard callback.
logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

# Train the model.
model.fit(
    train_images,
    train_labels, 
    batch_size=64,
    epochs=5, 
    callbacks=[tensorboard_callback])

训练后,在笔记本中运行

%tensorboard --logdir logs

并切换到导航栏中的“图表”选项卡:

您将看到一个看起来很像这样的图表:


2
投票

这是tf2.x的解决方案,具有子类模型/层的图形可视化

import tensorflow as tf print("TensorFlow version:", tf.__version__) from tensorflow.keras.layers import Dense, Flatten, Conv2D from tensorflow.keras import Model,Input class MyModel(Model): def __init__(self, dim): super(MyModel, self).__init__() self.conv1 = Conv2D(16, 3, activation='relu') self.conv2 = Conv2D(32, 3, activation='relu') self.conv3 = Conv2D(8, 3, activation='relu') self.flatten = Flatten() self.d1 = Dense(128, activation='relu') self.d2 = Dense(1) def call(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.flatten(x) x = self.d1(x) return self.d2(x) def build_graph(self): x = Input(shape=(dim)) return Model(inputs=[x], outputs=self.call(x)) dim = (28, 28, 1) # Create an instance of the model model = MyModel((dim)) model.build((None, *dim)) model.build_graph().summary() tf.keras.utils.plot_model(model.build_graph(), to_file="model.png", expand_nested=True, show_shapes=True)

输出是

TensorFlow version: 2.5.0 Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ conv2d (Conv2D) (None, 26, 26, 16) 160 _________________________________________________________________ conv2d_1 (Conv2D) (None, 24, 24, 32) 4640 _________________________________________________________________ conv2d_2 (Conv2D) (None, 22, 22, 8) 2312 _________________________________________________________________ flatten (Flatten) (None, 3872) 0 _________________________________________________________________ dense (Dense) (None, 128) 495744 _________________________________________________________________ dense_1 (Dense) (None, 1) 129 ================================================================= Total params: 502,985 Trainable params: 502,985 Non-trainable params: 0
这也是一个

图形可视化


1
投票
这是目前对我有用的东西(TF 2.0.0),基于

tf.keras.callbacks.TensorBoard 代码:

# After model has been compiled from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.keras.backend import get_graph tb_path = '/tmp/tensorboard/' tb_writer = tf.summary.create_file_writer(tb_path) with tb_writer.as_default(): if not model.run_eagerly: summary_ops_v2.graph(get_graph(), step=0)
    

1
投票
另一种选择是使用此网站:

https://lutzroeder.github.io/netron/

使用 .h5 或 .tflite 文件生成图表。

它所基于的 github 存储库可以在这里找到(它也有一个 python 接口):

https://github.com/lutzroeder/netron

© www.soinside.com 2019 - 2024. All rights reserved.