50

基于 Pytorch 的多类别图像分类实战

 4 years ago
source link: https://www.tuicool.com/articles/iyueaqR
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Z7BBjeZ.jpg!web

本篇基于 Pytorch 完成一个多类别图像分类实战。

1 简介

fQzQJbN.jpg!web 实现一个完整的图像分类任务,大致需要分为五个步骤:

1、选择开源框架

目前常用的深度学习框架主要包括 tensorflow、caffe、pytorch、mxnet 等;

2、构建并读取数据集

根据任务需求搜集相关图像搭建相应的数据集,常见的方式包括:网络爬虫、实地拍摄、公共数据使用等。随后根据所选开源框架读取数据集。

3、框架搭建

选择合适的网络模型、损失函数以及优化方式,以完成整体框架的搭建

4、训练并调试参数

通过训练选定合适超参数

5、测试准确率

在测试集上验证模型的最终性能

本文利用 Pytorch 框架,按照上述结构实现一个基本的图像分类任务,并详细阐述其中的细节及注意事项。

2 数据集

7jeEjuz.jpg!web

本次实战选择的数据集为 Kaggle 竞赛中的细胞数据集,共包含 9961 个训练样本,2491 个测试样本,可以分为嗜曙红细胞、淋巴细胞、单核细胞、中性白细胞 4 个类别,图片大小为 320x240。

Pytorch 中封装了相应的数据读取的类函数,通过调用 torch.utils.data.Datasets 函数,则可以实现读取功能。

eUrQ7nm.jpg!web

init()模块用来定义相关的参数, len ()模块用来获取训练样本个数, getitem ()模块则用来获取每张具体的图片,在读取图片时其可以通过 opencv 库、PIL 库等进行读取,具体代码如下:

数据集

class dataset(data.Dataset):

# 参数预定义

def init (self, anno_pd, transforms=None):

self.paths = anno_pd[‘ImageName’].tolist()

self.labels = anno_pd[‘label’].tolist()

self.transforms = transforms

# 返回图片个数

def len (self):

return len(self.paths)

# 获取每个图片

def getitem (self, item):

img_path =self.paths[item]

img_id =img_path.split("/")[-1]

img =cv2.imread(img_path)

img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

if self.transforms is not None:

img = self.transforms(img)

label = self.labels[item]

return torch.from_numpy(img).float(), int(label)

此外,需要定义图像增强模块,即上述代码中的 transform,通常采取的操作为翻转、剪切等,关于图像增强的具体介绍可以参考公众号前作。

【技术综述】深度学习中的数据增强方法都有哪些?

需要特别强调的是对图像进行去均值处理,很多同学不明白为何要减去均值,其主要的原因是图像作为一种平稳的数据分布,通过减去数据对应维度的统计平均值,可以消除公共部分,以凸显个体之间的特征和差异。进行去均值前后操作后的图像对比如下:

vAnERvE.jpg!web

3 框架搭建

本次实战主要选取了 VGG16、Resnet50、InceptionV4 三个经典网络,也是对前篇文章的一个总结。

损失函数则选择交叉熵损失函数: 【技术综述】一文道尽 softmax loss 及其变种

优化方式选择 SGD、Adam 优化两种: 【模型训练】SGD 的那些变种,真的比 SGD 强吗

4 训练及参数调试

初始学习率设置为 0.01,batch size 设置为 8,衰减率设置为 0.00001,迭代周期为 15,在不同框架组合下的最佳准确率和最低 loss 如下图所示:

fYjimu7.png!webJjumee3.png!web

可以发现在验证集上 Resnet-50+SGD+Cross Entropy 的组合下取得了 99% 左右的准确率,相反 VGG-16 结果则稍微差一些。

最佳组合下的准确率走势曲线如下图所示:

Yb2Mvmr.jpg!web

5 测试

对上述模型分别在测试集上进行测试,所获得的结果如下图所示,整体精度比训练集上约下降了一个百分点:

qe2Un2e.png!web

总结

以上就是整个多类别图像分类实战的过程,由于时间限制,本次实战并没有对多个数据集进行练,因此没有列出同一模型在不同数据集上的表现。

作者介绍

郭冰洋,公众号“有三 AI”作者。该公号聚焦于让大家能够系统性地完成 AI 各个领域所需的专业知识的学习。

原文链接

https://mp.weixin.qq.com/s/jPpZLYXQBX7l5AUfFV5n3g


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK