训练期间每类验证的准确性

Keras 在训练时给出了整体trainingvalidation准确率。

http://img2.mukewang.com/646c21c90001805815150311.jpg

有没有办法在培训期间获得a per-class validation accuracy?


更新:来自 Pycharm 的错误日志


File "C:/Users/wj96hq/PycharmProjects/PedestrianClassification/Awareness.py", line 82, in <module>

shuffle=True, callbacks=callbacks)

File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapper

return method(self, *args, **kwargs)

File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 876, in fit

callbacks.on_epoch_end(epoch, epoch_logs)

File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\callbacks.py", line 365, in on_epoch_end

callback.on_epoch_end(epoch, logs)

File "C:/Users/wj96hq/PycharmProjects/PedestrianClassification/Awareness.py", line 36, in on_epoch_end

x_test, y_test = self.validation_data[0], self.validation_data[1]

TypeError: 'NoneType' object is not subscriptable


跃然一笑
浏览 148回答 3
3回答

慕斯709654

使用它来获得每类准确性:model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])class Metrics(keras.callbacks.Callback):&nbsp; &nbsp; def on_train_begin(self, logs={}):&nbsp; &nbsp; &nbsp; &nbsp; self._data = []&nbsp; &nbsp; def on_epoch_end(self, batch, logs={}):&nbsp; &nbsp; &nbsp; &nbsp; x_test, y_test = self.validation_data[0], self.validation_data[1]&nbsp; &nbsp; &nbsp; &nbsp; y_predict = np.asarray(model.predict(x_test))&nbsp; &nbsp; &nbsp; &nbsp; true = np.argmax(y_test, axis=1)&nbsp; &nbsp; &nbsp; &nbsp; pred = np.argmax(y_predict, axis=1)&nbsp; &nbsp; &nbsp; &nbsp;&nbsp;&nbsp; &nbsp; &nbsp; &nbsp; cm = confusion_matrix(true, pred)&nbsp; &nbsp; &nbsp; &nbsp; cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]&nbsp; &nbsp; &nbsp; &nbsp; self._data.append({&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; 'classLevelaccuracy':cm.diagonal() ,&nbsp; &nbsp; &nbsp; &nbsp; })&nbsp; &nbsp; &nbsp; &nbsp; return&nbsp; &nbsp; def get_data(self):&nbsp; &nbsp; &nbsp; &nbsp; return self._datametrics = Metrics()history = model.fit(x_train, y_train, epochs=100, validation_data=(x_test, y_test), callbacks=[metrics])metrics.get_data()您可以在指标类中更改代码。随心所欲..并且这个工作。你只是用来metrics.get_data()获取所有信息..

猛跑小猪

好吧,准确性是一个global指标,没有per-class accuracy. 也许你的意思是,这就是orproportion of the class correctly identified的确切定义。TPRrecall

倚天杖

如果您想获得某个类别或一组特定类别的准确性,掩码可能是一个很好的解决方案。看这段代码:def cus_accuracy(real, pred):&nbsp; &nbsp; score = accuracy(real, pred)&nbsp; &nbsp; mask = tf.math.greater_equal(real, 5)&nbsp; &nbsp; mask = tf.cast(mask, dtype=real.dtype)&nbsp; &nbsp; score *= mask&nbsp; &nbsp; mask2 = tf.math.less_equal(real, 10)&nbsp; &nbsp; mask2 = tf.cast(mask2, dtype=real.dtype)&nbsp; &nbsp; score *= mask2return tf.reduce_mean(score)这个指标给出了 5 到 10 类的准确度。我用它来测量 seq2seq 模型中某些单词的准确度。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python