4

效果极佳 | OpenVINO 手写数字识别

 3 years ago
source link: https://bbs.cvmart.net/articles/4383
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

效果极佳 | OpenVINO 手写数字识别

2周前 ⋅ 323 ⋅ 0 ⋅ 0

来源:OpenCV学堂

【3月】OpenVINO™项目实战训练营开始招募啦! https://bbs.cvmart.net/topics/3522
(一对一指导完成OpenVINO™项目实践,及OpenVINO™中级技能资格认证,完成则可连抽 2+N 个英尔特二代算力棒)

模型介绍

之前没有注意到,最近在OpenVINO2020R04版本的模型库中发现了它有个手写数字识别的模型,支持\ or \.\ 格式的数字识别与小数点识别。相关的模型为:

handwritten-score-recognition-0003

该模型是基于LSTM双向神经网络训练,基于CTC损失,

输入格式为:[NCHW]= [1x1x32x64]
输出格式为:[WxBxL]=[16x1x13]

其中13表示"0123456789._#",#表示空白、_表示非数字的字符对输出格式的解码方式支持CTC贪心与Beam搜索,演示程序使用CTC贪心解码,这种方式最简单,我喜欢!

代码基于OPenVINO-Python SDK实现,首先需要说明一下,OpenVINO python SDK中主要的类是IECore,首先创建IECore实例对象,然后完成下面的流程操作:创建实例,加载模型

1log.info("Creating Inference Engine")2ie = IECore()3net = ie.read_network(model=model_xml, weights=model_bin)

获取输入与输出层名称

 1log.info("Preparing input blobs") 2input_it = iter(net.input_info) 3input_blob = next(input_it) 4print(input_blob) 5output_it = iter(net.outputs) 6out_blob = next(output_it) 7 8# Read and pre-process input images 9print(net.input_info[input_blob].input_data.shape)10n, c, h, w = net.input_info[input_blob].input_data.shape

加载网络为可执行网络,

1# Loading model to the plugin2exec_net = ie.load_network(network=net, device_name="CPU")

读取输入图像,并处理为\ or \.\, 格式,代码实现如下:

 1ocr = cv.imread("D:/images/zsxq/ocr1.png") 2cv.imshow("input", ocr) 3gray = cv.cvtColor(ocr, cv.COLOR_BGR2GRAY) 4binary = cv.adaptiveThreshold(gray, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY_INV, 25, 10) 5cv.imshow("binary", binary) 6contours, hireachy = cv.findContours(binary, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) 7for cnt in range(len(contours)): 8    area = cv.contourArea(contours[cnt]) 9    if area < 10:10        cv.drawContours(binary, contours, cnt, (0), -1, 8)11cv.imshow("remove noise", binary)1213# 获取每个分数14temp = np.copy(binary)15se = cv.getStructuringElement(cv.MORPH_RECT, (45, 5))16temp = cv.dilate(temp, se)17contours, hireachy = cv.findContours(temp, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)18for cnt in range(len(contours)):19    x, y, iw, ih = cv.boundingRect(contours[cnt])20    roi = gray[y:y + ih, x:x + iw]21    image = cv.resize(roi, (w, h))

输入原图:
file
二值化以后:
file
去掉干扰之后:
![图片]file
推理与解析

 1# Start sync inference 2log.info("Starting inference in synchronous mode") 3inf_start1 = time.time() 4res = exec_net.infer(inputs={input_blob: [img_blob]}) 5inf_end1 = time.time() - inf_start1 6print("inference time(ms) : %.3f" % (inf_end1 * 1000)) 7res = res[out_blob] 8 9# CTC greedy decode from here10print(res.shape)11# 解析输出text12ocrstr = ""13prev_pad = False;14for i in range(res.shape[0]):15    ctc = res[i] # 1x1316    ctc = np.squeeze(ctc, 0)17    index, prob = ctc_soft_max(ctc)18    if digit_nums[index] == '#':19        prev_pad = True20    else:21        if len(ocrstr) == 0 or prev_pad or (len(ocrstr) > 0 and digit_nums[index] != ocrstr[-1]):22            prev_pad = False23            ocrstr += digit_nums[index]24cv.putText(ocr, ocrstr, (x, y-5), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2, 8)25cv.rectangle(ocr, (x, y), (x+iw, y+ih), (0, 255, 0), 2, 8, 0)

CTC贪心解析

这个上次有个哥们问我,原因居然是我很久以前写的代码,没有交代CTC贪心解析,OpenVINO的文本与数字识别均支持CTC贪心解析,这个实现非常简单,首先来看一下输出的格式[16x1x13],可以简化为[16x13],取得每个一行13列的softmax之后的最大值,或许还可以阈值一下,得到的结果就是输出,这个就是CTC贪心解析最直接的解释。不用看公式,看完你会晕倒而且写不出代码!这个函数为:

def ctc_soft_max(data):
    sum = 0;
    max_val = max(data)
    index = np.argmax(data)
    for i in range(len(data)):
        sum += np.exp(data[i]- max_val)
    prob = 1.0 / sum
    return index, prob

最终的测试结果如下:
file

【3月】OpenVINO™项目实战训练营开始招募啦! https://bbs.cvmart.net/topics/3522
(一对一指导完成OpenVINO™项目实践,及OpenVINO™中级技能资格认证,完成则可连抽 2+N 个英尔特二代算力棒)

版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK