Tensorflow lite would convert tensorflow models into an optimized FlatBuffer format, so that they can be used by the TensorFlow Lite interpreter.
Tensorflow lite provides a converter to convert models into TensorFlow Lite FlatBuffer file (.tflite
) form. The supported inputs are as follows:
HDF5 models.tf.Session
(Python API only).Meanwhile, not all the models above can be handled directly by tf-lite. The relationship is shown explicitly below.
Then here is the question. What’s the difference between all kinds models that tensorflow provides? Google’s document is here.
A checkpoint file includes 3 files: .meta
, .data
and .index
. Meta is the meta-graph while data is the concrete value of weights etc. The relationship between checkpoint and .pb
model, you may refer to this article.
A .pb
file can be regarded in brief as a combination of graph and weights, just like (meta+data) in check points. The exact relationship of Graph, GraphDef and MetaGraph can refer to this article. Actually, pb is ProtoBuf in short. There are two different formats that a ProtoBuf file can be saved in text format and binary format.
A rather old version of model that tensorflow used. More information can refer to any search engine.
version | API usage | remarks |
1.5 | tf.contrib.lite.toco_convert | fully support TF Lite, Eager execution, CUDA9 and cuDNN7 |
1.6 | tf.contrib.lite.toco_convert | use AVX instructions, which may cause some inconvenience on old machines |
1.7-1.8 | tf.contrib.lite.toco_convert | I encountered some bugs with Lite when using version 1.5, and was told that bug was not fixed unitl version 1.7 |
1.9-1.11 | tf.contrib.lite.TocoConverter | |
1.12 | tf.contrib.lite.TFLiteConverter | I decide to reinstall TF-1.12 from source code after some faults, thanks to google’s document |
1.13-1.14 | tf.lite.TFLiteConverter | |
2.0 | tf.lite.TFLiteConverter |
First, we need checkpoints, but that’s obviously not enough. We still need the the graph of the network, which is stored in ProtoBuf format(.pb
or .pbtxt
Then, we use tools provided by google called freeze_graph
to freeze checkpoints(contains weights) and ProtoBuf(contains graph) into one file. Let’s given the name of frozen_graph.pb
The saving operations of checkpoints are done by tf.estimator =
for you automatically[11]. Below is an example from official document.
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes.
keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.
classifier = tf.estimator.DNNClassifier(
hidden_units=[10, 10],
Another way to generate checkpoint files is using tf.train.Saver()
. More details refer to [12].
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
# Do some work with the model.
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
# GraphDef file in .pb format
tf.train.write_graph(sess.graph_def, "./", "graph.pb", as_text=False)
# GraphDef file in .pbtxt format
tf.train.write_graph(sess.graph_def, "./", "graph.pbtxt", as_text=True)
When using tf.slim
, the slim.train
operation would generate a graph.pbtxt
automatically as well as checkpoints for you.
With the explanation from [7], it’s not hard to write out the following code
# my_freeze_graph.sh
python tools/freeze_graph.py \
--input_graph=${checkpoint_dir}"saved_graph.pb" \
--input_checkpoint=${checkpoint_dir}"model.ckpt-20000" \
--output_node_names="Predictions/Softmax" \
And then run sh my_freeze_graph.sh
will do.
But I came over some problems using this method when I moved on. The file generated had no input when checked with summarize_graph
, and I had no idea where the problem is.
The code is relatively similar to that above. For more details, please refer to [10].
--input_graph=/tmp/mobilenet_v1_224.pb \
--input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
--input_binary=true \
--output_graph=/tmp/frozen_mobilenet_v1_224.pb \
Here we are using a tool provided in source code to find out the input and output nodes of our network from the frozen graph generated above.
First, we have to build the tool
If you are not familiar with the codes below, you may refer to some blogs and documents that tells about building up tensorflow from the source code.
$ bazel build tensorflow/tools/graph_transforms:summarize_graph
Then we checkout the input and output nodes
$ bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph="/tmp/models/lenet-model/frozen_graph.pb"
Found 1 possible inputs: (name=input, type=float(1), shape=[1,224,224,3])
No variables spotted.
Found 1 possible outputs: (name=MobilenetV1/Predictions/Reshape_1, op=Reshape)
Found 4254920 (4.25M) const parameters, 0 (0) variable parameters, and 0 control_edges
Op types used: 166 Const, 138 Identity, 81 Mul, 54 Add, 27 Relu6, 27 Rsqrt, 27 Sub, 15 Conv2D, 13 DepthwiseConv2dNative, 2 Reshape, 1 AvgPool, 1 BiasAdd, 1 Placeholder, 1 Softmax, 1 Squeeze
To use with tensorflow/tools/benchmark:benchmark_model try these arguments:
bazel run tensorflow/tools/benchmark:benchmark_model -- --graph=/home/edai/work_files/models/MobileNet_v1/frozen_graph.pb --show_flops --input_layer=input --input_layer_type=float --input_layer_shape=1,224,224,3 --output_layer=MobilenetV1/Predictions/Reshape_1
When you use ‘slim’ module(a high level API provided by tensorflow) to train your own model, you may found no input nodes just as follows:
$ bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph="/tmp/Training_ckpt/graph.pbtxt"
No inputs spotted.
Found 24 variables: (name=global_step, type=int64(9), shape=[]) (name=SqueezeNet/conv1/weights, type=float(1), shape=[2,2,3,96])
Then we have to use a tool provided by tensorflow to induce the input node. Since you have been using ‘slim’, why not download the source code of ‘slim’ module from tensorflow/models, and the ‘slim’ is in models/tree/master/research/slim
, and the file you will use is models/tree/master/research/slim/export_inference_graph.py
Use it as follows:
echo "**********************************************************"
echo "add input node and export graph architecture"
echo "**********************************************************"
python export_inference_graph.py \
--model_name=squeezenet \
--output_file=${checkpoint_dir}/unfrozen_graph.pb \
--dataset_dir=${DATASET_DIR} \
echo "**********************************************************"
echo "freeze graph"
echo "**********************************************************"
python freeze_graph.py \
--input_graph=${checkpoint_dir}/unfrozen_graph.pb \
--input_binary=true \
--input_checkpoint=${checkpoint_dir}/model.ckpt${checkpoint_select} \
--output_node_names=${output_nodes_names} \
Then you will get a new graph called unfrozen_graph.pb
which contains input and output nodes. And the freeze operation generate the new frozen_graph.pb
file for you, thus you can go back to 3.5.1 to fetch your input and output nodes.
There exists complicated connections between files in models/tree/master/research/slim
folder, so I highly recommend your doing your own development just in this folder. For example, the command --model_name=squeezenet
comes from model SqueezeNet that I built before.
Here we have had a file frozen)graph.pb
and its input and output nodes. Then we will combine them together with checkpoints using codes below:
import tensorflow as tf
import tensorflow.contrib.lite as lite
input_arrays = ['input_node'] # your input node, here is an example
output_arrays = ['output/Relu'] # your output node, here is an example
converter = lite.TFLiteConverter.from_frozen_graph('frozen_graph.pb', input_arrays, output_arrays)
tflite_model = converter.convert()
open('converted_model.tflite', "wb").write(tflite_model)
Finally, we get the desired file converted_model.tflite
. To apply it in android devices, you may refer to [10].
