22

使用PyTorch简单实现图像分割网络FCN | 鸢尾花开

 4 years ago
source link: http://ishero.net/%E4%BD%BF%E7%94%A8PyTorch%E7%AE%80%E5%8D%95%E5%AE%9E%E7%8E%B0%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2%E7%BD%91%E7%BB%9CFCN.html?
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

使用PyTorch简单实现图像分割网络FCN

发布 : 2020-01-25 分类 : 深度学习 浏览 : 2269


构造一个简单的全卷积神经网络作为解码器,编码器使用预训练模型ResNet18。数据集使用VOC2012。

在写的过程中,遇到了些坎,这里做个记录。

训练时的ground truth

简写GT,即图像标注。
计算loss时要求predict出的特征图outputs的shape与它的标签GT一致。而模型的输出shape格式是:
(batch_size, classes, channels, height, width),而我们的标签在未做处理之前是没有classes这个维度的,即(batch_size, channels, height, width),因此在数据输入之前需要做处理,才能正确预测,对应代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
def __getitem__(self, idx):
feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
*self.crop_size)
label = voc_label_indices(label, self.colormap2label).numpy().astype('uint8')

# 统一GT
h, w = label.shape
target = torch.zeros(21, h, w)
for c in range(21):
target[c][label == c] = 1

return (self.tsf(feature), target)

这里是完整代码中对应的片段

1
2
3
4
5
6
7
8
9
10
11
12
resnet18 = models.resnet18(pretrained=True)
resnet18_modules = [layer for layer in resnet18.children()]
net = nn.Sequential()
for i, layer in enumerate(resnet18_modules[:-2]):
net.add_module(str(i), layer)

net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module("ConvTranspose2d",
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))

net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)

简单讲解下。首先使用pytorch提供resnet18预训练模型。为了提取模型中我们需要的部分,我们需要遍历它。将需要的module添加到我们的net中。在resnet18模型之后添加一层kernel size 为1的卷积层,做通道卷积。然后再添加一层转置卷积层,将特征图尺寸映射到输入尺寸。为了让模型能够快速收敛,我们指定了新添加的两层的kernel参数初始化方式。其中转置卷积层使用了输入的双线性差值作为初始化。

1
2
3
4
5
6
7
8
9
10
11
12
13
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
dtype='float32')
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.tensor(weight)

完整训练代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from tqdm import tqdm

from FCN.VOC2012Dataset import VOC2012SegDataIter
import torch
from torch import nn, optim
import numpy as np
from torchvision import models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_classes = 21


def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
dtype='float32')
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.tensor(weight)


if __name__ == '__main__':
batch_size = 4
train_iter, val_iter = VOC2012SegDataIter(batch_size, (320, 480), 2, 200)

resnet18 = models.resnet18(pretrained=True)
resnet18_modules = [layer for layer in resnet18.children()]
net = nn.Sequential()
for i, layer in enumerate(resnet18_modules[:-2]):
net.add_module(str(i), layer)

net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module("ConvTranspose2d",
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))

net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)

net = net.to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
lossFN = nn.BCEWithLogitsLoss()

num_epochs = 10
for epoch in range(num_epochs):
sum_loss = 0
sum_acc = 0
batch_count = 0
n = 0
for X, y in tqdm(train_iter):
X = X.to(device)
y = y.to(device)
y_pred = net(X)
loss = lossFN(y_pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

sum_loss += loss.cpu().item()
n += y.shape[0]
batch_count += 1
print("epoch %d: loss=%.4f" % (epoch + 1, sum_loss / n))

VOC数据集读入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torchvision
from PIL import Image
import numpy as np


def voc_label_indices(colormap, colormap2label):
"""
convert colormap (PIL image) to colormap2label (uint8 tensor).
"""
colormap = np.array(colormap.convert("RGB")).astype('int32')
idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
+ colormap[:, :, 2])
return colormap2label[idx]


def read_voc_images(root="./dataset/VOCdevkit/VOC2012",
is_train=True, max_num=None):
txt_fname = '%s/ImageSets/Segmentation/%s' % (
root, 'train.txt' if is_train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
if max_num is not None:
images = images[:min(max_num, len(images))]
features, labels = [None] * len(images), [None] * len(images)
for i, fname in enumerate(images):
features[i] = Image.open('%s/JPEGImages/%s.jpg' % (root, fname)).convert("RGB")
labels[i] = Image.open('%s/SegmentationClass/%s.png' % (root, fname)).convert("RGB")
return features, labels # PIL image


def voc_rand_crop(feature, label, height, width):
"""
Random crop feature (PIL image) and label (PIL image).
"""
i, j, h, w = torchvision.transforms.RandomCrop.get_params(
feature, output_size=(height, width))

feature = torchvision.transforms.functional.crop(feature, i, j, h, w)
label = torchvision.transforms.functional.crop(label, i, j, h, w)

return feature, label


class VOCSegDataset(torch.utils.data.Dataset):
def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None):
"""
crop_size: (h, w)
"""
self.rgb_mean = np.array([0.485, 0.456, 0.406])
self.rgb_std = np.array([0.229, 0.224, 0.225])
self.tsf = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=self.rgb_mean,
std=self.rgb_std)
])

self.crop_size = crop_size # (h, w)
features, labels = read_voc_images(root=voc_dir,
is_train=is_train,
max_num=max_num)
self.features = self.filter(features) # PIL image
self.labels = self.filter(labels) # PIL image
self.colormap2label = colormap2label
print('read ' + str(len(self.features)) + ' valid examples')

def filter(self, imgs):
return [img for img in imgs if (
img.size[1] >= self.crop_size[0] and
img.size[0] >= self.crop_size[1])]

def __getitem__(self, idx):
feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
*self.crop_size)
label = voc_label_indices(label, self.colormap2label).numpy().astype('uint8')

# 统一GT
h, w = label.shape
target = torch.zeros(21, h, w)
for c in range(21):
target[c][label == c] = 1

return (self.tsf(feature), target)

def __len__(self):
return len(self.features)


def VOC2012SegDataIter(batch_size=64, crop_size=(320, 480), num_workers=4, max_num=None):
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]]
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person',
'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8)
for i, colormap in enumerate(VOC_COLORMAP):
colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i

voc_train = VOCSegDataset(True, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num)
voc_val = VOCSegDataset(False, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num)
train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True, drop_last=True,
num_workers=num_workers)
val_iter = torch.utils.data.DataLoader(voc_val, batch_size, drop_last=True, num_workers=num_workers)
return train_iter, val_iter

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK