猿问

Bazel错误解析tf.estimator模型

我正在尝试使用tf.estimator和创建* .pb模型export_savedmodel(),这是对虹膜数据集进行分类的简单分类器(4个要素,3个类):


import tensorflow as tf



num_epoch = 500

num_train = 120

num_test = 30


# 1 Define input function

def input_function(x, y, is_train):

    dict_x = {

        "thisisinput" : x,

    }


    dataset = tf.data.Dataset.from_tensor_slices((

        dict_x, y

    ))


    if is_train:

        dataset = dataset.shuffle(num_train).repeat(num_epoch).batch(num_train)

    else:   

        dataset = dataset.batch(num_test)


    return dataset



def my_serving_input_fn():

    input_data = tf.placeholder(tf.string, [None], name='input_tensors')

    receiver_tensors = {"inputs" : input_data}


    # 2 Define feature columns

    feature_columns = [

        tf.feature_column.numeric_column(key="thisisinput", shape=4),]

    features = tf.parse_example(

        input_data, 

        tf.feature_column.make_parse_example_spec(feature_columns))


    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)



def main(argv):

    tf.set_random_seed(1103) # avoiding different result of random


    # 2 Define feature columns

    feature_columns = [

        tf.feature_column.numeric_column(key="thisisinput", shape=4),]


    # 3 Define an estimator

    classifier = tf.estimator.DNNClassifier(

        feature_columns=feature_columns,

        hidden_units=[10],

        n_classes=3,

        optimizer=tf.train.GradientDescentOptimizer(0.001),

        activation_fn=tf.nn.relu,

        model_dir = 'modeliris2/'

    )


    # Train the model

    classifier.train(

        input_fn=lambda:input_function(xtrain, ytrain, True)

    )


    # Evaluate the model

    eval_result = classifier.evaluate(

        input_fn=lambda:input_function(xtest, ytest, False)

    )


    print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

    print('\nSaving models...')

    classifier.export_savedmodel("modeliris2pb", my_serving_input_fn)



if __name__ == "__main__":

    tf.logging.set_verbosity(tf.logging.INFO)

    tf.app.run(main)


拉风的咖菲猫
浏览 153回答 1
1回答
随时随地看视频慕课网APP

相关分类

Python
我要回答