猿问

Tensorflow 保存多个会话之一

有一个 Python 脚本,我在其中实例化了神经网络类的两个对象。每个对象定义自己的会话并提供保存图形的方法。


import tensorflow as tf

import os, shutil


class TestNetwork:


    def __init__(self, id):

        self.id = id


        tf.reset_default_graph()


        self.s = tf.placeholder(tf.float32, [None, 2], name='s')

        w_initializer, b_initializer = tf.random_normal_initializer(0., 1.0), tf.constant_initializer(0.1)

        self.k = tf.layers.dense(self.s, 2, kernel_initializer=w_initializer,

                    bias_initializer=b_initializer, name= 'k')


        '''Defines self.session and initialize the variables'''

        session_conf = tf.ConfigProto(

            allow_soft_placement = True,

            log_device_placement = False)

        self.session = tf.Session(config = session_conf)

        self.session.run(tf.global_variables_initializer())




    def save_model(self, output_dir):

        '''Save the network graph and weights to disk'''

        if os.path.exists(output_dir):

            # if provided output_dir already exists, remove it

            shutil.rmtree(output_dir)


        builder = tf.saved_model.builder.SavedModelBuilder(output_dir)

        builder.add_meta_graph_and_variables(

            self.session,

            [tf.saved_model.tag_constants.SERVING],

            clear_devices=True)

        # create a new directory output_dir and store the saved model in it

        builder.save()



t1 = TestNetwork(1)

t2 = TestNetwork(2)



t1.save_model("t1_model")

t2.save_model("t2_model")

我得到的错误是


类型错误:无法将 feed_dict 键解释为张量:名称“save/Const:0”指的是不存在的张量。图中不存在“save/Const”操作。


我读到一些说这个错误是由于tf.train.Saver.


因此,我在__init__方法的末尾添加了以下行:


self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 5)

但是我仍然收到错误。


拉莫斯之舞
浏览 194回答 1
1回答
随时随地看视频慕课网APP

相关分类

Python
我要回答