使用 TensorFlow 模型评估输入的简单方法?

在这里,我有一个使用生成的数据训练的增强决策树,并保存为est:


from sklearn.datasets import make_blobs

import pandas as pd

import tensorflow as tf


#creates an input function for a tf model

def make_input_fn(X, Y, n_epochs=None, shuffle=True, verbose=False):

    batch_len = len(Y)

    def input_fn():

        dataset = tf.data.Dataset.from_tensor_slices((dict(X), Y))

        if shuffle:

            dataset = dataset.shuffle(batch_len)

        # For training, cycle thru dataset as many times as need (n_epochs=None).

        dataset = dataset.repeat(n_epochs)

        #dividing data into batches

        dataset = dataset.batch(batch_len)

        return dataset

    return input_fn


#making data

trainX, trainY = make_blobs(n_samples=10, centers=2, n_features=3, random_state=0)


#xVals

trainX = pd.DataFrame(trainX)

trainX.columns = ['feature{}'.format(num) for num in trainX.columns]


#yVals

trainY = pd.DataFrame(trainY)

trainY.columns = ['flag']


# Defining input function

train_input_fn = make_input_fn(trainX, trainY)


#defining tf feature columns

feature_columns=[]

for feature_name in list(trainX.columns):

    feature_columns.append(tf.feature_column.numeric_column(feature_name,dtype=tf.float32))

    

#creating the estimator

n_batches = 1

est = tf.estimator.BoostedTreesClassifier(feature_columns, n_batches_per_layer=n_batches)


est.train(train_input_fn, max_steps=10)

我想使用该模型根据一行训练数据进行预测以用于测试目的;像这样的事情:res = est.predict(trainX.loc[0])但是,我很难弄清楚如何去做。


猛跑小猪
浏览 90回答 1
1回答

慕哥6287543

您必须像训练时一样创建输入函数。代码:def my_input_fn(features, batch_size=256):    """An input function for prediction."""    # Convert the inputs to a Dataset without labels.    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)testX = pd.DataFrame(trainX.loc[0]).Tpredictions = est.predict(    input_fn=lambda: my_input_fn(testX))预测将为您提供一个生成器对象。你必须迭代它才能获得预测for pred_dict in predictions:    class_id = pred_dict['class_ids'][0]    probability = pred_dict['probabilities'][class_id]    print(class_id, probability)class_id是预测的ID。请注意,pred_dict 还包含其他信息。以下是 pred_dict 中包含的信息:{'all_class_ids': array([0, 1]), 'all_classes': array([b'0', b'1'], dtype=object), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object), 'logistic': array([0.17926924], dtype=float32), 'logits': array([-1.5213063], dtype=float32), 'probabilities': array([0.82073075, 0.17926925], dtype=float32)}
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python