轻松学Pytorch-使用ResNet50实现图像分类
source link: http://mp.weixin.qq.com/s?__biz=MzAxMjMwODMyMQ%3D%3D&%3Bmid=2456351188&%3Bidx=1&%3Bsn=fbd197f8b57f4e6e049280fd19d7651a
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模型权重文件、常见图像变换、计算机视觉任务训练。可以是说是pytorch中非常有用的模型迁移学习神器。本文将会介绍如何使用torchvison的预训练模型ResNet50实现图像分类。
模型
Torchvision.models包里面包含了常见的各种基础模型架构,主要包括:
AlexNet VGG ResNet SqueezeNet DenseNet Inception v3 GoogLeNet ShuffleNet v2 MobileNet v2 ResNeXt Wide ResNet MNASNet
这里我选择了ResNet50,基于ImageNet训练的基础网络来实现图像分类, 网络模型下载与加载如下:
model = torchvision.models.resnet50(pretrained=True).eval().cuda()
tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
使用模型实现图像分类
这里首先需要加载ImageNet的分类标签,目的是最后显示分类的文本标签时候使用。然后对输入图像完成预处理,使用ResNet50模型实现分类预测,对预测结果解析之后,显示标签文本,完整的代码演示如下:
1with open('imagenet_classes.txt') as f:
2 labels = [line.strip() for line in f.readlines()]
3
4src = cv.imread("D:/images/space_shuttle.jpg") # aeroplane.jpg
5image = cv.resize(src, (224, 224))
6image = np.float32(image) / 255.0
7image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9image = image.transpose((2, 0, 1))
10input_x = torch.from_numpy(image).unsqueeze(0)
11print(input_x.size())
12pred = model(input_x.cuda())
13pred_index = torch.argmax(pred, 1).cpu().detach().numpy()
14print(pred_index)
15print("current predict class name : %s"%labels[pred_index[0]])
16cv.putText(src, labels[pred_index[0]], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
17cv.imshow("input", src)
18cv.waitKey(0)
19cv.destroyAllWindows()
运行结果如下:
转ONNX支持
在torchvision中的模型基本上都可以转换为ONNX格式,而且被OpenCV DNN模块所支持,所以,很方便的可以对torchvision自带的模型转为ONNX,实现OpenCV DNN的调用,首先转为ONNX模型,直接使用torch.onnx.export即可转换(还不知道怎么转,快点看前面的例子)。转换之后使用OpenCV DNN调用的代码如下:
1with open('imagenet_classes.txt') as f:
2 labels = [line.strip() for line in f.readlines()]
3net = cv.dnn.readNetFromONNX("resnet.onnx")
4src = cv.imread("D:/images/messi.jpg") # aeroplane.jpg
5image = cv.resize(src, (224, 224))
6image = np.float32(image) / 255.0
7image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9blob = cv.dnn.blobFromImage(image, 1.0, (224, 224), (0, 0, 0), False)
10net.setInput(blob)
11probs = net.forward()
12index = np.argmax(probs)
13cv.putText(src, labels[index], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
14cv.imshow("input", src)
15cv.waitKey(0)
16cv.destroyAllWindows()
运行结果见上图,这里就不再贴了。
✄------------------------------------------------
看到这里,说明你喜欢这篇文章,请点击「 在看 」或顺手「 转发 」「 点赞 」。
欢迎微信搜索「 panchuangxx 」,添加小编 磐小小仙 微信,每日朋友圈更新一篇高质量推文(无广告),为您提供更多精彩内容。
▼ ▼ 扫描二维码添加小编 ▼ ▼
Recommend
-
33
使用PyTorch实现经典的卷积神经网络。从LeNet到DenseNet介绍
-
29
↑ 点击 蓝字 关注极市平台 作者|Happy...
-
9
[PyTorch 学习笔记] 8.1 图像分类简述与 ResNet 源码分析哈尔滨工程大学 计算机硕士在读 本章代码:
-
5
无需额外数据、Tricks、架构调整,CMU 开源首个提升 ResNet50 精度至 80%+ 新方法 7个月前...
-
2
老潘家的潘老师 CV方向小学二年级在读 一个使用pytorch的图片分类教程——以明信片分类为例 ...
-
1
作者:贺弘 谦言 临在图像分割(Image Segmentation)是指对图片进行像素级的分类,根据分类粒度的不同可以分为语义分割(Semantic Segmentation)、实例分割(Instance Segmentation)、全景分割(Panoptic Segmentation)三类。图像分割是计算机视...
-
0
图神经网络(GNN)目前的主流实现方式就是节点之间的信息汇聚,也就是类似于卷积网络的邻域加权和,比如图卷积网络(GCN)、图注意力网络(GAT)等。下面根据GCN的实现原理使用Pytorch张量,和调用torch_geometric包,分别对Cora数据集进行节点分类实验。
-
3
MMClassification 图像分类使用 OpenMMLab 是香港中文大学-商汤科技联合实验室 MMLab 自 2018 年 10 月开源的一个计算机视觉领域的 AI 算法框架。其包含了众...
-
1
ResNet50的猫狗分类训练及预测 相比于之前写的ResN...
-
5
VGG16 、VGG19 、ResNet50 、Inception V3 、Xception介绍发布于 2020-04-26 05:45:232.8K0文章被收录于专栏:
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK