16

PyTorch中基于TPU的FastAI多类图像分类

 3 years ago
source link: https://flashgene.com/archives/143508.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

作者|PRUDHVI VARMA 编译|VK 来源|Analytics Indiamag

计算机视觉因其广泛的应用而成为人工智能领域中最具发展趋势的子领域之一。在某些领域,甚至它们在快速准确地识别图像方面超越了人类的智能。

在本文中,我们将演示最流行的计算机视觉应用之一-多类图像分类问题,使用fastAI库和TPU作为硬件加速器。TPU,即张量处理单元,可以加速深度学习模型的训练过程。

本文涉及的主题:

多类图像分类

常用的图像分类模型

使用TPU并在PyTorch中实现

多类图像分类

我们使用图像分类来识别图像中的对象,并且可以用于检测品牌logo、对对象进行分类等。但是这些解决方案有一个局限性,即只能识别对象,但无法找到对象的位置。但是与目标定位相比,图像分类模型更容易实现。

图像分类的常用模型

我们可以使用VGG-16/19,Resnet,Inception v1,v2,v3,Wideresnt,Resnext,DenseNet等,它们是卷积神经网络的高级变体。这些是流行的图像分类网络,并被用作许多最先进的目标检测和分割算法的主干。

基于FasAI库和TPU硬件的图像分类

我们将在以下方面开展这项工作步骤:

1.选择硬件加速器

这里我们使用Google Colab来实现。要在Google Colab中使用TPU,我们需要打开edit选项,然后打开notebook设置,并将硬件加速器更改为TPU。

通过运行下面的代码片段,你可以检查你的Notebook是否正在使用TPU。

import os
assert os.environ['COLAB_TPU_ADDR']
Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', Path)

2.加载FastAI库

在下面的代码片段中,我们将导入fastAI库。

from fastai.vision import *
from fastai.metrics import error_rate, accuracy

3.定制数据集

在下面的代码片段中,你还可以尝试使用自定义数据集。

PATH = '/content/images/dataset'
np.random.seed(24)
tfms = get_transforms(do_flip=True)
data = ImageDataBunch.from_folder(PATH, valid_pct=0.2, ds_tfms=tfms, size=299, bs=16).normalize(imagenet_stats)
data.show_batch(rows=4, figsize=(8, 8))

4.加载预训练的深度学习模型

在下面的代码片段中,我们将导入VGG-19 batch_normalisation模型。我们将把它作为fastAI的计算机视觉学习模块的一个实例。

learn = cnn_learner(data, models.vgg19_bn, metrics=accuracy)

5.训练模型

在下面的代码片段中,我们尝试使用一个epoch。

learn.fit_one_cycle(1)

在输出中,我们可以看到我们得到了0.99的准确度,它花了1分2秒。

在下面的代码片段中,我们使用混淆矩阵显示结果。

con_matrix = ClassificationInterpretation.from_learner(learn)
con_matrix.plot_confusion_matrix()

6.利用模型进行预测

在下面的代码片段中,我们可以通过在test_your_image中给出图像的路径来测试我们自己的图像。

test_your_image='/content/images (3).jpg'
test = open_image(test_your_image)
test.show()

在下面的代码片段中,我们可以得到输出张量及其所属的类。

learn.predict(test)

正如我们在上面的输出中看到的,模型已经预测了输入图像的类标签,它属于“flower”类别。

结论

在上面的演示中,我们使用带TPU的fastAI库和预训练VGG-19模型实现了一个多类的图像分类。在这项任务中,我们在对验证数据集进行分类时获得了0.99的准确率。

原文链接: https://analyticsindiamag.com…


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK