I have a trained model (Faster R-CNN) which I exported using export_inference_graph.py
to use for inference. I'm trying to understand the difference between the created frozen_inference_graph.pb
and saved_model.pb
and also model.ckpt*
files. I've also seen .pbtxt
representations.
I tried reading through this but couldn't really find the answers: https://www.tensorflow.org/extend/tool_developers/
What do each of these files contain? Which ones can be converted to which other ones? What is the ideal purpose of each?
frozen_inference_graph.pb, is a frozen graph that cannot be trained anymore, it defines the graphdef and is actually a serialized graph and can be loaded with this code:
def load_graph(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
tf.import_graph_def(load_graph("frozen_inference_graph.pb"))
the saved model is a model generated by tf.saved_model.builder and is has to be imported into a session, this file contains the full graph with all training weights (just like the frozen graph) but here can be trained upon, and this one is not serialized and needs to be loaded by this snippet. The [] are tagconstants which can be read by the saved_model_cli. This model is also often served to predict on, like google ml engine par example:
with tf.Session() as sess:
tf.saved_model.loader.load(sess, [], "foldername to saved_model.pb, only folder")
model.ckpt files are checkpoints, generated during training, this is used to resume training or to have a back up when something goes wrong after along training. If you have a saved model and a frozen graph, then you can ignore this.
.pbtxt files are basically the same as previous discussed models, but then human readable, not binary. These can be ignored as well.
To answer your conversion question: saved models can be transformed into a frozen graph and vice versa, although a saved_model extracted from a frozen graph is also no trainable, but the way it is stored is in saved model format. Checkpoints can be read in and loaded into a session, and there you can build a saved model from them.
Hope I helped, any questions, ask away!
ADDITION:
How to freeze a graph, starting from a saved model folder structure. This post is old, so the method I used before might not work anymore, it will most likely still work with Tensorflow 1.+.
Start of by downloading this file from the tensorflow library, and then this code snippit should do the trick:
import freeze_graph # the file you just downloaded
from tensorflow.python.saved_model import tag_constants # might be unnecessary
freeze_graph.freeze_graph(
input_graph=None,
input_saver=None,
input_binary=None,
input_checkpoint=None,
output_node_names="dense_output/BiasAdd",
restore_op_name=None,
filename_tensor_name=None,
output_graph=os.path.join(path, "frozen_graph.pb"),
clear_devices=None,
initializer_nodes=None,
input_saved_model_dir=path,
saved_model_tags=tag_constants.SERVING
)
output_node_names = Node name of the final operation, if you end on a dense layer, it will be dense layer_name/BiasAdd
output_graph = output graph name
input_saved_model_dir = root folder of the saved model
saved_model_tags = saved model tags, in your case this can be None, I did however use a tag.
ANOTHER ADDITION:
The code to load models is already provided above. To actually predict you need a session, for a saved model this session is already created, for a frozen model, it's not.
saved model:
with tf.Session() as sess:
tf.saved_model.loader.load(sess, [], "foldername to saved_model.pb, only folder")
prediction = sess.run(output_tensor, feed_dict={input_tensor: test_images})
Frozen model:
tf.import_graph_def(load_graph("frozen_inference_graph.pb"))
with tf.Session() as sess:
prediction = sess.run(output_tensor, feed_dict={input_tensor: test_images})
To further understand what your input and output layers are, you need to check them out with tensorboard, simply add the following line of code into your session:
tf.summary.FileWriter("path/to/folder/to/save/logs", sess.graph)
This line will create a log file that you can open with the cli/powershell, to see how to run tensorboard, check out this previously posted question