5 Star 3 Fork 0

is_good_bro / robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
work.py 14.39 KB
一键复制 编辑 原始数据 按行查看 历史
bitter465 提交于 2021-12-08 23:53 . work commit
from flask import Flask, request
from flask_cors import CORS
import pickle
import numpy as np
from bert4keras.backend import keras
from keras.preprocessing.sequence import pad_sequences
from nlu.bert_intent_recognition.bert_model import build_bert_model
from knowledge_extraction.bilstm_crf.crf_layer import CRF
from gevent import pywsgi
from utils.json_utils import dump_user_dialogue_context,load_user_dialogue_context
from bert4keras.tokenizers import Tokenizer
import random
from py2neo import Graph
from nlu.sklearn_Classification.clf_model import CLFModel
from config import *
graph = Graph(host="127.0.0.1",
http_port=7474,
user="neo4j",
password="123456")
clf_model = CLFModel('./nlu/sklearn_Classification/model_file/')
max_len = 80
class BertIntentModel(object):
def __init__(self):
super(BertIntentModel, self).__init__()
self.dict_path = './nlu/bert_intent_recognition/chinese_rbt3_L-3_H-768_A-12/vocab.txt'
self.config_path='./nlu/bert_intent_recognition/chinese_rbt3_L-3_H-768_A-12/bert_config_rbt3.json'
self.checkpoint_path='./nlu/bert_intent_recognition/chinese_rbt3_L-3_H-768_A-12/bert_model.ckpt'
self.label_list = [line.strip() for line in open('./nlu/bert_intent_recognition/label','r',encoding='utf8')]
self.id2label = {idx:label for idx,label in enumerate(self.label_list)}
self.tokenizer = Tokenizer(self.dict_path)
self.model = build_bert_model(self.config_path,self.checkpoint_path,13)
self.model.load_weights('./nlu/bert_intent_recognition/best_model.weights')
def predict(self,text):
token_ids, segment_ids = self.tokenizer.encode(text, maxlen=60)
proba = self.model.predict([[token_ids], [segment_ids]])
rst = {l:p for l,p in zip(self.label_list,proba[0])}
rst = sorted(rst.items(), key = lambda kv:kv[1],reverse=True)
name,confidence = rst[0]
return {"name":name,"confidence":float(confidence)}
BIM = BertIntentModel()
class BiLstmCrfModel(object):
def __init__(
self,
max_len, # 句子最大长度
vocab_size, # 词向量字典的大小
embedding_dim, # 词向量的维度
lstm_units, # lstm隐层单元的数量
class_nums, # 标签类型的数量
embedding_matrix=None # 词向量矩阵
):
super(BiLstmCrfModel, self).__init__()
self.max_len = max_len
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.lstm_units = lstm_units
self.class_nums = class_nums
self.embedding_matrix = embedding_matrix
if self.embedding_matrix is not None:
self.vocab_size,self.embedding_dim = self.embedding_matrix.shape
def build(self): # 建立模型
inputs = keras.layers.Input(
shape=(self.max_len,), # 限定了句子矩阵的最大长度
dtype='int32'
)
x = keras.layers.Masking(
mask_value=0 # 如果句子长度不够 则会用0补充
)(inputs)
x = keras.layers.Embedding(
input_dim=self.vocab_size, # 输入是词向量表的大小
output_dim=self.embedding_dim, # 输出是词向量的维度
trainable=False, # (不)做训练
weights=self.embedding_matrix,
mask_zero=True
)(x)
x = keras.layers.Bidirectional( # 接双向LSTM层
keras.layers.LSTM(
self.lstm_units,
return_sequences=True # 返回序列 输出每个字的特征
)
)(x)
x = keras.layers.TimeDistributed(
keras.layers.Dropout(
0.2
)
)(x)
crf = CRF(self.class_nums) # 标签数量有多少种
outputs = crf(x) # 输出
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer='adam', # 优化器
loss=crf.loss_function, # 损失函数
metrics=[crf.accuracy] # 每个字符的准确率
)
print(model.summary())
return model
class MedicalNerModel(object):
def __init__(self):
super(MedicalNerModel, self).__init__()
self.word2id, _, self.id2tag = pickle.load(
open("./knowledge_extraction/bilstm_crf/checkpoint/word_tag_id.pkl", "rb")
)
self.model = BiLstmCrfModel(80, 2410, 200, 128, 24).build() # 使用模型
self.model.load_weights('./knowledge_extraction/bilstm_crf/checkpoint/best_bilstm_crf_model.h5') # 加载最佳参数
def tag_parser(self, string, tags):
item = {"string": string, "entities": []}
entity_name = ""
flag = []
visit = False
for char, tag in zip(string, tags): # 迭代出来是 感 B_disease 冒 I_disease
if tag[0] == "B":
if entity_name != "": # 如果是新的一个实体开头 且上一个实体不是空的 就将上一个实体放入
x = dict((a, flag.count(a)) for a in flag)
y = [k for k, v in x.items() if max(x.values()) == v]
item["entities"].append({"word": entity_name, "type": y[0]}) # 放入
flag.clear()
entity_name = "" # 清空
entity_name += char # 重新记录
flag.append(tag[2:])
elif tag[0] == "I":
entity_name += char
flag.append(tag[2:])
else:
if entity_name != "": # 如果当前拿到的实体是O
x = dict((a, flag.count(a)) for a in flag)
y = [k for k, v in x.items() if max(x.values()) == v]
item["entities"].append({"word": entity_name, "type": y[0]})
flag.clear()
flag.clear()
entity_name = ""
if entity_name != "": # 如果做到最后一个实体不是O 就重复放入程序
x = dict((a, flag.count(a)) for a in flag)
y = [k for k, v in x.items() if max(x.values()) == v]
item["entities"].append({"word": entity_name, "type": y[0]})
return item
def predict(self, texts): # 在线接收字符串用于预测
X = [[self.word2id.get(word, 1) for word in list(x)] for x in texts] # 先进行id映射
X = pad_sequences(X, maxlen=max_len, value=0) # 填充和截断
pred_id = self.model.predict(X) # 预测输入的句子
res = []
for text, pred in zip(texts, pred_id): # 将预测结果从独热码回到标签
tags = np.argmax(pred, axis=1) # 找出独热码中的1的下标
tags = [self.id2tag[i] for i in tags if i != 0] # 通过id转标签 将下标转为标签
res.append(self.tag_parser(text, tags)) # 改输出样式
return res
ner = MedicalNerModel()
def intent_classifier(text):
return BIM.predict(text)
def slot_recognizer(text):
return ner.predict([text])
def entity_link(mention,etype):
return mention
def classifier(text): # 识别意图的分类器函数
return clf_model.predict(text)
def neo4j_searcher(cql_list):
ress = ""
if isinstance(cql_list,list):
for cql in cql_list:
rst = []
data = graph.run(cql).data()
if not data:
continue
for d in data:
d = list(d.values())
if isinstance(d[0],list):
rst.extend(d[0])
else:
rst.extend(d)
data = "、".join([str(i) for i in rst])
ress += data+"\n"
else:
data = graph.run(cql_list).data()
if not data:
return ress
rst = []
for d in data:
d = list(d.values())
if isinstance(d[0],list):
rst.extend(d[0])
else:
rst.extend(d)
data = "、".join([str(i) for i in rst])
ress += data
return ress
def semantic_parser(text,user): # 得到语义结构化信息的函数
intent_rst = intent_classifier(text) # 识别出输入的query中是什么意图的结构化信息?
slot_rst = slot_recognizer(text) # 识别出输入的语句中有哪些实体?
# 意图继承
if intent_rst == -1: # 如果没有识别到意图 就需要继承上一句的意图
pass
if intent_rst==-1 or slot_rst==-1 or intent_rst.get("name")=="其他": # 识别不到意图 或实体 或意图为 其他
return semantic_slot.get("unrecognized") # 返回能力不足的语义结构化信息
slot_info = semantic_slot.get(intent_rst.get("name")) # 根据意图获得语义结构化信息
# 填槽
slots = slot_info.get("slot_list") # 拿到槽位列表
slot_values = {} # 建一个空字典用于存放槽位值
for slot in slots: # 每一个槽位
slot_values[slot] = None # key为槽位类别 value先置空
for ent_info in slot_rst: # 遍历语句中的每一个实体
for e in ent_info["entities"]: # 拿到实体的类别
if slot.lower() == e['type']: # 如果和槽位类别一样
slot_values[slot] = entity_link(e['word'],e['type']) # 将该实体填进去槽位中
for k in slot_values.keys(): # 检查现语句的每一个槽位
if slot_values[k] is None: # 如果槽位有内容是空的
last_slot_values = load_user_dialogue_context(user)["slot_values"] # 拿到相同用户上一句话的槽位值
slot_values[k] = last_slot_values.get(k,None) # 继承上一句话的槽位内容
slot_info["slot_values"] = slot_values # 完整的槽位值
# 根据意图强度来确认回复策略
conf = intent_rst.get("confidence") # 获得目前意图的置信度
if conf >= intent_threshold_config["accept"]: # 如果置信度大于设定的接受
slot_info["intent_strategy"] = "accept" # 设置意图策略值为接受
elif conf >= intent_threshold_config["deny"]: # 如果大于设定的否认
slot_info["intent_strategy"] = "clarify" # 设置为澄清
else: # 低于设定的否认
slot_info["intent_strategy"] = "deny" # 设置为否认
return slot_info # 返回语义结构化信息
def get_answer(slot_info):
cql_template = slot_info.get("cql_template") # 获得类sql语句
reply_template = slot_info.get("reply_template") # 获得回复模板
ask_template = slot_info.get("ask_template") # 获得追问模板
slot_values = slot_info.get("slot_values") # 获得槽位值
strategy = slot_info.get("intent_strategy") # 获得意图策略值 接受 澄清 否认
if not slot_values: # 如果没有槽位值
return slot_info # 返回语义结构化信息
if strategy == "accept": # 如果意图策略值为接受
cql = [] # 建立一个空列表用于存放类sql语句
if isinstance(cql_template,list):
for cqlt in cql_template:
cql.append(cqlt.format(**slot_values)) # 把槽位值填进去类sql语句
else:
cql = cql_template.format(**slot_values)
answer = neo4j_searcher(cql) # 向图数据库发送类sql语句 并获得反馈
if not answer: # 如果没有反馈信息
slot_info["replay_answer"] = "唔~我装满知识的大脑此刻很贫瘠"
else:
pattern = reply_template.format(**slot_values) # 将槽位值填进去回复模板
slot_info["replay_answer"] = pattern + answer # 回复模板接上图数据库的反馈作为机器人的回复
elif strategy == "clarify": # 如果意图策略值为澄清
# 澄清用户是否问该问题
pattern = ask_template.format(**slot_values) # 将槽位值填进去追问模板
slot_info["replay_answer"] = pattern # 追问模板作为机器人的回复
# 得到肯定意图之后需要给用户回复的答案
cql = []
if isinstance(cql_template,list):
for cqlt in cql_template:
cql.append(cqlt.format(**slot_values))
else:
cql = cql_template.format(**slot_values)
answer = neo4j_searcher(cql)
if not answer:
slot_info["replay_answer"] = "唔~我装满知识的大脑此刻很贫瘠"
else:
pattern = reply_template.format(**slot_values)
slot_info["choice_answer"] = pattern + answer # 回复模板接上图数据库的反馈作为机器人的回复
elif strategy == "deny": # 如果意图策略值为否认
slot_info["replay_answer"] = slot_info.get("deny_response") # 否认模板作为机器人的回复
return slot_info # 返回语义结构化信息
def gossip_robot(intent):
return random.choice(
gossip_corpus.get(intent) # 获得闲聊的回复内容
)
def medical_robot(text,user):
semantic_slot = semantic_parser(text,user) # 得到未填入回复的语义结构化信息
answer = get_answer(semantic_slot) # 得到已填入回复的语义结构化信息
return answer # 返回已填入回复的语义结构化信息
def text_replay(msg):
user_intent = classifier(msg['Text']) # 使用分类器 判断出用户的输入是什么意图 greet goodbye deny isbot accept diagnosis
if user_intent in ["greet","goodbye","deny","isbot"]: # 如果判断出来的类型属于闲聊的范围
reply = gossip_robot(user_intent) # 获得闲聊的回复
elif user_intent == "accept": # 如果意图是接受
reply = load_user_dialogue_context(msg['NickName']) # 加载目前用户上一句的语义结构化信息
reply = reply.get("choice_answer") # 从语义结构化信息中拿到机器人的回复
else: # 如果是其他意图
reply = medical_robot(msg['Text'],msg['NickName']) # 得到医疗诊断机器人计算出来的语义结构化信息
if reply["slot_values"]: # 如果槽位值不为空
dump_user_dialogue_context(msg['NickName'],reply) # 记录目前用户以及本句语义结构化信息
reply = reply.get("replay_answer") # 从语义结构化信息中拿到机器人的回复
return reply
if __name__ == '__main__':
app = Flask(__name__)
CORS(app, resources=r'/*')
@app.route("/service/api/work",methods=["GET","POST"])
def work():
param = request.get_json()
return text_replay(param)
server = pywsgi.WSGIServer(("0.0.0.0",60064), app)
server.serve_forever()
1
https://gitee.com/is_good_bro/robot.git
git@gitee.com:is_good_bro/robot.git
is_good_bro
robot
robot
master

搜索帮助