在 keras 中为多类分类生成混淆矩阵

通过训练模型获得高达 98% 的准确率,但混淆矩阵显示出非常高的误分类。


我正在使用 keras 在预训练的 VGG16 模型上使用迁移学习方法进行多类分类。


问题是使用 CNN 将图像分类为 5 种番茄病害。


有 5 个疾病类别,6970 张训练图像和 70 张测试图像。


训练模型显示 98.65% 的准确率,而测试显示 94% 的准确率。


但问题是当我生成混淆矩阵时,它显示出非常高的误分类。


有人请帮助我,是我的代码错误还是模型错误?我很困惑我的模型是否给了我正确的结果。


如果有人可以向我解释 keras 如何使用 model.fit_generator 函数实际计算准确度,因为在混淆矩阵上应用准确度的一般公式并没有给我与 keras 计算出的相同的结果。


用于测试数据集的代码是:


test_generator = test_datagen.flow_from_directory(

test_dir,

target_size=(150, 150),

batch_size=20,

class_mode='categorical')

test_loss, test_acc = model.evaluate_generator(test_generator, steps=50)

print('test acc:', test_acc)

我从论坛之一找到了生成混淆矩阵的代码;


代码是:


import numpy as np

from sklearn.metrics import confusion_matrix,classification_report

batch_size = 20

num_of_test_samples = 70

predictions = model.predict_generator(test_generator,  num_of_test_samples // batch_size+1)


y_pred = np.argmax(predictions, axis=1)


true_classes = test_generator.classes


class_labels = list(test_generator.class_indices.keys())   


print(class_labels)


print(confusion_matrix(test_generator.classes, y_pred))


report = classification_report(true_classes, y_pred, target_names=class_labels)

print(report)

以下是我得到的结果:


测试精度:


Found 70 images belonging to 5 classes.

test acc: 0.9420454461466182

混淆矩阵的结果:


['TEB', 'TH', 'TLB', 'TLM', 'TSL']

[[2 3 2 4 3]

 [4 2 3 0 5]

 [3 3 3 2 3]

 [3 3 2 4 2]

 [2 2 4 4 2]]]

              precision    recall  f1-score   support


         TEB       0.14      0.14      0.14        14

          TH       0.15      0.14      0.15        14

         TLB       0.21      0.21      0.21        14

         TLM       0.29      0.29      0.29        14

         TSL       0.13      0.14      0.14        14


   micro avg       0.19      0.19      0.19        70

   macro avg       0.19      0.19      0.19        70

weighted avg       0.19      0.19      0.19        70


慕妹3146593
浏览 648回答 6
6回答

慕虎7371278

测试标签应该是 class_indices 而不是 classestrue_classes = test_generator.class_indices

狐的传说

亲爱的总是对任何分类性能参数做以下事情:首先重置您在预测中使用的生成器将 shuffle 等于 false 在 flow_from_directory()
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python