Keras - 如何使用 argmax 进行预测

我有 3 类课程Tree, Stump, Ground。我为这些类别列出了一个清单:


CATEGORIES = ["Tree", "Stump", "Ground"]

当我打印我的预测时,它给了我输出


[[0. 1. 0.]]

我已经阅读了 numpy 的 Argmax,但我不完全确定在这种情况下如何使用它。


我试过使用


print(np.argmax(prediction))

但这给了我1. 太好了,但我想找出索引是什么1,然后打印出类别而不是最高值。


import cv2

import tensorflow as tf

import numpy as np


CATEGORIES = ["Tree", "Stump", "Ground"]



def prepare(filepath):

    IMG_SIZE = 150 # This value must be the same as the value in Part1

    img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)

    new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))

    return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)


# Able to load a .model, .h3, .chibai and even .dog

model = tf.keras.models.load_model("models/test.model")


prediction = model.predict([prepare('image.jpg')])

print("Predictions:")

print(prediction)

print(np.argmax(prediction))


我希望我的预测告诉我:


Predictions:

[[0. 1. 0.]]

Stump


斯蒂芬大帝
浏览 420回答 2
2回答

慕无忌1623718

您只需要使用以下结果对类别进行索引np.argmax:pred_name = CATEGORIES[np.argmax(prediction)]print(pred_name)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python