
6

基于Unet+opencv实现天空对象的分割、替换和美化 - jsxyhelu
source link: https://www.cnblogs.com/jsxyhelu/p/16995892.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

基于Unet+opencv实现天空对象的分割、替换和美化
原文地址:https://www.cnblogs.com/jsxyhelu/p/16995892.html
传统图像处理算法进行“天空分割”存在精度问题且调参复杂,无法很好地应对云雾、阴霾等情况;本篇文章分享的“基于Unet+opencv实现天空对象的分割、替换和美化”,较好地解决了该问题,包括以下内容:
1、基于Unet语义分割的基本原理、环境构建、参数调节等 2、一种有效的天空分割数据集准备方法,并且获得数据集 3、基于OpenCV的Pytorch模型部署方法 4、融合效果极好的 SeamlessClone 技术 5、饱和度调整、颜色域等基础图像处理知识和编码技术
本文适合具备 OpenCV 和Pytorch相关基础,对“天空替换”感兴趣的人士。学完本文,可以获得基于Pytorch和OpenCV进行语义分割、解决实际问题的具体方法,提高环境构建、数据集准备、参数调节和运行部署等方面综合能力。
一、传统方法和语义分割基础
1.1传统方法主要通过“颜色域”来进行分割
比如,我们要找的是蓝天,那么在HSV域,就可以通过查表的方法找出蓝色区域。
在这张表中,蓝色的HSV的上下门限已经标注出来,我们编码实现。
cvtColor(matSrc,temp,COLOR_BGR2HSV); split(temp,planes); equalizeHist(planes[2],planes[2]);//对v通道进行equalizeHist merge(planes,temp); inRange(temp,Scalar(100,43,46),Scalar(124,255,255),temp); erode(temp,temp,Mat());//形态学变换,填补内部空洞 dilate(temp,temp,Mat()); imshow("原始图",matSrc);
在这段代码中,有两个小技巧,一个是对模板(MASK)进行了形态学变化,这个不展开说;一个是我们首先对HSV图进行了3通道分解,并且直方图增强V通道,而后将3通道合并回去。通过这种方法能够增强原图对比度,让蓝天更蓝、青山更青……大家可以自己调试看一下。 显示处理后识别为天空的结果(在OpenCV中,白色代表1也就是由数据,黑色代表0也就是没数据)

对于天坛这幅图来说,效果不错。虽然在右上角错误,而塔中间的一个很小的空洞,这些后期都是可以规避掉的错误。

但是对于阴霾图片来说,由于天空中没有蓝色,识别起来就很错误很多。
1.2 语义分割基础
图像语义分割(semantic segmentation),从字面意思上理解就是让计算机根据图像的语义来进行分割,例如让计算机在输入下面左图的情况下,能够输出右图。语义在语音识别中指的是语音的意思,在图像领域,语义指的是图像的内容,对图片意思的理解,比如左图的语义就是三个人骑着三辆自行车;分割的意思是从像素的角度分割出图片中的不同对象,对原图中的每个像素都进行标注,比如右图中粉红色代表人,绿色代表自行车。

那么对于天空分割问题来说,主要目标就是找到像素级别的天空对象,使用语义分割模型就是有效的。
二、Unet基本情况和环境构建
Unet 发表于 2015 年,属于 FCN 的一种变体,Unet 的初衷是为了解决生物医学图像方面的问题,由于效果确实很好后来也被广泛的应用在语义分割的各个方向,比如卫星图像分割,工业瑕疵检测等。它也有很多变体,但是对于天空分割问题来看,Unet的能力已经够了。
Unet 跟 FCN 都是 Encoder-Decoder 结构,结构简单但很有效。Encoder 负责特征提取,你可以将自己熟悉的各种特征提取网络放在这个位置。由于在医学方面,样本收集较为困难,作者为了解决这个问题,应用了图像增强的方法,在数据集有限的情况下获得了不错的精度。

