温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

怎么使用Tensorflow hub完成目标检测

发布时间:2023-04-18 17:46:33 来源:亿速云 阅读:143 作者:iii 栏目:开发技术

今天小编给大家分享一下怎么使用Tensorflow hub完成目标检测的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。

前言

使用到的主要环境是:

  • tensorflow-cpu=2.10

  • tensorflow-hub=0.11.0

  • tensorflow-estimator=2.6.0

  • python=3.8

  • protobuf=3.20.1

导入必要的库

首先导入必要的 python 包,后面要做一些复杂的安装和配置工作,需要一点耐心和时间。在运行下面代码的时候可能会报错:

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

你只需要重新使用 pip 安装,将 protobuf 降低到 3.20.x 版本即可。

import os
import pathlib
import matplotlib
import matplotlib.pyplot as plt
import io
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from six.moves.urllib.request import urlopen
import tensorflow as tf
import tensorflow_hub as hub

tf.get_logger().setLevel('ERROR')

准备数据和模型

(1)到 github.com/protocolbuf… 用迅雷下载对应操作系统的压缩包,我的是 win7 版本

(2)下载好之后随便解压到自定义目录,我的是 “主目录\protoc-22.1-win64”,然后将其中的 “主目录\protoc-22.1-win64\bin” 路径添加到用户环境变量中的 PATH 变量中,重新打开命令行,输入 protoc --version ,如果能正常返回版本号说明配置成功,可以开始使用。

(3)进入命令行,在和本文件同一个目录下,执行命令

git clone --depth 1 https://github.com/tensorflow/models

,将 models 文件夹下载下来,进入 models/research/ 下,使用命令执行

protoc object_detection/protos/*.proto --python_out=.

将 models/research/object_detection/packages/tf2/setup.py 拷贝到和 models/research/ 下,然后使用执行本文件的 python 对应的 pip 去执行安装包操作

..\Anaconda3\envs\tfcpu2.10_py38\Scripts\pip.exe install . -i https://pypi.tuna.tsinghua.edu.cn/simple

中间可能会报错“error: netadata-generation-failed”,一般都是某个包安装的时候出问题了,我们只需要看详细的日志,单独用 pip 进行安装即可,单独安装完之后,再去执行上面的根据 setup.py 的整装操作,反复即可,过程有点麻烦但还是都可以安装成功的。

(4)这里的模型本来在:

https://tfhub.dev/tensorflow/centernet/hourglass\_512x512\_kpts/1

但是由于网络问题无法获取,所以我们可以改为从

https://storage.googleapis.com/tfhub-modules/tensorflow/centernet/hourglass\_512x512\_kpts/1.tar.gz

获取模型。

from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import ops as utils_ops


PATH_TO_LABELS = './models/research/object_detection/data/mscoco_label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)

model_path = 'https://storage.googleapis.com/tfhub-modules/tensorflow/centernet/hourglass_512x512_kpts/1.tar.gz'
print('TensorFlow Hub 中的模型地址: {}'.format(model_path))
print('加载模型...')
hub_model = hub.load(model_path)
print('加载成功!')

打印结果:

TensorFlow Hub 中的模型地址: https://storage.googleapis.com/tfhub-modules/tensorflow/centernet/hourglass_512x512_kpts/1.tar.gz
加载模型...
WARNING:absl:Importing a function (__inference_batchnorm_layer_call_and_return_conditional_losses_42408) with ops with custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_batchnorm_layer_call_and_return_conditional_losses_209416) with ops with custom gradients. Will likely fail if a gradient is requested.
...
WARNING:absl:Importing a function (__inference_batchnorm_layer_call_and_return_conditional_losses_56488) with ops with custom gradients. Will likely fail if a gradient is requested.
加载成功!

(5)在这里我们主要定义了一个函数 load_image_into_numpy_array 来加载从网上下载图片的图片,并将其转换为模型可以适配的输入类型。

(6)IMAGES_FOR_TEST 字典中记录了多个可以用来测试的图片,但是这些都是在网上,用的使用需要调用 load_image_into_numpy_array 函数。

(7)COCO17_HUMAN_POSE_KEYPOINTS 记录了人体姿态关键点。

(8)我们这里展示了 dogs 这张图片,可以看到两条可爱的小狗。

def load_image_into_numpy_array(path):
    image = None
    if(path.startswith('http')):
        response = urlopen(path)
        image_data = response.read()
        image_data = BytesIO(image_data)
        image = Image.open(image_data)
    else:
        image_data = tf.io.gfile.GFile(path, 'rb').read()
        image = Image.open(BytesIO(image_data))

    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape((1, im_height, im_width, 3)).astype(np.uint8)


IMAGES_FOR_TEST = {
  'Beach' : 'models/research/object_detection/test_images/image2.jpg',
  'Dogs' : 'models/research/object_detection/test_images/image1.jpg',
  'Naxos Taverna' : 'https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg',
  'Beatles' : 'https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg',
  'Phones' : 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg/1024px-Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg',
  'Birds' : 'https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053836633%29.jpg',
}

COCO17_HUMAN_POSE_KEYPOINTS = [(0, 1), (0, 2),(1, 3),(2, 4),(0, 5),(0, 6),(5, 7),(7, 9),(6, 8),(8, 10),(5, 6),(5, 11), (6, 12),(11, 12),(11, 13),(13, 15),(12, 14),(14, 16)]

%matplotlib inline
selected_image = 'Dogs' 
image_path = IMAGES_FOR_TEST[selected_image]
image_np = load_image_into_numpy_array(image_path)
plt.figure(figsize=(24,32))
plt.imshow(image_np[0])
plt.show()

目标检测

我们这里将经过处理的小狗的图片传入模型中,会返回结果,我们只要使用结果来绘制出所检测目标的框,以及对应的类别,分数,可以看出来结果是相当的准确的,甚至通过人的腿就能识别出人的框。

results = hub_model(image_np)
result = {key:value.numpy() for key,value in results.items()}
label_id_offset = 0
image_np_with_detections = image_np.copy()

keypoints, keypoint_scores = None, None
if 'detection_keypoints' in result:
    keypoints = result['detection_keypoints'][0]
    keypoint_scores = result['detection_keypoint_scores'][0]

viz_utils.visualize_boxes_and_labels_on_image_array(
      image_np_with_detections[0],
      result['detection_boxes'][0],
      (result['detection_classes'][0] + label_id_offset).astype(int),
      result['detection_scores'][0],
      category_index,
      use_normalized_coordinates=True,
      max_boxes_to_draw=200,
      min_score_thresh=.30,
      agnostic_mode=False,
      keypoints=keypoints,
      keypoint_scores=keypoint_scores,
      keypoint_edges=COCO17_HUMAN_POSE_KEYPOINTS)

plt.figure(figsize=(24,32))
plt.imshow(image_np_with_detections[0])
plt.show()

以上就是“怎么使用Tensorflow hub完成目标检测”这篇文章的所有内容,感谢各位的阅读!相信大家阅读完这篇文章都有很大的收获,小编每天都会为大家更新不同的知识,如果还想学习更多的知识,请关注亿速云行业资讯频道。

向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI