我正在尝试将 TF 模型转换为 TFLite。模型以 .pb格式保存,我使用以下代码对其进行了转换:
import os
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
export_dir = os.path.join('export_dir', '0')
if not os.path.exists('export_dir'):
os.mkdir('export_dir')
tf.compat.v1.enable_control_flow_v2()
tf.compat.v1.enable_v2_tensorshape()
# I took this function from a tutorial on the TF website
def wrap_frozen_graph(graph_def, inputs, outputs):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
inputs, outputs)
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(os.path.join(export_dir, 'saved_model.pb'),'rb').read())
concrete_func = wrap_frozen_graph(
graph_def, inputs=['extern_data/placeholders/data/data:0', 'extern_data/placeholders/data/data_dim0_size:0'],
outputs=['output/output_batch_major:0'])
concrete_func.inputs[0].set_shape([None, 50])
concrete_func.inputs[1].set_shape([None])
concrete_func.outputs[0].set_shape([None, 100])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
converter.post_training_quantize=True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
converter.allow_custom_ops=True
tflite_model = converter.convert()
# Save the model.
if not os.path.exists('tflite'):
os.mkdir('tflite')
output_model = os.path.join('tflite', 'model.tflite')
with open(output_model, 'wb') as f:
f.write(tflite_model)
现在,我在代码中找不到任何内容VarHandleOp,我发现它实际上是在tensorflow中(https://www.tensorflow.org/api_docs/python/tf/raw_ops/VarHandleOp)。那么,为什么TFLite无法识别呢?
开满天机
相关分类