如何在Tensorflow 2.0中获取其他指标(不仅仅是准确性)?

我是Tensorflow领域的新手,我正在研究mnist数据集分类的简单示例。我想知道除了准确性和损失(并可能显示它们)之外,我还如何获得其他指标(例如精度,召回率等)。这是我的代码:


from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

from tensorflow.keras.callbacks import ModelCheckpoint

from tensorflow.keras.callbacks import TensorBoard

import os 


#load mnist dataset

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()


#create and compile the model

model = tf.keras.models.Sequential([

  tf.keras.layers.Flatten(input_shape=(28, 28)), 

  tf.keras.layers.Dense(128, activation='relu'), 

  tf.keras.layers.Dropout(0.2), 

  tf.keras.layers.Dense(10, activation='softmax') 

])

model.summary()


model.compile(optimizer='adam',

              loss='sparse_categorical_crossentropy',

              metrics=['accuracy'])


#model checkpoint (only if there is an improvement)


checkpoint_path = "logs/weights-improvement-{epoch:02d}-{accuracy:.2f}.hdf5"


cp_callback = ModelCheckpoint(checkpoint_path, monitor='accuracy',save_best_only=True,verbose=1, mode='max')


#Tensorboard

NAME = "tensorboard_{}".format(int(time.time())) #name of the model with timestamp

tensorboard = TensorBoard(log_dir="logs/{}".format(NAME))


#train the model

model.fit(x_train, y_train, callbacks = [cp_callback, tensorboard], epochs=5)


#evaluate the model

model.evaluate(x_test,  y_test, verbose=2)

由于我只获得准确性和损失,因此如何获得其他指标?提前感谢您,如果这是一个简单的问题或如果已经在某个地方回答了,我很抱歉。


德玛西亚99
浏览 147回答 3
3回答

吃鸡游戏

我正在添加另一个答案,因为这是在测试集上正确计算这些指标的最干净方法(截至2020年3月22日)。您需要做的第一件事是创建自定义回调,在其中发送测试数据:import tensorflow as tffrom tensorflow.keras.callbacks import Callbackfrom sklearn.metrics import classification_report class MetricsCallback(Callback):    def __init__(self, test_data, y_true):        # Should be the label encoding of your classes        self.y_true = y_true        self.test_data = test_data            def on_epoch_end(self, epoch, logs=None):        # Here we get the probabilities        y_pred = self.model.predict(self.test_data))        # Here we get the actual classes        y_pred = tf.argmax(y_pred,axis=1)        # Actual dictionary        report_dictionary = classification_report(self.y_true, y_pred, output_dict = True)        # Only printing the report        print(classification_report(self.y_true,y_pred,output_dict=False)                         在主节点中,加载数据集并添加回调的位置:metrics_callback = MetricsCallback(test_data = my_test_data, y_true = my_y_true)......#train the modelmodel.fit(x_train, y_train, callbacks = [cp_callback, metrics_callback,tensorboard], epochs=5)         

饮歌长啸

从TensorFlow 2.X开始,两者都可作为内置指标使用。precisionrecall因此,您不需要手动实现它们。除此之外,它们之前在Keras 2.X版本中被删除,因为它们具有误导性---因为它们是以批处理方式计算的,精度和召回率的全局(真实)值实际上会有所不同。你可以看看这里:https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall现在,他们有一个内置的累加器,可确保正确计算这些指标。model.compile(optimizer='adam',               loss='binary_crossentropy',               metrics=['accuracy',tf.keras.metrics.Precision(),tf.keras.metrics.Recall()])

慕哥6287543

Keras 文档中提供了可用指标的列表。它包括 、 等。recallprecision例如,回想一下:model.compile('adam', loss='binary_crossentropy',      metrics=[tf.keras.metrics.Recall()])
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python