How to graph tf.keras model in Tensorflow-2.0?

Colin Steidtmann picture Colin Steidtmann · Jun 20, 2019 · Viewed 7.3k times · Source

I upgraded to Tensorflow 2.0 and there is no tf.summary.FileWriter("tf_graphs", sess.graph). I was looking through some other StackOverflow questions on this and they said to use tf.compat.v1.summary etc. Surely there must be a way to graph and visualize a tf.keras model in Tensorflow version 2. What is it? I'm looking for a tensorboard output like the one below. Thank you!

enter image description here

Answer

nessuno picture nessuno · Jun 21, 2019

You can visualize the graph of any tf.function decorated function, but first, you have to trace its execution.

Visualizing the graph of a Keras model means to visualize it's call method.

By default, this method is not tf.function decorated and therefore you have to wrap the model call in a function correctly decorated and execute it.

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)


@tf.function
def traceme(x):
    return model(x)


logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)