Tensorflow Convert pb file to TFLITE using python

Nael Marwan picture Nael Marwan · May 31, 2018 · Viewed 22.6k times · Source

I have a model saved after training as pb file, I want to use tensorflow mobile and it's important to work with TFLITE file. The problem is most of the examples I found after googling for converters are command on terminal or cmd. Can you please share with me an example of converting to tflite files using python code?

Answer

Pannag Sanketi picture Pannag Sanketi · Jun 1, 2018

You can convert to tflite directly in python directly. You have to freeze the graph and use toco_convert. It needs the input and output names and shapes to be determined ahead of calling the API just like in the commandline case.

An example code snippet

Copied from documentation, where a "frozen" (no variables) graph is defined as part of your code:

import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
  open("test.tflite", "wb").write(tflite_model)

In the example above, there is no freeze graph step since there are no variables. If you have variables and run toco without freezing graph, i.e. converting those variables to constants first, then toco will complain!

If you have frozen graphdef and know the inputs and outputs

Then you don't need the session. You can directly call toco API:

path_to_frozen_graphdef_pb = '...'
input_tensors = [...]
output_tensors = [...]
frozen_graph_def = tf.GraphDef()
with open(path_to_frozen_graphdef_pb, 'rb') as f:
  frozen_graph_def.ParseFromString(f.read())
tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

If you have non-frozen graphdef and know the inputs and outputs

Then you have to load the session and freeze the graph first before calling toco:

path_to_graphdef_pb = '...'
g = tf.GraphDef()
with open(path_to_graphdef_pb, 'rb') as f:
  g.ParseFromString(f.read())
output_node_names = ["..."]
input_tensors = [..]
output_tensors = [...]

with tf.Session(graph=g) as sess:
  frozen_graph_def = tf.graph_util.convert_variables_to_constants(
      sess, sess.graph_def, output_node_names)
# Note here we are passing frozen_graph_def obtained in the previous step to toco.
tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

If you don't know inputs / outputs of the graph

This can happen if you did not define the graph, ex. you downloaded the graph from somewhere or used a high level API like the tf.estimators that hide the graph from you. In this case, you need to load the graph and poke around to figure out the inputs and outputs before calling toco. See my answer to this SO question.