TensorFlow:如何保存/恢复模型?

TensorFlow:如何保存/恢复模型?

当你在坦索弗洛训练了一个模特之后:

  1. 你是如何拯救受过训练的模特的?
  2. 稍后如何恢复这个保存的模型?


扬帆大鱼
浏览 1880回答 3
3回答

拉风的咖菲猫

我正在改进我的答案,为保存和恢复模型添加更多的细节。在(和之后)TensorFlow版本0.11:保存模型:import tensorflow as tf#Prepare to feed input, i.e. feed_dict and placeholdersw1 = tf.placeholder("float", name="w1")w2 = tf.placeholder("float", name="w2")b1= tf.Variable(2.0,name="bias")feed_dict ={w1:4,w2:8}#Define a test operation that we will restorew3 = tf.add(w1,w2)w4 = tf.multiply(w3,b1,name="op_to_restore")sess = tf.Session()sess.run(tf.global_variables_initializer())#Create a saver object which will save all the variablessaver = tf.train.Saver()#Run the operation by feeding inputprint sess.run(w4,feed_dict)#Prints 24 which is sum of (w1+w2)*b1 #Now, save the graphsaver.save(sess, 'my_test_model',global_step=1000)恢复模型:import tensorflow as tfsess=tf.Session()    #First let's load meta graph and restore weightssaver = tf.train.import_meta_graph('my_test_model-1000.meta')saver.restore(sess,tf.train.latest_checkpoint('./'))# Access saved Variables directlyprint(sess.run('bias:0'))# This will print 2, which is the value of bias that we saved# Now, let's access and create placeholders variables and# create feed-dict to feed new datagraph = tf.get_default_graph()w1 = graph.get_tensor_by_name("w1:0")w2 = graph.get_tensor_by_name("w2:0")feed_dict ={w1:13.0,w2:17.0}#Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")print sess.run(op_to_restore,feed_dict)#This will print 60 which is calculated 这个和一些更高级的用例在这里已经解释得很好了。保存和恢复TensorFlow模型的快速完整教程
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python