几行代码构建全功能的对象检测模型,他是如何做到的?

新闻 前端
如今,机器学习和计算机视觉已成为一种热潮。我们都看过关于自动驾驶汽车和面部识别的新闻,可能会想象建立自己的计算机视觉模型有多酷。

 如今,机器学习和计算机视觉已成为一种热潮。我们都看过关于自动驾驶汽车和面部识别的新闻,可能会想象建立自己的计算机视觉模型有多酷。然而,进入这个领域并不总是那幺容易,尤其是在没有很强的数学背景的情况下。如果你只想做一些小的实验,像PyTorch和TensorFlow这样的库可能会很枯燥。

在本教程中,作者提供了一种简单的方法,任何人都可以使用几行代码构建全功能的对象检测模型。更具体地说,我们将使用Detecto,这是一个在PyTorch之上构建的Python软件包,可简化该过程并向所有级别的程序员开放。

快速简单的例子

为了演示如何简单地使Detecto,让我们加载一个预先训练的模型,并对以下图像进行推断:

首先,使用pip下载Detecto软件包:

pip3 install detecto

然后,将上面的图像另存为“fruit.jpg”,并在与图像相同的文件夹中创建一个Python文件。在Python文件中,编写以下5行代码:

  1. from detectoimport core, utils, visualize 
  2. image = utils.read_image('fruit.jpg'
  3. model = core.Model() 
  4. labels, boxes, scores = model.predict_top(image) 
  5. visualize.show_labeled_image(image, boxes, labels) 

运行此文件后(如果你的计算机上没有启用CUDA的GPU,可能会花费几秒钟;稍后再进行介绍),你应该会看到类似下面的图:

作者仅用了5行代码就完成了所有工作,真的是太棒了。下面是我们每步中分别做的:

1)导入Detecto模块

2)读入图像

3)初始化预训练模型

4)在图像上生成最高预测

5)为预测绘图

绘制我们的预测

Detecto使用来自PyTorch模型动物园中的Faster R-CNN ResNet-50 FPN,它能够检测大约80种不同的物体,例如动物,车辆,厨房用具等。但是,如果你想要检测自定义对象,例如可口可乐与百事可乐罐,斑马与长颈鹿,该怎幺办呢?

这时你会发现,在自定义数据集上训练探测器模型同样简单; 同样,你只需要5行代码,以及现有的数据集或花一些时间标记图像。

构建自定义数据集

在本教程中,作者将从头开始构建自己的数据集。建议你也这样做,但是如果你想跳过这一步,你可以在这里下载一个示例数据集(从斯坦福的Dog数据集修改)。

对于我们的数据集,我们将训练我们的模型来检测来自RoboSub竞赛的水下外星人,蝙蝠和女巫,如下所示:

[[315504]]

理想情况下,每个类至少需要100张图像。好在每张图像中可以有多个对象,所以理论上,如果每张图像包含你想要检测的每类对象,那幺你可以总共获得100张图像。另外,如果你有视频素材,Detico可以轻松地将这些视频素材分割成可用于数据集的图像:

  1. from detecto.utilsimport split_video 
  2. split_video('video.mp4','frames/', step_size=4

上面的代码在“video.mp4”中每第4帧拍摄一次,并将其另存为JPEG文件存在“frames”文件夹中。

生成训练数据集后,应该具有一个类似于以下内容的文件夹:

  1. images/ 
  2. |   image0.jpg 
  3. |   image1.jpg 
  4. |   image2.jpg 
  5. |   ... 

如果需要的话,你还可以使用另一个文件夹,其中包含一组验证图像。

现在是耗时的部分:标记。Detecto支持PASCAL VOC格式,其中具有XML文件,其中包含图像中每个对象的标签和位置数据。要创建这些XML文件,可以使用开源LabelImg工具,如下所示:

  1. pip3 install labelImg   # Download LabelImg using pip 
  2. labelImg                # Launch the application 

现在,你应该会看到一个弹出窗口。单击左侧“打开目录”按钮,然后选择想要标记的图像文件夹。如果一切正常,你应该会看到类似以下内容:

要绘制边界框,请单击左侧菜单栏中的图标(或使用键盘快捷键“w”)。然后,你可以在对象周围拖动一个框并编写/选择标签:

标记完图像后,请使用CTRL+S或CMD+S保存XML文件(为简便起见,你可以使用自动填充的默认文件位置和名称)。要标记下一张图像,请单击“下一张图像”(或使用键盘快捷键“d”)。

整个数据集处理完毕之后,你的文件夹应如下所示:

  1. images/ 
  2. |   image0.jpg 
  3. |   image0.xml 
  4. |   image1.jpg 
  5. |   image1.xml 
  6. |   ... 

我们已经准备好开始训练我们的对象检测模型了!

访问GPU

首先,检查你的计算机是否具有启用CUDA的GPU。由于深度学习需要大量处理能力,因此在通常的CPU上进行训练可能会非常缓慢。值得庆幸的是,大多数现代深度学习框架(例如PyTorch和Tensorflow)都可以在GPU上运行,从而使处理速度更快。 确保已经下载了PyTorch(如果你安装了Detecto,应该已经下载了),然后运行以下两行代码:

  1. import torch 
  2. print(torch.cuda.is_available()) 

如果打印True,那你可以跳到下一部分。如果显示False,不要担心。请按照以下步骤创建Google Colaboratory笔记本,这是一个在线编码环境,带有免费可用的GPU。对于本教程,你将只在Google Drive文件夹中工作,而不是在计算机上工作。

1)登录到Google Drive

