65

基于Bert的NL2SQL模型:一个简明的Baseline

 4 years ago
source link: https://www.tuicool.com/articles/aEn6nyu
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

在之前的文章 《当Bert遇上Keras:这可能是Bert最简单的打开姿势》 中,我们介绍了基于微调Bert的三个NLP例子,算是体验了一把Bert的强大和Keras的便捷。而在这篇文章中,我们再添一个例子:基于Bert的NL2SQL模型。

NL2SQL的NL也就是Natural Language,所以NL2SQL的意思就是“自然语言转SQL语句”,近年来也颇多研究,它算是人工智能领域中比较实用的一个任务。而笔者做这个模型的契机,则是今年我司举办的 首届“中文NL2SQL挑战赛”

首届中文NL2SQL挑战赛,使用金融以及通用领域的表格数据作为数据源,提供在此基础上标注的自然语言与SQL语句的匹配对,希望选手可以利用数据训练出可以准确转换自然语言到SQL的模型。

这个NL2SQL比赛算是今年比较大型的NLP赛事了,赛前投入了颇多人力物力进行宣传推广,比赛的奖金也颇丰富,唯一的问题是NL2SQL本身算是偏冷门的研究领域,所以注定不会太火爆,为此主办方也放出了一个 Baseline ,基于Pytorch写的,希望能降低大家的入门难度。

抱着“Baseline怎么能少得了Keras版”的心态,我抽时间自己用Keras做了做这个比赛,为了简化模型并且提升效果也加载了预训练的Bert模型,最终形成此文。

每个数据样本如下:

{
    "table_id": "a1b2c3d4", # 相应表格的id
    "question": "世茂茂悦府新盘容积率大于1,请问它的套均面积是多少?", # 自然语言问句
    "sql":{ # 真实SQL
        "sel": [7], # SQL选择的列 
        "agg": [0], # 选择的列相应的聚合函数, '0'代表无
        "cond_conn_op": 0, # 条件之间的关系
        "conds": [
            [1, 2, "世茂茂悦府"], # 条件列, 条件类型, 条件值,col_1 == "世茂茂悦府"
            [6, 0, "1"]
        ]
    }
}

# 其中条件运算符、聚合符、连接符分别如下
op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="}
agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
conn_sql_dict = {0:"", 1:"and", 2:"or"}

然后每个样本都对应着一个数据表,里边包含该表的所有列名,以及相应的数据记录,而原则上生成的SQL语句在对应的数据表上是可以执行的,并且都能返回有效的结果。

可以看到,虽然说是NL2SQL,但事实上主办方已经将SQL语句做了十分清晰的格式化,这样一来这个任务就可以相当大地简化了。比如sel这个字段,其实就是一个多标签分类模型,只不过类别可能会随时变化,因为这里的类别实际上对应着数据表的列,而每个样本的数据表及其含义都不尽相同,所以我们要根据表的列名来动态地编码一个类别向量;至于agg则跟sel一一对应,并且类别是固定的,cond_conn_op则是一个单标签分类问题。

最后的conds相对复杂一些,它需要结合字标注和分类,因为要同时决定哪一列是条件,条件的运算关系,以及条件对应的值,并且要注意的是,条件值并不总是question的一个片段,它有可能是格式化后的结果,比如question里边是“16年”,那么条件值可能是格式化之后的“2016”,不过,由于主办方保证生成的sql是能够在对应数据表中执行并且出有效结果的,所以如果条件运算符是“==”的时候,那么条件值肯定会出现在数据表对应列的值之中,比如上述示例样本中,数据表的第一列肯定会有“世茂茂悦府”这个值,而通过这个信息我们也可以去校正预测结果。

正式看本模型之前,读者不妨自己思考一番,想一下自己会怎么做。只有经过思考之后,才会明白其中的困难所在,也就这样才能理解本模型的一些处理技巧的要点。

本文的模型示意图如下:

M3Ifmif.png!web

本文的NL2SQL模型示意图。主要包括4个不同的分类器:序列标注器

作为一个SQL,最基本的是要决定哪些列会被select,而每个表的列的含义均不一样,所以我们要将question句子连同数据表的所有表头拼起来,一同输入到Bert模型中进行实时编码,其中每一个表头也视为一个句子,用[CLS]***[SEP]括住。经过Bert之后,我们就得到了一系列编码向量,然后就看我们怎么去用这些向量了。

