如何在决策树 sklearn 中计算精确召回?

我尝试在标准数据集“iris.csv”中进行预测


import pandas as pd

from sklearn import tree

df = pd.read_csv('iris.csv')

df.columns = ['X1', 'X2', 'X3', 'X4', 'Y']

df.head()


# Decision tree

from sklearn.model_selection import train_test_split

decision = tree.DecisionTreeClassifier(criterion='gini')

X = df.values[:, 0:4]

Y = df.values[:, 4]

trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.25)

decision.fit(trainX, trainY)

y_score = decision.score(testX, testY)

print('Accuracy: ', y_score)



# Compute the average precision score

from sklearn.metrics import average_precision_score

average_precision = average_precision_score(testY, y_score)


print('Average precision-recall score: {0:0.2f}'.format(

      average_precision))

我有 valueerror


File "C:/Users/Ultra/PycharmProjects/poker_ML/decision_tree.py", line 20, in <module>

    average_precision = average_precision_score(testY, y_score)

  File "C:\Users\Ultra\PycharmProjects\poker_ML\venv\lib\site-packages\sklearn\metrics\ranking.py", line 241, in average_precision_score

    average, sample_weight=sample_weight)

  File "C:\Users\Ultra\PycharmProjects\poker_ML\venv\lib\site-packages\sklearn\metrics\base.py", line 74, in _average_binary_score

    raise ValueError("{0} format is not supported".format(y_type))

ValueError: multiclass format is not supported

如何计算 3 类的精确召回率?sklearn 中决策树的精确召回是如何工作的。也许我在计算“y_score”时有错误?


慕的地8271018
浏览 253回答 1
1回答

忽然笑

根据scikit-learn 文档 average_precision_score无法处理多类分类。相反,您可以precision_score像这样使用:# Decision tree...y_pred = decision.predict(testX)y_score = decision.score(testX, testY)print('Accuracy: ', y_score)# Compute the average precision scorefrom sklearn.metrics import precision_scoremicro_precision = precision_score(y_pred, testY, average='micro')print('Micro-averaged precision score: {0:0.2f}'.format(&nbsp; &nbsp; &nbsp; micro_precision))macro_precision = precision_score(y_pred, testY, average='macro')print('Macro-averaged precision score: {0:0.2f}'.format(&nbsp; &nbsp; &nbsp; macro_precision))per_class_precision = precision_score(y_pred, testY, average=None)print('Per-class precision score:', per_class_precision)请注意,您需要指定如何平均分数。如果您的数据集显示标签不平衡(iris事实并非如此),这一点尤其重要。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python