2)创建一个名为“Detecto Tutorial”的文件夹并导航到该文件夹

3)将你的训练图像(和/或验证图像)上传到此文件夹

4)右键单击,转到“更多”,然后单击“Google Colaboratory”:

你现在应该看到这样的界面:

5)根据需要给笔记本起个名字,然后转到“编辑”->“笔记本设置”->“硬件加速器”,然后选择“GPU”

6)输入以下代码以“装入”你的云端硬盘,将目录更改为当前文件夹,然后安装Detecto:

  1. import os 
  2. from google.colabimport drive 
  3. drive.mount('/content/drive'
  4. os.chdir('/content/drive/My Drive/Detecto Tutorial'
  5. !pip install detecto 

为了确保一切正常,你可以创建一个新的代码单元,然后输入 !ls 以检查你是否处于正确的目录中。

训练自定义模型

最后,我们现在可以在自定义数据集上训练模型了。如前所述,这是容易的部分。它只需要4行代码:

  1. from detectoimport core, utils, visualize 
  2. dataset = core.Dataset('images/'
  3. model = core.Model(['alien','bat','witch']) 
  4. model.fit(dataset) 

让我们再次分解一下我们每行代码所做的工作:

1、导入的Detecto模块

2、从“images”文件夹(包含我们的JPEG和XML文件)创建了一个数据集

3、初始化模型检测自定义对象(外星人,蝙蝠和女巫)

4、在数据集上训练我们的模型

根据数据集的大小,这可能需要10分钟到1个小时以上的时间来运行,因此请确保你的程序在完成上述语句后不会立即退出(例如:你使用的是Jupyter / Colab笔记本,它在活动时保留状态)。

使用训练好的模型

现在你已经有了训练好的模型,让我们在一些图像上对其进行测试。要从文件路径读取图像,可以使用 detecto.utils 模块中的 read_image 函数(也可以使用上面创建的数据集中的图像):

  1. # Specify the path to your image 
  2. image = utils.read_image('images/image0.jpg'
  3. predictions = model.predict(image) 
  4. # predictions format: (labels, boxes, scores) 
  5. labels, boxes, scores = predictions 
  6. # ['alien''bat''bat'
  7. print(labels) 
  8. #           xmin       ymin       xmax       ymax 
  9. # tensor([[ 569.2125,  203.67021003.4383,  658.1044], 
  10. #         [ 276.2478,  144.0074,  579.6044,  508.7444], 
  11. #         [ 277.2929,  162.6719,  627.9399,  511.9841]]) 
  12. print(boxes) 
  13. # tensor([0.99520.98370.5153]) 
  14. print(scores) 

正像你看到的,模型的预测方法返回一个由3个元素组成的元组:标签,方框和分数。在上面的示例中,此模型在坐标[569、204、1003、658](框[0])处预测了一个外星人(标签[0]),其置信度为0.995(得分[0])。

根据这些预测,我们可以使用 detecto.visualize 模块绘制结果。例如:

  1. visualize.show_labeled_image(image, boxes, labels) 

将上面的代码与收到的图像和预测一起运行将产生如下所示的内容:

[[315505]]

如果你有一个视频,你可以在它上面运行对象检测:

  1. visualize.detect_video(model,'input.mp4','output.avi'

这将获取一个名为“input.mp4”的视频文件,并根据给定模型的预测结果生成一个“output.avi”文件。如果你使用VLC或其他视频播放器打开此文件,应该会看到一些希望看到的结果!

最后,你可以从文件中保存和加载模型,从而可以保存进度并稍后返回:

  1. model.save('model_weights.pth'
  2. # ... Later ... 
  3. model = core.Model.load('model_weights.pth', ['alien','bat','witch']) 

高级用法

你会发现Detecto不仅限于5行代码。举例来说,这个模型没有你希望的那幺好。我们可以尝试通过使用Torchvision转换来扩展我们的数据集并定义一个自定义数据加载器来提高其性能:

  1. from torchvisionimport transforms 
  2. augmentations = transforms.Compose([ 
  3. transforms.ToPILImage(), 
  4. transforms.RandomHorizontalFlip(0.5), 
  5. transforms.ColorJitter(saturation=0.5), 
  6. transforms.ToTensor(), 
  7. utils.normalize_transform(), 
  8. ]) 
  9. dataset = core.Dataset('images/', transform=augmentations) 
  10. loader = core.DataLoader(dataset, batch_size=2, shuffle=True) 

此代码对数据集中的图像应用了随机的水平翻转和饱和效果,从而增加了数据的多样性。然后,我们使用 batch_size = 2 定义一个数据加载对象;我们将其传递给 model.fit 而不是Dataset,这样来告诉我们的模型是对2张图像进行批量训练,而不是默认的1张。

如果你之前创建了单独的验证数据集,那幺现在是在训练期间加载它的时候了。通过提供验证数据集, fit 方法将返回每个时期的损失列表,如果 verbose = True ,则会在训练过程中将其打印出来。以下代码块演示了这一点,并自定义了其他几个训练参数:

  1. import matplotlib.pyplotas plt 
  2. val_dataset = core.Dataset('validation_images/'
  3. losses = model.fit(loader, val_dataset, epochs=10, learning_rate=0.001
  4. lr_step_size=5, verbose=True) 
  5. plt.plot(losses) 
  6. plt.show() 

损失的结果图应或多或少地减少:

为了更具有灵活性和对模型的控制,你可以完全绕过Detecto。你可以根据需要随意调整 model.get_internal_model 方法返回使用的基础模型。

结论

在本教程中,作者展示了计算机视觉和对象检测不需要具有挑战性。你所需要的是一点时间和耐心来处理标记的的数集。

如果你对进一步探索感兴趣的话,请查看Detecto on GitHub或访问文档以获取更多教程和用例!

责任编辑:张燕妮 来源: AI科技大本营
相关推荐

2023-11-30 10:13:17

TensorRT架构

2017-11-14 08:25:36

数据库MySQL安全登陆

2016-11-30 14:18:30

互联网

2021-08-02 09:01:05

MySQL 多版本并发数据库

2019-12-23 09:25:29

日志Kafka消息队列

2018-05-15 16:19:39

程序员bug代码

2019-01-03 14:00:37

降价青云全栈云

2011-11-09 15:49:52

API

2011-06-22 09:45:46

JavaScriptAPI

2019-11-27 18:33:32

Docker架构数据

2017-12-05 11:48:44

AI人工智能开发者

2020-06-01 08:41:29

苏宁分析大数据

2024-03-08 07:58:13

QPShttpsync

2011-08-01 09:08:49

程序员

2020-09-25 09:52:48

机器学习人工智能计算机

2018-09-07 18:14:37

2011-04-29 10:32:46

项目管理

2009-11-20 11:37:11

Oracle完全卸载

2014-04-01 09:29:12

2018-07-12 09:51:04

Python代码对象模型
点赞
收藏

51CTO技术栈公众号