第一个[CLS]对应的向量,我们可以认为是整个问题的句向量,我们用它来预测conds的连接符。后面的每个[CLS]对应的向量,我们认为是 每个表头的编码向量 ,我们把它拿出来,用来预测该表头表示的列是否应该被select。注意,此处预测有一个技巧,就是前面说了除了预测sel外,还要预测对应的agg,其中agg一共有6个类别,代表不同的距离运算,既然如此,我们干脆多增加一个类别,第7个类别代表着此列不被select,这样一来,每一列都对应着一个7分类问题,如果分到前6类,那么代表着此类被select而且同时预测到了agg,如果分到了第7类,那意味着此类不被select。

现在就剩下比较复杂的conds了,就是 where col_1 == value_1 这样子的,col_1、value_1以及运算符==都要找出来。conds的预测分两步, 第一步预测条件值,第二步预测条件列 ,预测条件值其实就是一个序列标注问题,而条件值对应的运算符有4个,我们同样新增一类变成5类,第5类代表着当前字不被标注,否则就被标注,这样我们就能预测出条件值和运算符了。剩下的就是预测条件值对应的列,我们将标注出来的值的字向量跟每个表头的向量一一算相似度,然后softmax。我这里算相似度的方法是最简单的,直接将字向量和表头向量拼接起来,然后接一个 Dense(1)做得这么简单,一是因为本文主要目的是给出一个基本可行的demo而非一个完善的程序,需要留些改进空间给读者,二是因为做得复杂了,就很容易显存不足而OOM了

顺便说一下,本文的模型是自己根据比赛任务“闭门造车”出来的,如果读者需要跟我讨论主流的NL2SQL模型,我可能就无能为力了,请见谅。

还是那句话,只要你认真地观察过比赛数据、独立思考过这个任务,上面本文介绍的模型其实很容易理解,而简单的模型能够有不错的效果,则得益于Bert强大的语义编码能力。 在线下valid集上,本文的模型生成的SQL全匹配率大概为50%左右,而官方的评估指标是(全匹配率 + 执行匹配率) / 2,也就是说你有可能写出跟标注答案不同的SQL语句,但执行结果是一致的,这也算正确一半。

这样一来最终得分肯定会比50%要高,我估计是 55% 左右吧,我看了看当前榜单,如果55%的话还可以排在前几名。由于公司员工不允许打榜评测,所以我没参与过评测,也不知道线上提交会有多少,有兴趣试用的选手自行提交测试吧。

对了,要跑这个脚本最好有个1080ti或者以上的显卡,如果你没有那么多显存,你可以试着降低maxlen和batch size。还有,现在Bert可用的中文预训练权重有两个,一个是 官方版 的,一个是 哈工大版 的,两个的最终效果差不多,但是哈工大版的收敛更快。

纵观整个模型,在实现上,最困难的应当是要精细地考虑各种mask,在上述脚本中, xm, hm, cm 是三个mask变量,就是要去除训练过程中padding部分带来的效应。注意mask不是Keras独有,不管你用Tensorflow还是Pytorch,理论上你都要仔细地处理好mask。如果读者实在读不懂mask部分,欢迎留言提问讨论,但是在提问之前,请你回答以下问题:

mask之前的序列大概是怎样的?mask之后序列的哪些位置发生了变化?变成了怎么样?

回答这个问题是证明“你已经能明白程序做了什么运算了,只是不明白为什么这样运算”。而如果你连这个运算本身都看不明白,我想我们很难沟通下去了(哪部分发生了变化总能知道吧...),还是好好学学Keras或者Tensorflow再来玩吧,不能想着一蹴而就。

对于模型来说,实现的困难部分在于mask。不过如果看整个脚本的话,其实代码占比最多的就是数据的读取和预处理、结果的后处理这两部分罢了,真正搭建模型也就只有二十行左右(再次惊叹Keras的简洁和Bert的强大吧)。

其中我们说过条件值不一定出现在question中,那如何用对question字标注的方法来抽取条件值呢?

我的方法是,如果条件值不出现在question,那么我就对question分词,然后找出question的所有1gram、2gram、3gram,然后根据条件值找出一个最相近的ngram作为标注片段。而在预测的时候,如果找到了一个ngram作为条件值、并且运算符为==时,我们会判断这个ngram有没有在数据库出现过,如果有直接保留,如果没有则在数据库中找一个最相近的值。

相应的过程在代码中都有体现,欢迎细读。

欢迎大家来玩~

nQbiYbv.png!web

首届中文NL2SQL挑战赛

https://tianchi.aliyun.com/markets/tianchi/zhuiyi_cn

祝大家取得好成绩!

转载到请包括本文地址: https://kexue.fm/archives/6771

更详细的转载事宜请参考: 《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎/本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (2019, Jun 29). 《 基于Bert的NL2SQL模型:一个简明的Baseline 》[Blog post]. Retrieved from https://kexue.fm/archives/6771


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK