猿问

Tensorflow 对象检测 - 将检测到的对象转换为图像

我训练了一个 ssd_mobilenet_v1 模型来检测静态灰度图像中的小物体。

现在我想确定诸如物体的水平角度之类的东西。如何将对象“提取”为图像或图像阵列以进行进一步的几何研究?


这是我从 Github 上的 Tensorflow 对象检测 API 更改的 object_detection_tutorial.ipynb 文件的版本(原始版本可以在这里找到:https : //github.com/tensorflow/models/tree/master/research/object_detection)


代码:


进口


mport numpy as np

import os

import six.moves.urllib as urllib

import sys

import tarfile

import tensorflow as tf

import zipfile


from collections import defaultdict

from io import StringIO

from matplotlib import pyplot as plt

from PIL import Image


# This is needed since the notebook is stored in the object_detection folder.

sys.path.append("..")

from object_detection.utils import ops as utils_ops

对象检测导入


from utils import label_map_util


from utils import visualization_utils as vis_util

变量


# What model to download.

MODEL_NAME = 'shard_graph_ssd'


# Path to frozen detection graph. This is the actual model that is used for the object detection.

PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'


# List of the strings that is used to add correct label for each box.

PATH_TO_LABELS = os.path.join('data', 'label_map.pbtxt')


NUM_CLASSES = 1

将(冻结的)Tensorflow 模型加载到内存中。


detection_graph = tf.Graph()

with detection_graph.as_default():

  od_graph_def = tf.GraphDef()

  with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:

    serialized_graph = fid.read()

    od_graph_def.ParseFromString(serialized_graph)

    tf.import_graph_def(od_graph_def, name='')

加载标签图


label_map = label_map_util.load_labelmap(PATH_TO_LABELS)

categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)

category_index = label_map_util.create_category_index(categories)




Smart猫小萌
浏览 218回答 1
1回答

斯蒂芬大帝

我用以下功能解决了这个问题:i 是一个用于循环的变量,基本上是当前图像的数量def crop_objects(image, image_np, output_dict, i):&nbsp; &nbsp; global ymin, ymax, xmin, xmax&nbsp; &nbsp; width, height = image.size&nbsp; &nbsp; #Coordinates of detected objects&nbsp; &nbsp; ymin = int(output_dict['detection_boxes'][0][0]*height)&nbsp; &nbsp; xmin = int(output_dict['detection_boxes'][0][1]*width)&nbsp; &nbsp; ymax = int(output_dict['detection_boxes'][0][2]*height)&nbsp; &nbsp; xmax = int(output_dict['detection_boxes'][0][3]*width)&nbsp; &nbsp; crop_img = image_np[ymin:ymax, xmin:xmax]&nbsp; &nbsp; # 1. Only crop objects that are detected with an accuracy above 50%,&nbsp;&nbsp; &nbsp; # images&nbsp;&nbsp; &nbsp; # with objects below 50% will be filled with zeros (black image)&nbsp; &nbsp; # This is something I need in my program&nbsp; &nbsp; # 2. Only crop the object with the highest score (Object Zero)&nbsp; &nbsp; if output_dict['detection_scores'][0] < 0.5:&nbsp; &nbsp; &nbsp; &nbsp; crop_img.fill(0)&nbsp; &nbsp; #Save cropped object into image&nbsp; &nbsp; cv2.imwrite('Images/Step_2/' + str(i) + '.png', crop_img)&nbsp; &nbsp; return ymin, ymax, xmin, xmax这些是它工作所必需的:image = Image.open(image_path)image_np = load_image_into_numpy_array(image)def load_image_into_numpy_array(image):&nbsp; &nbsp; #Für Bilderkennung benötigte Funktion&nbsp; &nbsp; last_axis = -1&nbsp; &nbsp; dim_to_repeat = 2&nbsp; &nbsp; repeats = 3&nbsp; &nbsp; grscale_img_3dims = np.expand_dims(image, last_axis)&nbsp; &nbsp; training_image = np.repeat(grscale_img_3dims, repeats, dim_to_repeat).astype('uint8')&nbsp; &nbsp; assert len(training_image.shape) == 3&nbsp; &nbsp; assert training_image.shape[-1] == 3&nbsp; &nbsp; return training_image这可能比仅裁剪对象所需的代码更多。
随时随地看视频慕课网APP

相关分类

Python
我要回答