如上图,Unet 网络结构是对称的,形似英文字母 U 所以被称为 Unet。整张图都是由蓝/白色框与各种颜色的箭头组成,其中,蓝/白色框表示 feature map;蓝色箭头表示 3x3 卷积,用于特征提取;灰色箭头表示 skip-connection,用于特征融合;红色箭头表示池化 pooling,用于降低维度;绿色箭头表示上采样 upsample,用于恢复维度;青色箭头表示 1x1 卷积,用于输出结果。
在环境构建这块,我建议一定要结合自己的实际情况,构建专用的代码库,这样才能够通过不断迭代,在总体正确的前提下形成自己风格。
在我的库中,基于现有的Unet代码进行了修改

其中checkpoints、data保持数据;unet是模型的具体实现,未来可以扩充为多模型;utils是常用函数;alibaba.py和oss2helper.py是阿里云的辅助函数;export_unet.py是输出函数;eveluate.py和train.py用于训练;predict.py用于本地测试;main.py是主要函数。
三、数据集准备和增强
3.1 数据集准备这块,我采取了增强的方法。由于个人习惯问题,采用的是OpenCV本地变换的方法
getFiles("e:/template/Data_sky/data", fileNames); string saveFile = "e:/template/Data_sky/dataEX3/"; for (int index = 0; index < fileNames.size(); index++) { Mat src = imread(fileNames[index]); Mat dst; string fileName; getFileName(fileNames[index], fileName); resize(src, dst, cv::Size(512, 512)); imwrite(saveFile + fileName + "_512.jpg", dst); resize(src, dst, cv::Size(256, 256)); imwrite(saveFile + fileName + "_256.jpg", dst); resize(src, dst, cv::Size(128, 128)); imwrite(saveFile + fileName + "_128.jpg", dst); cout << fileName << endl; } fileNames.clear(); getFiles("e:/template/Data_sky/mask", fileNames); saveFile = "e:/template/Data_sky/maskEX3/"; for (int index = 0; index < fileNames.size(); index++) { Mat src = imread(fileNames[index], 0); Mat dst; string fileName; getFileName(fileNames[index], fileName); fileName = fileName.substr(0, fileName.size() - 3); resize(src, dst, cv::Size(512, 512)); imwrite(saveFile + fileName + "_512_gt.jpg", dst); resize(src, dst, cv::Size(256, 256)); imwrite(saveFile + fileName + "_256_gt.jpg", dst); resize(src, dst, cv::Size(128, 128)); imwrite(saveFile + fileName + "_128_gt.jpg", dst); cout << fileName << endl; }

