继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

TF Saver 保存/加载训练好模型(网络+参数)的那些事儿

刘小米92
关注TA
已关注
手记 13
粉丝 6
获赞 19

对于神经网络的读取和存储,TF推荐我们使用:When you want to save and load variables, the graph, and the graph's metadata--basically, when you want to save or restore your model--we recommend using SavedModel 推荐使用savedmodel. SavedModel is a language-neutral, recoverable, hermetic serialization format. SavedModel enables higher-level systems and tools to produce, consume, and transform TensorFlow models. TensorFlow provides several mechanisms for interacting with SavedModel, including tf.saved_model APIs, Estimator APIs and a CLI.

一般模式:

训练时候保存

Create a saver.

saver = tf.train.Saver(...variables...)

Remember the training_op we want to run by adding it to a collection.

tf.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:

Saves checkpoint, which by default also exports a meta_graph
    # named 'my-model-global_step.meta'.
    saver.save(sess, 'my-model', global_step=step)

另一个代码加载模型训练好的的网络和参数来测试,或进一步训练

with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess,tf.train.latest_checkpoint( './')
...

模型训练完毕之后,你可能需要在产品上使用它。那么tensorflow model是什么?tensorflow模型主要包含网络的结构的定义或者叫graph和训练好的网络结构里的参数。

因此tensorflow model包含2个文件:

a)Meta graph:

使用protocol buffer来保存整个tensorflow graph.例如所有的variables, operations, collections等等。这个文件使用.meta后缀

b) Checkpoint file:

二进制文件包含所有的weights,biases,gradients和其他variables的值。这个文件使用.ckpt后缀,后来变成有2个文件:

mymodel.data-00000-of-00001
mymodel.index
第一个文件.data文件就是保存训练的variables我们将要使用它。

和这些文件一起,tensorflow还有一个文件叫checkpoint用来简单保存最近一次保存checkpoint文件的记录

导入训练好的模型,有两件事必须做:

1)创造网络:最直接的方法是再逐层写一遍与原模型一样的网络结构,但是meta文件已经把原始网络保存起来了,因此可以直接导入为你创建网络。saver=tf.train.import_meta_graph('my_model-1000.meta')

2)加载模型:通过saver恢复网络的参数,saver.restore(sess,tf.train.latest_checkpoint('./'))

注意:这个checkpoint文件,在导入训练好模型的时候会有如下问题: 当代码知道你指定了log directory 而且在你之前训练胡时候已经在该log目录下生成了checkpoint 模型,它就会直接从你的log目录下去restore checkpoint 文件,而不是the original TF-Slim checkpoint 模型。这一点很多人会忽视。

error like:

2018-01-07 15:17:30.960482: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conduct_encoder/layer9-conv9/E_weights not found in checkpoint

方案:当你改变代码重新训练模型之后,应该删除之前的 log文件。或者你不要有保存log文件的代码。

问题2:当我们保存多个模型在一个目录下面时,checkpoint文件只有一个,默认后来的覆盖最初的checkpoint内容,导致你加载不到自己想要加载的模型,而是加载了最后训练成的模型。

方案:因此如果保存一个训练模型时候,尽量给它一个单另的文件夹,比如

saver(sess, 'First_net/my_model')

打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP