【目标检测项目实战】一、Tensorflow Object Detection API 下载与配置及使用

1,017 阅读3分钟

首先,简单介绍下,Tensorflow Object Detection API是一个构建在TensorFlow之上的开源框架,它使构建、训练和部署对象检测模型变得很容易

首先,关于win10下深度学习基本环境的搭建,比如,anaconda, Tensorflow CPU或GPU版本,pycharm等安装这块就不说了,网上的教程很多。

额外需要的python库有 pillow, lxml,可以通过pip install 命令进行安装

1.Tensorflow Object Detection API 下载

https://github.com/tensorflow/models,直接从github上下载源码

2.Protoc下载

Protoc是用来将下载来的 中的 object_detection/protos目录下的proto文件编译为py文件

WIN下,建议下载3.4的版本,下载链接

下载完成后,将对应目录的bin文件夹目录添加到环境变量中

cmd打开命令行,输入 protoc,显示如下内容说明安装成功


3.object_detection\protos目录下的文件编译

将之前下载好的Tensorflow Object Detection文件解压,命令行cd进入models-master\research目录下,然后执行命令

protoc ./object_detection/protos/*.proto --python_out=. 

将object_detection/protos目录下的proto文件编译为py文件,

执行完毕后,进入object_detection/protos目录下查看,可以看到生成了对应的py文件

4.使用训练好的目标检测模型完成目标检测任务

首先,在Pycharm中重新创建一个你的新项目,我这块项目名称为 using_pre-trained_model_to_detect_objects,然后将下载的Tensorflow Object Detection中的models-master\research\object_detection拷贝进using_pre-trained_model_to_detect_objects新项目中

在项目中创建 object_detection_tutorial.py 文件用来进行目标检测,项目结构为:

预测程序如下,需要注意相关路径问题:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
import matplotlib.pyplot as plt
from PIL import Image

from object_detection.utils import ops as utils_ops

if StrictVersion(tf.__version__) < StrictVersion('1.12.0'):
  raise ImportError('Please upgrade your TensorFlow installation to v1.12.*.')

from object_detection.utils import label_map_util

from object_detection.utils import visualization_utils as vis_util

MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# pb模型存放位置.
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'

# coco数据集的label映射文件
PATH_TO_LABELS = os.path.join('object_detection/data', 'mscoco_label_map.pbtxt')

PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]

#模型下载与解压
def downloadModel():
  opener = urllib.request.URLopener()
  opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
  tar_file = tarfile.open(MODEL_FILE)
  for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
      tar_file.extract(file, os.getcwd())


#加载模型
def loadModel():
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return detection_graph

#将图片转换为三维数组,数据类型为uint8
def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

#进行目标检测
def run_inference_for_single_image(image, graph):
  with graph.as_default():
    with tf.Session() as sess:
      # Get handles to input and output tensors
      ops = tf.get_default_graph().get_operations()
      all_tensor_names = {output.name for op in ops for output in op.outputs}
      tensor_dict = {}
      for key in [
          'num_detections', 'detection_boxes', 'detection_scores',
          'detection_classes'
      ]:
        tensor_name = key + ':0'
        if tensor_name in all_tensor_names:
          tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
              tensor_name)
      image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

      # Run inference
      output_dict = sess.run(tensor_dict,
                             feed_dict={image_tensor: image})

      # all outputs are float32 numpy arrays, so convert types as appropriate
      output_dict['num_detections'] = int(output_dict['num_detections'][0])
      output_dict['detection_classes'] = output_dict[
          'detection_classes'][0].astype(np.int64)
      output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
      output_dict['detection_scores'] = output_dict['detection_scores'][0]
  return output_dict

def predict(detection_graph):
    for image_path in TEST_IMAGE_PATHS:
        image = Image.open(image_path)
        # the array based representation of the image will be used later in order to prepare the
        # result image with boxes and labels on it.
        image_np = load_image_into_numpy_array(image)
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Actual detection.
        output_dict = run_inference_for_single_image(image_np_expanded, detection_graph)
        # 得到一个保存编号和类别描述映射关系的列表
        category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
        # Visualization of the results of a detection.
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            output_dict['detection_boxes'],
            output_dict['detection_classes'],
            output_dict['detection_scores'],
            category_index,
            instance_masks=output_dict.get('detection_masks'),
            use_normalized_coordinates=True,
            line_thickness=8)
        plt.figure(figsize=(12, 8))
        plt.imshow(image_np)
        plt.axis('off')
        plt.show()


if __name__ == '__main__':
    # downloadModel()
    detection_graph = loadModel()
    predict(detection_graph)


输出结果为:


欢迎关注我的个人公众号 AI计算机视觉工坊,本公众号不定期推送机器学习,深度学习,计算机视觉等相关文章,欢迎大家和我一起学习,交流。