从而获得不同分辨率的目标数据,但是如何获得标注数据?我推荐一种方法。
3.2、通过对“阿里视觉智能开放平台”的研究,调用它的成果来进行训练。简单来说,它提供了天空分割的功能,但是要求数据的输入输出都保存在oss中,所以需要通过python来编写脚本。我对这段python代码进行了一些注释,放在这里。
# -*- coding: utf8 -*- from aliyunsdkcore.client import AcsClient from aliyunsdkimageseg.request.v20191230 import SegmentSkyRequest from aliyunsdkimageseg.request.v20191230.SegmentHDSkyRequest import SegmentHDSkyRequest import oss2 import os import json import urllib # 创建 AcsClient 实例 client = AcsClient("LTAI5tQCCmMyKSfifwsFHLpC", "JyzNfHsCnUaVTeS6Xg3ylMjQFC8C6L", "cn-shanghai") request = SegmentSkyRequest.SegmentSkyRequest() endpoint = "https://oss-cn-shanghai.aliyuncs.com" accesskey_id = "LTAI5tQCCmMyKSfifwsFHLpC" accesskey_secret = "JyzNfHsCnUaVTeS6Xg3ylMjQFC8C6L" bucket_name = "datasky2" bucket_name2 = "viapi-cn-shanghai-dha-segmenter" #本地文件保存路径前缀 download_local_save_prefix = "/home/helu/GOPytorchHelper/data/dataOss/" ''' 列举prefix全部文件 ''' def prefix_all_list(bucket,prefix): print("开始列举"+prefix+"全部文件"); oss_file_size = 0; for obj in oss2.ObjectIterator(bucket, prefix ='%s/'%prefix): print(' key : ' + obj.key) oss_file_size = oss_file_size + 1; download_to_local(bucket, obj.key, obj.key); print(prefix +" file size " + str(oss_file_size)); ''' 列举全部的根目录文件夹、文件 ''' def root_directory_list(bucket): # 设置Delimiter参数为正斜线(/)。 for obj in oss2.ObjectIterator(bucket, delimiter='/'): # 通过is_prefix方法判断obj是否为文件夹。 if obj.is_prefix(): # 文件夹 print('directory: ' + obj.key); prefix_all_list(bucket,str(obj.key).strip("/")); #去除/ else: # 文件 print('file: ' +obj.key) # 填写Object完整路径,例如exampledir/exampleobject.txt。Object完整路径中不能包含Bucket名称。 object_name = obj.key # 生成下载文件的签名URL,有效时间为60秒。 # 生成签名URL时,OSS默认会对Object完整路径中的正斜线(/)进行转义,从而导致生成的签名URL无法直接使用。 # 设置slash_safe为True,OSS不会对Object完整路径中的正斜线(/)进行转义,此时生成的签名URL可以直接使用。 url = bucket.sign_url('GET', object_name, 60, slash_safe=True) print('签名url的地址为:', url) ## 如下url替换为自有的上海region的oss文件地址 request.set_ImageURL(url) response = client.do_action_with_exception(request) print('response地址为:', response) user_dict = json.loads(response) for name in user_dict.keys(): if(name.title() == "Data"): inner_dict = user_dict[name] for innerName in inner_dict.keys(): if(innerName == "ImageURL"): finalName = inner_dict[innerName] print('finalName地址为:',str(finalName)) urllib.request.urlretrieve(str(finalName), download_local_save_prefix+obj.key) ''' 下载文件到本地 ''' def download_to_local(bucket,object_name,local_file): url = download_local_save_prefix + local_file; #文件名称 file_name = url[url.rindex("/")+1:] file_path_prefix = url.replace(file_name, "") if False == os.path.exists(file_path_prefix): os.makedirs(file_path_prefix); print("directory don't not makedirs "+ file_path_prefix); # 下载OSS文件到本地文件。如果指定的本地文件存在会覆盖,不存在则新建。 bucket.get_object_to_file(object_name, download_local_save_prefix+local_file); if __name__ == '__main__': print("start \n"); # 阿里云主账号AccessKey拥有所有API的访问权限,风险很高。强烈建议您创建并使用RAM账号进行API访问或日常运维,请登录 https://ram.console.aliyun.com 创建RAM账号。 auth = oss2.Auth(accesskey_id,accesskey_secret) # Endpoint以杭州为例,其它Region请按实际情况填写。 bucket = oss2.Bucket(auth,endpoint , bucket_name) bucket2= oss2.Bucket(auth,endpoint , bucket_name2) #单个文件夹下载 root_directory_list(bucket); print("end \n");
四、模型训练概要
将数据集放入项目中,运行u2net_train.py即可。
4.1读懂训练部分代码,其中在step5的地方,我添加了一段处理,用于float和int类型之间转换
# 5. Begin training for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: images = batch['image'] true_masks = batch['mask'] assert images.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {images.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' images = images.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.long) ###### one = torch.ones_like(true_masks) zero = torch.zeros_like(true_masks) true_masks = torch.where(true_masks>0,one,zero) ##### with torch.cuda.amp.autocast(enabled=amp): masks_pred = net(images) loss = criterion(masks_pred, true_masks) \ + dice_loss(F.softmax(masks_pred, dim=1).float(), F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), multiclass=True) optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss).backward() grad_scaler.step(optimizer) grad_scaler.update() pbar.update(images.shape[0]) global_step += 1 epoch_loss += loss.item() pbar.set_postfix(**{'loss (batch)': loss.item()}) # Evaluation round division_step = (n_train // (10 * batch_size)) if division_step > 0: if global_step % division_step == 0: histograms = {} for tag, value in net.named_parameters(): tag = tag.replace('/', '.') val_score = evaluate(net, val_loader, device) scheduler.step(val_score) logging.info('Validation Dice score: {}'.format(val_score)) if save_checkpoint: Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1))) logging.info(f'Checkpoint {epoch + 1} saved!')
4.2 推荐适当投资,采购了autodl进行在线训练

