# TreeGen **Repository Path**: growinware/tree-gen ## Basic Information - **Project Name**: TreeGen - **Description**: 面向功能增强的程序代码自动生成工具。给定一串自然语言描述,本项目通过人工智能相关技术自动化生成程序代码。贡献单位:北京大学 - **Primary Language**: Unknown - **License**: MulanPSL-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2021-04-28 - **Last Updated**: 2021-10-18 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README 1.安装说明 下载源码。 2. 使用说明 2.1 运行环境 • 运行系统: Ubuntu 16.04。 • Python 环境: Python 3.7。 • Python 依赖库: o NLTK 3.2.1 2.2 运行方法 2.2.1训练前 下载项目,下载想要执行的代码生成任务放到根目录下的一个自定义文件夹A中。 其中,用于训练的数据命名为train.txt,用于验证和测试的数据分别命名为dev.txt, test.txt。 其中 train.txt, dev.txt, test.txt 中数据表示为9元组,元组中每个元素一行,元素详情如下: 1) 输入代码的自然语言描述。 2-4 输入代码所生成步骤的抽象语法树表示。 5) 输入代码所生成步骤的Tree Path表示 6) 已使用的语法规则序列 7) 下一个要使用的语法规则 8) 空 9) 下一个要应用语法规则的结点其父结点的位置。 ![](https://images.gitee.com/uploads/images/2021/0428/102902_20f2e800_8450335.png "屏幕截图.png") 获取词表信息,将程序read_train.py移入自定义文件夹A。 使用命令行命令python3 read_train.py执行 read_train.py以生成词表文件nl_voc.txt, tree_voc.txt, char_voc.txt。 ![](https://images.gitee.com/uploads/images/2021/0428/102948_ca75774a_8450335.png "屏幕截图.png") 获取Tree Path信息,将程序read_tree_path.py移入自定义文件夹A。 使用命令行命令python3 read_tree_path.py执行read_tree_path.py以生成Tree Path文件train_tree.txt, test_tree.txt, dev_tree.txt。 ![](https://images.gitee.com/uploads/images/2021/0428/103015_9176e677_8450335.png "屏幕截图.png") 获取训练相关文件,将程序 trans_train.py移入自定义文件夹A。 使用命令行命令python3 trans_train.py执行 trans_train.py以生成Tree Path文件train_trans.txt, dev_trans.txt, test_trans.txt。 ![](https://images.gitee.com/uploads/images/2021/0428/103035_481292b8_8450335.png "屏幕截图.png") 2.2.2 执行训练 进入项目根目录,运行程序run.py。 该文件会对所给出数据进行训练,同时保存相应训练模型到指定文件夹下。 使用命令行命令python3 run.py A(其中A代表指定文件夹)以执行 run.py对模型进行训练。 ![](https://images.gitee.com/uploads/images/2021/0428/103056_11653884_8450335.png "屏幕截图.png") 2.2.3 执行预测 获取所需执行预测的文件input.txt放置于指定文件夹A中。 input.txt中,每行为一段用于生成代码的自然语言描述。 ![](https://images.gitee.com/uploads/images/2021/0428/103142_f312f7b9_8450335.png "屏幕截图.png") 进入项目根目录,运行程序predict.py以进行预测。 该文件会利用所给出的输入数据input.txt和已经训练完成的模型进行预测,同时保存相应的预测结果到指定文件夹下的out文件夹中。 使用命令行命令python3 predict.py A 5(其中A代表指定文件夹)以执行 predict.py对模型进行预测,其中5代表搜索的宽度大小。 ![](https://images.gitee.com/uploads/images/2021/0428/103208_c450cf70_8450335.png "屏幕截图.png") ![](https://images.gitee.com/uploads/images/2021/0428/103300_617fe985_8450335.png "屏幕截图.png") 2.2.3 查看执行结果 所预测的结果放置在指定文件夹A下的out文件夹中。 out文件夹中0.txt代表input.txt第一行的预测结果,1.txt代表input.txt第二行的预测结果,以此类推。 ![](https://images.gitee.com/uploads/images/2021/0428/103334_69a33612_8450335.png "屏幕截图.png") 输出结果为多个二元组,二元组第一行代表输出代码的抽象语法树,第二行则代表生成该代码的概率。 3. 算法说明 ![](https://images.gitee.com/uploads/images/2021/0428/103400_edc00beb_8450335.png "屏幕截图.png") 3.1 自然语言描述读取器 自然语言描述描述确定目标SQL代码的功能。对于一个给定自然语言描述,我们首先将其标记一系列令牌序列n1,n2,··· nL, 其中L表示输入的长度。 然后将每个令牌分割成字符 c1(ni),c2(ni),···,cS(ni),其中,S为字符数量。 所有标记和字符都通过embedding表示为实数向量n1,n2,···,nL 和 c1(ni),c2(ni),···,cS(ni)。 3.1.1 输入文本表示 字符级Embedding 经常会出现相似的词组具有相似的字符(例如“表A”和“表B”)的情况。要利用此属性,我们通过具有全连接层的字符级embedding来表示 ![](https://images.gitee.com/uploads/images/2021/0428/103428_b6f971c2_8450335.png "屏幕截图.png") 其中W(c)是权重,字符序列在输入前会被填充到预定义的最大长度M。 在全连接的层之后,我们还使用了层归一化处理。 这些输出向量会被输入到自然语言描述阅读器,并通过门控子层与 词语embedding集成在一起。 3.1.2 自然语言描述读取器的神经网络结构 自然语言描述读取器由多个结构相同的模块(总共Nd个)组成。 每个模块包含三个不同的子层(即,自注意力子层,门控子层和词组卷积子层)以提取特征,我们将在以下子部分中详细介绍。 在两个子层之间,我们采用了残差连接,并进行了层归一化。 自注意力子层 自注意力子层遵循Transformer的架构,并使用多头注意力来捕获长程依赖信息。 对于输入为 n1,n2,...,nL 的序列,我们将它们表示为实数向量序列,n1,n2,...,nL 。同时,我们使用位置embedding来编码单词位置信息。更具体的,我们将第i个单词在第 b 个 Transformer模块中的位置embedding计算为 ![](https://images.gitee.com/uploads/images/2021/0428/103453_207e44d9_8450335.png "屏幕截图.png") 其中,pb,i[.] 代表向量 pb,i 的维,而d是维度大小。 Transformer模块通过多头注意力学习非线性特征,从而产生矩阵Yb(self) 。为了简化符号,我们省略了下标b。多头层的计算公式为 ![](https://images.gitee.com/uploads/images/2021/0428/103510_e128b598_8450335.png "屏幕截图.png") 其中 H 表示头数, Wh 是权重。注意层应用于每个头部headt,由 ![](https://images.gitee.com/uploads/images/2021/0428/103523_891374d9_8450335.png "屏幕截图.png") 计算而来,其中dk 表示每个特征向量的长度。 Q,K和V的计算公式为 ![](https://images.gitee.com/uploads/images/2021/0428/103534_f8fccdb8_8450335.png "屏幕截图.png") 其中 WQ,WK,WV 是模型参数。 xi是此Transformer模块的输入。对于第一个注意力子层,它是embedding和位置embedding的矢量和;对于其他注意力子层,则是前置Transformer模块的输出和与该注意力子层相对应的位置嵌入的矢量和。 门控子层 在通过自注意力子层计算出特征之后,我们将字符embedding的信息进一步合并。 这是由基于softmax的门控机制给出的。 对于第一个词,我们通过线性变换 yi(self)来计算控制向量qi 。 用于字符embedding的权重ki(c)由等式2中的ni(c)进行线性变换得到。用于输入特征的权重ki(y)由另一个yi(self)进行线性变换得到。 该计算可表示为 ![](https://images.gitee.com/uploads/images/2021/0428/103559_b47ab067_8450335.png "屏幕截图.png") 它们用于权衡Transformer子层所得特征vi(y)和字符embedding特征vi(c) ![](https://images.gitee.com/uploads/images/2021/0428/103612_681c43c2_8450335.png "屏幕截图.png") 词组卷积子层 最后,我们将两个卷积层应用于门控子层的输出,以提取每个词组周围的局部特征。特别的,我们对于第一个和最后一个词,我们添加0作为填充。 在这些层之间,我们使用GELU作为激活函数。 3.2 抽象语法树读取器 抽象语法树读取器用以对已生成的部分AST的结构进行建模。 尽管我们的程序是通过预测语法规则的顺序生成的,但仅这些规则就缺少程序的具体信息,不足以预测下一个规则。 因此,我们的抽象语法树读取器会考虑更多信息,包括预测规则和树结构。 为了编码特定于程序的信息,我们首先将代码表示为规则序列,然后使用注意力机制对规则进行编码,最后使用树卷积层将每个节点及其祖先的编码表示形式组合在一起。 3.2.1 抽象语法树表示 规则序列Embedding 我们使用规则的ID编码规则信息。 假设我们有一个规则序列r1,r2,...,rP用于在解码步骤中生成部分AST,其中P表示序列的长度。 我们通过embedding的方法将这些规则表示为实数向量r1,r2,...,rP。 规则定义编码 embedding将语法规则视为原子标记,实际上,其丢失了规则内容的信息。 为了缓解此问题,我们使用规则定义的编码来增强规则的表示形式。 对于语法规则 i:a --> b1 ... bK ,其中a是父节点,而 b1 ... bK是前继节点。它们可以是终结符或非终结符。索引 i 是规则的ID。 我们使用全连接的方式,通过将规则内容编码为向量r(c) 。其中,输入为向量 a b1 bK。特别的,该序列也被填充到最大长度。 然后,规则定义特征y1(rule),...,yP(rule) $由另一个全连接层计算得出 ![](https://images.gitee.com/uploads/images/2021/0428/103636_fed4f574_8450335.png "屏幕截图.png") 其中ri 是规则 ri的表查询嵌入,ri(c)是内容编码规则表示,并且我们再次编码了前继节点信息 a。在该层之后,我们进行了层归一化。 位置及深度Embedding 由于我们的抽象语法树读取器将使用自注意力机制,因此我们需要表示所使用语法规则的位置。 我们首先采用位置embedding中的向量信息去编码语法规则的位置,表示何时在语法序列中使用该规则。该位置embedding用 p1(r), ...,pP(r) 来表示。 但是,这种位置嵌入不能捕获规则在抽象语法树中的位置,即树结构。 因而,我们进一步通过深度embedding对此类信息进行编码。 如果我们通过规则 r去扩展非终结符 a : a --> b1 ... bK,我们将通过其前继节点于抽象语法树中的深度(即 a的深度)表示规则的深度。 我们将embedding的方式将深度信息编码为向量,同时与输入的序列向量及位置embedding向量求和。以这种方式,我们便可以编码结构上的位置信息,帮助神经网络去理解抽象语法树。 3.2.2 抽象语法树读取器的神经网络结构 抽象语法树读取器同样由一系列结构相同的模块组成(总共 N1 个模块)。每个模块被表示为四个子层(即,自注意力子层,门控子层,自然语言注意力子层和树卷积子层)。除了树卷积层之外,我们在每个子层周围都采用了残差连接。在每个子层之后,我们使用力层归一化。 自注意力子层 为了编码抽象语法树的信息,我们构建了一个类似Transformer的自注意力层,其中输入是规则embedding,位置embedding和深度embedding的和。该自注意力子层使用了与自然语言读取器中自注意力子层相同的结构及计算过程,但是使用了不同的权重,并增加了深度embedding信息。 门控子层 我们结合规则表示与相应的规则信息。因而,我们采用如自然语言读取器中的门控子层,并且将其中结合的特征变为自注意力子层的输出向量(用于计算控制向量)和规则定义编码的输出向量。 自然语言注意力子层 在解码步骤中,应将输入的自然语言描述的信息输入与抽象语法树信息结合。因而,我们使用了一种多头的注意力机制,该机制类似于Transformer解码器对其编码器的注意力机制。我们使用该机制将门控子层的输出与自然语言描述读取器的输出结合,用以更好地帮助神经网络理解输入输出信息。 树卷积子层 如果我们仅考虑上述子层,那么读取器将很难将每个节点的信息与其祖先节点的信息结合起来。在规则序列表示中,节点可以远离其祖先,但在抽象语法树结构中,其则具有近邻的关系。因此,传统的编码器很难提取这种结构特征。 我们将节点的向量与其祖先的向量结合在一起。如果我们将AST视为图形,并使用邻接矩阵M表示有向图。其中,如果节点ai是aj的前继节点,则Mji = 1。假设所有节点都由特征f1,...,fn表示,则其前继节点的特征可以通过与邻接矩阵相乘来得出: ![](https://images.gitee.com/uploads/images/2021/0428/103658_84c5b6f9_8450335.png "屏幕截图.png") 其中fi(par)表示第i 个节点的父节点。特别的,对于根节点的父节点,我们用根节点本身的特征向量对其进行填充。 以这样的方式,应用于当前抽象语法树的树卷积窗口计算公式可以由下式给出 ![](https://images.gitee.com/uploads/images/2021/0428/103714_a4ed5303_8450335.png "屏幕截图.png") 其中 W(tconv, l)是卷积层的权重, kt 表示卷积窗口大小(在实验中设置为3), l 是树卷积层的层,对于抽象语法树读取器的最后一层模块,我们额外添加了两层树卷积层。在 等式中,f是激活函数(GELU)。 总的来说,抽象语法树读取器具有这四个子层的N1个模块,并输出以下特征:y1(ast) y2(ast) ... yP(ast)。 3.3 解码器 我们的最后一个组件是一个解码器,它将生成的SQL代码信息与自然语言描述描述结合在一起,并预测下一个要使用的语法规则。与抽象语法树读取器类似,在解码器中我们使用了多个结构相同的模块(每个模块包含多个子层,总共N2个模块)。在每个子层之间使用残差连接及层归一化。 解码器将要扩展的非终结节点作为查询输入。受先前方法[6]的启发,查询节点表示为从根节点到要扩展的节点的路径。我们将该路径中的节点表示为实树然后,对这些向量应用的全连接层,其的输出为qi(path)。 然后,我们应用两个与自然语言注意力子层相同结构的注意力子层来结合抽象语法树读取器和自然语言描述读取器的输出。 我们首先通过在抽象语法树读取器的输出上应用抽象语法树注意力子层,并提取特征。在这一层中,Q是根据查询qi(path)计算得出的; K和V是根据读取器输出的代码特征计算得出的。 我们将从输入描述中进一步结合到解码器中功能。这种结合也通过注意力子层来实现,其中Q由抽象语法树注意力子层的输出特征计算;和K 和 V自然语言描述读取器的输出计算。 最后,我们使用了两层全连接,其中第一层具有 GELU 激活函数,然后提取特征以进行预测。 训练及预测 我们根据解码器的最后一层输出特征,通过softmax预测所有可能的候选词中的下一个语法规则。 我们还引入一种可以直接从自然语言描述中复制词语c的指针网络(本质上是一种注意力机制)。在这种情况下,生成的语法规则为 a--> c,其中 a 是要扩展的非终结符,而c是终结符。这种指针机制对于用户定义的标识符(例如,变量和函数名称)很有帮助。 在softmax规则预测和指针网络之间的选择由另一个选择机制 pg算出,该选择机制也是根据解码器的最后一个特征计算得出的。 因而。下一个语法规则的总体预测概率为 ![](https://images.gitee.com/uploads/images/2021/0428/103732_84d93dc9_8450335.png "屏幕截图.png") 其中i表示规则的ID,D是预定义规则的集合,而C表示形式为a-->c的规则集合,其中c是在自然语言描述中说明中出现的终结符名称。 pg是使用预定义规则类型的概率,而p(ri | .)(每个预定义规则的概率)分别由两个分别具有sigmoid和softmax激活函数的单层感知器来计算,其输入是向量 h(dec)。 指针网络由 ![](https://images.gitee.com/uploads/images/2021/0428/103749_4ded8b9f_8450335.png "屏幕截图.png") 计算,其中h(dec)表示解码器的最后一层的输出向量。 我们通过针对参考SQL程序最大化负对数似然损失来优化模型。 预测过程从start规则开始:snode --> root,将特殊符号snode扩展为root符号。如果预测的AST中的每个叶节点都是终结符,则递归预测终止。在预测期间,我们使用大小为5的波束搜索。在波束搜索期间,我们将无效规则排除。 4. 代码结构 ![](https://images.gitee.com/uploads/images/2021/0428/103814_94efefb5_8450335.png "屏幕截图.png")