Ai
1 Star 0 Fork 0

codeMonkey/retrievalQA

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
zl_chatbot.py 10.46 KB
一键复制 编辑 原始数据 按行查看 历史
codeMonkey 提交于 2023-10-16 10:56 +08:00 . 初始化
import json
import os
import threading
from langchain import FAISS
from langchain.callbacks.manager import CallbackManager
from langchain.schema.document import Document
from echo_ai.callbacks import StreamHandler
from echo_ai.embeddings import *
from echo_ai.retrival import MyRetrival
os.environ["OPENAI_API_KEY"] = "openai api key"
os.environ["OPENAI_API_BASE"] = "https://api.openai-proxy.com/v1"
class ZlChatBot:
"""
一个基于知识库的问答机器人的核心步骤包括:
1)检索: 从知识库中检索相关的知识。
2)后处理: 对检索得到的知识进行一些处理,比如,过滤,去重,召回,格式话等。
3)推理: 从得到的知识中推理出答案。
"""
def __init__(self):
self.retrival = MyRetrival() # 检索器,从单个获取多个知识库中检索相关的知识
self.llm = ChatOpenAI(temperature=0) ## temperature越低回答越准确,越高创造性越强
self.cache = {}
def format_json_data(self, json_data: List[dict]):
"""
格式化json格式的数据
:param json_data:
:return:
"""
full_text_docs = [] # 待embedding的文本是(问题+答案)文本
query_docs = [] # 待embedding的文本是问题文本
for d in json_data:
## 从json格式的数据中解析出问题文本, 和(问题+答案)文本
query = d['Q']
q_type = d['C']
answer = ''
for step in d['A']:
answer += step['T'] + "\n"
full_text = f"""
题目: {query}
分类: {q_type}
解决办法: {answer}
"""
full_text_docs.append(
Document(page_content=full_text, metadata=d)) # page_content是待embedding的文本,metadata是存到向量库中的数据
query_docs.append(Document(page_content=query, metadata=d))
return full_text_docs, query_docs
def init_from_json_data(self, json_data: List[dict], save_path: str):
"""
从自定义格式的json数据中初始化;
方法:
1)将json数据统一加载成Document形式, page_context保存待embedding的文本, 整个Document保存到向量库。
2)随后使用不同的方式对List[Document]进行embedding.
:param json_data: 保存到知识库中的json格式,需要满足标记的格式。
:param save_path: 保存知识库的目录
:return:
"""
full_text_docs, query_docs = self.format_json_data(json_data)
# 三种不同的方式的embedding
# 1) 对问题进行embedding
self.query_embed_db = FAISS.from_documents(query_docs, TextEmbedding())
self.query_embed_db.save_local(save_path, 'query_embed_db')
# 2) 获取(问题+答案文本)的关键词列表,然后进行embedding
self.keyword_embed_db = FAISS.from_documents(full_text_docs, KeywordEmbedding())
self.keyword_embed_db.save_local(save_path, 'keyword_embed_db')
# 3) 获取(问题+答案文本)的摘要,然后进行embedding
self.absract_embed_db = FAISS.from_documents(full_text_docs, AbstractEmbedding())
self.absract_embed_db.save_local(save_path, 'absract_embed_db')
def init_from_labeled_json(self, file_path: str, encoding='utf-8'):
"""
从自定义的json格式的文件中初始化本地知识库
:param file_path:
:param encoding:
:return:
"""
with open(file_path, encoding=encoding) as f: # 加载json数据
data = json.load(f)
self.init_from_json_data(data, "")
def init_from_dir_with_labeled_json(self, file_dir: str, vector_store_dir: str, encoding='utf-8'):
"""
将本地目录的所有的json文件构建知识库
:param file_dir:
:param vector_store_dir:
:param encoding:
:return:
"""
json_files = []
# 遍历目录及其子目录
for root, dirs, files in os.walk(file_dir):
# 遍历当前目录下的文件
for file in files:
# 判断文件是否以 .json 结尾
if file.endswith('.json'):
json_files.append(os.path.join(root, file))
json_data = []
for file in json_files:
with open(file, encoding=encoding) as f: ## 将某个目录下所有的json文件加载为一个list
json_data.extend(json.load(f))
self.init_from_json_data(json_data, save_path=vector_store_dir)
def init_chatbot_from_vec_db(self, db_dirs: List[str]):
"""
初始化中梁项目的知识库
:param db_dirs: 向量库所在的目录
:return:
"""
embeddings = [AbstractEmbedding(), KeywordEmbedding(), TextEmbedding()]
self.retrival.init_from_faiss_dbs(db_dirs, embeddings)
return self
def add_new_json_data(self, json_data: list):
"""
导入更多的json数据
:param json_data:
:return:
"""
full_text_docs, query_docs = self.format_json_data(json_data)
self.retrival.embed_dbs[0].add_documents(full_text_docs)
self.retrival.embed_dbs[1].add_documents(full_text_docs)
self.retrival.embed_dbs[2].add_documents(query_docs)
for i, db in enumerate(zip(self.retrival.embed_db_dirs, self.retrival.embed_dbs)):
db[1].save_local(db[0])
def get_from_cache(self, query: str):
"""
从缓存中取回答
:param query:
:return:
"""
hash_code = hash(query)
return self.cache.get(hash_code)
def query2llm(self, query: str):
"""
直接与大模型进行对话
:param query:
:return:
"""
return self.llm.predict(query)
def post_progress_data(self, docs_list: List[List]):
"""
主要对从多个向量库得到的信息进行一个过滤去冗余 和 对json格式的数据处理成字符串格式的。
:param docs_list:
:return:
"""
docs = [d for ds in docs_list for d in ds]
res = []
docs = sorted(docs, key=lambda x: x[1], reverse=True)
temp = set()
content = ""
for doc in docs:
item = doc[0].metadata
answer = ''
for step in item['A']:
answer += step['T'] + '\n'
content = f"问题: {item['Q']} \n分类: {item['C']} \n答案: {answer}"
if content != "" and content not in temp:
res.append((content, doc[1]))
temp.add(content)
return list(res)
def query2kb(self, query: str, llm=None):
"""
从本地知识库中检索相关知识并回答问题
:param llm:
:param query:
:return:
"""
if llm == None:
llm = self.llm
docs = self.retrival.get_relevant_documents(query)
relevant_docs = self.post_progress_data(docs)
content = ""
for d in relevant_docs:
content += d[0] + '\n'
# prompt = f"""
# You are a helpful AI assistant.
# The following are the relevant knowledge content fragments found from the knowledge base.
# The relevance is sorted from high to low.
# You can only answer according to the following content:
# \n>>>\n{content}\n<<<\n
# You need to carefully consider your answer to ensure that it is based on the context.
# If the context does not mention the content or it is uncertain whether it is correct,
# please answer "Current knowledge base cannot provide effective information."
# You must use {"Chinese"} to respond.
# """
# f"""
# You are a helpful AI customer service.
# The following are text snippets or dialogue records related to the problem found in the knowledge base.
# Relevance to the problem is sorted from high to low.
# You need to carefully consider your answer and ensure it is context-based.
# Please remove overly personalized content from the prompt.
# If the prompt does not contain the necessary knowledge to answer the question or if you are unsure of its correctness, please respond with "The current knowledge base cannot provide valid information."
# Please remember that you can only answer user queries based on the provided content and cannot perform any actions on behalf of the users.
# You must respond using {"Chinese"}.
#
# You can only answer based on the following information:
# \n>>>\n{content}\n<<<\n
# """
prompt = f"""
您是一个智能客服。
以下是与知识库中发现的问题相关的文本片段或对话记录。
与问题相关性从高到低排序。
您需要仔细考虑您的答案,并确保它基于上下文。
如果提示不包含回答问题所需的知识,或者您对其正确性不确定,请回复“当前知识库无法提供有效信息。”
在回答中不要包含过于个人化的内容。
请记住,您只能根据提供的内容回答用户问题,且回答的内容都只能是知识性的回答,并不能代表或者为用户执行任何操作,或是对用户执行额外操作,例如发文件等。
必须使用{"Chinese"}进行回应。
相关内容:
\n>>>\n{content}\n<<<\n
"""
messages = [
SystemMessage(content=prompt),
HumanMessage(content=query)
]
ans = llm.predict_messages(messages).content
self.cache[hash(query)] = ans
return ans
def get_stream(self, query):
handler = StreamHandler()
llm = ChatOpenAI(temperature=0, streaming=True, callback_manager=CallbackManager([handler]))
thread = threading.Thread(target=async_run, args=(self.query2kb, llm, query))
thread.start()
return handler.generate_tokens()
def async_run(fun, llm, query):
fun(query, llm)
if __name__ == '__main__':
bot = ZlChatBot().init_chatbot_from_vec_db(
[ './vector_storage/zhongliang_abstract',
'./vector_storage/zl_db/zhongliang_keyword',
'./vector_storage/zl_db/zhongliang_query'])
with open('data/output_fin.json', encoding='utf-8') as f:
data = json.load(f)
bot.add_new_json_data(data)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/codemonkey9527/retrieval-qa.git
git@gitee.com:codemonkey9527/retrieval-qa.git
codemonkey9527
retrieval-qa
retrievalQA
master

搜索帮助