通过predict生成模板结果,在Photoshop中进行比对发现边界已经比较贴合,最终在增强的数据集上,实现了DICE90%的目标。

五、基于OpenCV的Pytorch模型部署方法
这里为了进行总结,我对分别对目前使用Python和C++下的几种可行可用的推断方法进行汇总,并进一步比对。
5.1 (python)使用onnxruntime方法进行推断
session = onnxruntime.InferenceSession("转换的onnx文件") input_name = session.get_inputs()[0].name label_name = session.get_outputs()[0].name img_name_list = ['需要处理的图片'] image = Image.open(img_name_list[0]) w, h = image.size dataset = SalObjDataset( img_name_list=img_name_list, lbl_name_list=[], transform=transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) ) data_loader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=1 ) im = list(data_loader)[0]['image'] inputs_test = im inputs_test = inputs_test.type(torch.FloatTensor) with torch.no_grad(): inputs_test = Variable(inputs_test) res = session.run([label_name], {input_name: inputs_test.numpy().astype(np.float32)}) result = torch.from_numpy(res[0]) pred = result[:, 0, :, :] pred = normPRED(pred) pred = pred.squeeze() predict_np = pred.cpu().data.numpy() im = Image.fromarray(predict_np * 255).convert('RGB') im = im.resize((w, h), resample=Image.BILINEAR) im.show()
5.2 (python) 使用opencv方法
import os import argparse from skimage import io, transform import numpy as np from PIL import Image import cv2 as cv parser = argparse.ArgumentParser(description='Demo: U2Net Inference Using OpenCV') parser.add_argument('--input', '-i') parser.add_argument('--model', '-m', default='u2net_human_seg.onnx') args = parser.parse_args() def normPred(d): ma = np.amax(d) mi = np.amin(d) return (d - mi)/(ma - mi) def save_output(image_name, predict): img = cv.imread(image_name) h, w, _ = img.shape predict = np.squeeze(predict, axis=0) img_p = (predict * 255).astype(np.uint8) img_p = cv.resize(img_p, (w, h)) print('{}-result-opencv_dnn.png-------------------------------------'.format(image_name)) cv.imwrite('{}-result-opencv_dnn.png'.format(image_name), img_p) def main(): # load net net = cv.dnn.readNet('saved_models/sky_split.onnx') input_size = 320 # fixed # build blob using OpenCV img = cv.imread('test_imgs/sky1.jpg') blob = cv.dnn.blobFromImage(img, scalefactor=(1.0/255.0), size=(input_size, input_size), swapRB=True) # Inference net.setInput(blob) d0 = net.forward('output') # Norm pred = normPred(d0[:, 0, :, :]) # Save save_output('test_imgs/sky1.jpg', pred) if __name__ == '__main__': main()
5.3 (c++)使用libtorch方法
// std::string strModelPath = "E:/template/u2net_train.pt"; void bgr_u2net(cv::Mat& image_src, cv::Mat& result, torch::jit::Module& model) { //1.模型已经导入 auto device = torch::Device("cpu"); //2.输入图片,变换到320 cv::Mat image_src1 = image_src.clone(); cv::resize(image_src1, image_src1, cv::Size(320, 320)); cv::cvtColor(image_src1, image_src1, cv::COLOR_BGR2RGB); // 3.图像转换为Tensor torch::Tensor tensor_image_src = torch::from_blob(image_src1.data, { image_src1.rows, image_src1.cols, 3 }, torch::kByte); tensor_image_src = tensor_image_src.permute({ 2,0,1 }); // RGB -> BGR互换 tensor_image_src = tensor_image_src.toType(torch::kFloat); tensor_image_src = tensor_image_src.div(255); tensor_image_src = tensor_image_src.unsqueeze(0); // 拿掉第一个维度 [3, 320, 320] //4.网络前向计算 auto src = tensor_image_src.to(device); auto pred = model.forward({ src }).toTuple()->elements()[0].toTensor(); //模型返回多个结果,用toTuple,其中elements()[i-1]获取第i个返回值 //d1,d2,d3,d4,d5,d6,d7= net(inputs_test) //pred = d1[:,0,:,:] auto res_tensor = (pred * torch::ones_like(src)); res_tensor = normPRED(res_tensor); //是否就是Tensor转换为图像 res_tensor = res_tensor.squeeze(0).detach(); res_tensor = res_tensor.mul(255).clamp(0, 255).to(torch::kU8); //mul函数,表示张量中每个元素乘与一个数,clamp表示夹紧,限制在一个范围内输出 res_tensor = res_tensor.to(torch::kCPU); //5.输出最终结果 cv::Mat resultImg(res_tensor.size(1), res_tensor.size(2), CV_8UC3); std::memcpy((void*)resultImg.data, res_tensor.data_ptr(), sizeof(torch::kU8) * res_tensor.numel()); cv::resize(resultImg, resultImg, cv::Size(image_src.cols, image_src.rows), cv::INTER_LINEAR); result = resultImg.clone(); }
5.4 (c++)使用opencv方法
#include "opencv2/dnn.hpp" #include "opencv2/imgproc.hpp" #include "opencv2/highgui.hpp" #include <iostream> #include "opencv2/objdetect.hpp" using namespace cv; using namespace std; using namespace cv::dnn; int main(int argc, char ** argv) { Net net = readNetFromONNX("E:/template/sky_split.onnx"); if (net.empty()) { printf("read model data failure...\n"); return -1; } // load image data Mat frame = imread("e:/template/sky14.jpg"); Mat blob; blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true); net.setInput(blob); Mat prob = net.forward("output"); Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0)); normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U); resize(slice, slice, frame.size()); return 0; }
综合考虑后,选择opencv onnx的部署方式
import os import torch from unet import UNet def main(): net = UNet(n_channels=3, n_classes=2, bilinear=True) net.load_state_dict(torch.load("checkpoints/skyseg0113.pth", map_location=torch.device('cpu'))) net.eval() # --------- model 序列化 --------- example = torch.zeros(1, 3, 320, 320) #这里经过实验,最大是 example = torch.zeros(1, 3, 411, 411) torch_script_module = torch.jit.trace(net, example) #torch_script_module.save('unet_empty.pt') torch.onnx.export(net, example, 'checkpoints/skyseg0113.onnx', opset_version=11) print('over') if __name__ == "__main__": main() int main() { //参数和常量准备 Net net = readNetFromONNX("E:/template/skyseg0113.onnx"); if (net.empty()) { printf("read model data failure...\n"); return -1; } // load image data Mat frame = imread("E:\\sandbox/sky4.jpg"); pyrDown(frame, frame); Mat blob; blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true); net.setInput(blob); Mat prob = net.forward("473");//???对于Unet来说,example最大为(411,411),原理上来说,值越大越有利于分割 Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0)); threshold(slice, slice, 0.1, 1, cv::THRESH_BINARY_INV); normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U); Mat mask; resize(slice, mask, frame.size());//制作mask }
通过这种方法,就能够获得模型推断的模板对象,其中“473”是模型训练过程的层名,由于我们在训练的过程中没有指定,所以按照系统自己的名字给出。

我们可以通过netron的方式查看获得这里的名称。
六、结合SeamlessClone等图像处理方法,实现最终效果
int main() { //参数和常量准备 Net net = readNetFromONNX("E:/template/skyseg0113.onnx"); if (net.empty()) { printf("read model data failure...\n"); return -1; } // load image data Mat frame = imread("E:\\sandbox/sky4.jpg"); pyrDown(frame, frame); Mat blob; blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true); net.setInput(blob); Mat prob = net.forward("473"); Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0)); threshold(slice, slice, 0.1, 1, cv::THRESH_BINARY_INV); normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U); Mat mask; resize(slice, mask, frame.size());//制作mask Mat matSrc = frame.clone(); VP maxCountour = FindBigestContour(mask); Rect maxRect = boundingRect(maxCountour); if (maxRect.height == 0 || maxRect.width == 0) maxRect = Rect(0, 0, mask.cols, mask.rows);//特殊情况 ////天空替换 Mat matCloud = imread("E:/template/cloud/cloud1.jpg"); resize(matCloud, matCloud, frame.size()); //直接拷贝 matCloud.copyTo(matSrc, mask); imshow("matSrc", matSrc); //seamless clone matSrc = frame.clone(); Point center = Point((maxRect.x + maxRect.width) / 2, (maxRect.y + maxRect.height) / 2);//中间位置为蓝天的背景位置 Mat normal_clone; Mat mixed_clone; Mat monochrome_clone; seamlessClone(matCloud, matSrc, mask, center, normal_clone, NORMAL_CLONE); seamlessClone(matCloud, matSrc, mask, center, mixed_clone, MIXED_CLONE); seamlessClone(matCloud, matSrc, mask, center, monochrome_clone, MONOCHROME_TRANSFER); imshow("normal_clone", normal_clone); imshow("mixed_clone", mixed_clone); imshow("monochrome_clone", monochrome_clone); waitKey(); return 0; }
在调用seamlessClone()的时候报错:

报错原因:可以看seamlessClone源码(opencv/modules/photo/src/seamless_cloning.cpp),在执行seamlessClone的时候,会先求mask内物体的boundingRect,然后会把这个最小框矩形复制到dst上,矩形中心对齐center

这个过程中可能矩形会超出dst的边界范围,就会报上面的roi边界错误。

这里错误的根源应该还是OpenCV 这块的代码有问题,其中roi_s不应该适用BoundingRect进行处理。除了进行修改重新编译,或者直接进行PR解决之外,我们可以采取一些补救的。这里我采取了2手方法来避免异常:一个是在模板制作的过程中,除了获得的最大区域之外,主动地将其他区域涂黑,从而保证BoundingRect能够准确地框选天空区域;二个是在seamlessClone之前,对模板进行异常判断,对可能出现的情况进程处置。
通过添加opencv代码,进行系统联调:

修改后的代码为:
int main() { //参数和常量准备 Net net = readNetFromONNX("E:/template/skyseg0113.onnx"); if (net.empty()) { printf("read model data failure...\n"); return -1; } vector<string> vecFilePaths; getFiles("e:/template/sky", vecFilePaths); string strSavePath = "e:/template/sky_change_result"; for (int index = 0;index<vecFilePaths.size();index++) { try{ string strFilePath = vecFilePaths[index]; string strFileName; getFileName(strFilePath, strFileName); Mat frame = imread(strFilePath); pyrDown(frame, frame); Mat blob; blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true); net.setInput(blob); Mat prob = net.forward("473"); Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0)); threshold(slice, slice, 0.1, 1, cv::THRESH_BINARY_INV); normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U); Mat mask; resize(slice, mask, frame.size());//制作mask Mat matSrc = frame.clone(); VP maxCountour = FindBigestContour(mask); Rect maxRect = boundingRect(maxCountour); if (maxRect.height == 0 || maxRect.width == 0) maxRect = Rect(0, 0, mask.cols, mask.rows);//特殊情况 Mat maskRedux(mask.size(), mask.type(), Scalar::all(0)); Mat roi1 = mask(maxRect); Mat roi2 = maskRedux(maxRect); roi1.copyTo(roi2); ////天空替换 Mat matCloud = imread("E:/template/cloud/cloud2.jpg"); resize(matCloud, matCloud, frame.size()); //直接拷贝 matCloud.copyTo(matSrc, maskRedux); matSrc = frame.clone(); cv::Point center = Point((maxRect.x + maxRect.width) / 2, (maxRect.y + maxRect.height) / 2);//中间位置为蓝天的背景位置 Rect roi_s = maxRect; Rect roi_d(center.x - roi_s.width / 2, center.y - roi_s.height / 2, roi_s.width, roi_s.height); if(! (0 <= roi_d.x && 0 <= roi_d.width && roi_d.x + roi_d.width <= matSrc.cols && 0 <= roi_d.y && 0 <= roi_d.height && roi_d.y + roi_d.height <= matSrc.rows)) center = Point(matSrc.cols / 2, matSrc.rows / 2);//这里错误的根源应该还是OpenCV 这块的代码有问题,其中roi_s不应该适用BoundingRect进行处理.所以采取补救的方法 Mat mixed_clone; seamlessClone(matCloud, matSrc, maskRedux, center, mixed_clone, MIXED_CLONE); string saveFileName = strSavePath + "/" + strFileName + "_cloud2.jpg"; imwrite(saveFileName, mixed_clone); } catch (Exception * e) { continue; } }
2022 0312 更新代码
int main() { Mat src = imread("e:/template/tiantan.jpg"); Mat matCloud = imread("E:/template/cloud/cloud2.jpg"); Mat mask = imread("e:/template/tiantanmask2.jpg", 0); resize(matCloud, matCloud, src.size()); resize(mask, mask, src.size()); Mat matSrc = src.clone(); Mat board = mask.clone(); cvtColor(board, board, COLOR_GRAY2BGR); //寻找模板最大轮廓 VP maxCountour = FindBigestContour(mask); Rect maxRect = boundingRect(maxCountour); //异常处理 Mat maskCopy = mask.clone(); copyMakeBorder(maskCopy, maskCopy, 1, 1, 1, 1, BORDER_ISOLATED | BORDER_CONSTANT, Scalar(0)); Rect roi_s = boundingRect(maskCopy); if (roi_s.empty()) return -1; cv::Point center = Point((maxRect.x + maxRect.width) / 2, (maxRect.y + maxRect.height) / 2); Rect roi_d(center.x - roi_s.width / 2, center.y - roi_s.height / 2, roi_s.width, roi_s.height); if (!(0 <= roi_d.x && 0 <= roi_d.width && roi_d.x + roi_d.width <= matSrc.cols && 0 <= roi_d.y && 0 <= roi_d.height && roi_d.y + roi_d.height <= matSrc.rows)) center = Point(matSrc.cols / 2, matSrc.rows / 2); //融合 Mat normal_clone, mixed_clone, monochrome_clone; seamlessClone(matCloud, matSrc, mask, center, normal_clone, NORMAL_CLONE); seamlessClone(matCloud, matSrc, mask, center, mixed_clone, MIXED_CLONE); seamlessClone(matCloud, matSrc, mask, center, monochrome_clone, MONOCHROME_TRANSFER); waitKey(); return 0; }
七、结果对比和小结
效果是相当不错的,但是在部署过程中也可能会遇到一些问题;特别是如果用于手机端部署,必然有工具链的问题。


我在hugginface上也实现了可以在线测试的效果。分别是skgseg和skgchange
https://huggingface.co/spaces/jsxyhelu/skyseg

最后,“天空替换”整个问题,只是语义分割的一种应用,结果是美化的图片。这是价值比较有限的,必须要转换为量化的结果,用于定量计数,才能够推动生产实践。
此外,关于算法运行效率,也是部署应用的重要环节,在部署实现的时候也需要重点考虑。
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK