diff --git a/ci/tools/DocsRAG_test.py b/ci/tools/DocsRAG_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c19ab3b20543e4617559f35b03ea1ddc223808 --- /dev/null +++ b/ci/tools/DocsRAG_test.py @@ -0,0 +1,92 @@ +import unittest +from unittest.mock import Mock, patch +from langchain.embeddings.base import Embeddings +from langchain_community.vectorstores import FAISS +from myFirstRagDemo import ( + load_knowledge_base, + create_vector_db, + setup_rag_chain, + get_local_embeddings, + get_deepseek_llm +) +import os + +class TestRAGDemo(unittest.TestCase): + def test_load_knowledge_base(self): + """测试知识库加载功能""" + # 执行函数 + knowledge_base = load_knowledge_base() + + # 验证结果 + self.assertEqual(knowledge_base.strip(), "This is a test text.") + self.assertIsInstance(knowledge_base, str) + + def test_create_vector_db(self): + """测试向量数据库创建功能""" + # 创建模拟嵌入模型 + class MockEmbeddings(Embeddings): + def embed_query(self, text): + return [0.1] * 768 # 假设BERT输出768维向量 + + def embed_documents(self, texts): + return [[0.1] * 768 for _ in texts] + + mock_embeddings = MockEmbeddings() + test_text = "这是一个测试文本,用于验证向量数据库创建功能。" + + # 执行函数 + vector_db = create_vector_db(test_text, mock_embeddings) + + # 验证结果 + self.assertIsInstance(vector_db, FAISS) + self.assertEqual(len(vector_db.docstore._dict), 1) # 检查文档数量 + + def test_setup_rag_chain(self): + """测试RAG链设置功能""" + # 创建模拟向量数据库 + class MockFAISS(FAISS): + def as_retriever(self, search_kwargs=None): + return Mock() # 返回模拟检索器 + + mock_vector_db = MockFAISS() + mock_llm = Mock() # 模拟语言模型 + + # 执行函数 + rag_chain = setup_rag_chain(mock_vector_db, mock_llm) + + # 验证结果 + self.assertIsNotNone(rag_chain) + self.assertEqual(rag_chain.return_source_documents, True) + + def test_get_local_embeddings(self): + """测试本地嵌入模型加载功能""" + # 执行函数 + embeddings = get_local_embeddings() + + # 验证结果 + self.assertIsNotNone(embeddings) + self.assertEqual(embeddings.model_name, 'bert-base-chinese') + self.assertEqual(embeddings.model_kwargs, {"device": "cpu"}) + + @patch('langchain_deepseek.ChatDeepSeek') + def test_get_deepseek_llm(self, mock_chat_deepseek): + """测试DeepSeek语言模型加载功能""" + # 设置测试API密钥 + test_api_key = "test_api_key_123" + + # 执行函数 + llm = get_deepseek_llm(test_api_key) + + # 验证环境变量是否正确设置 + self.assertEqual(os.environ["DEEPSEEK_API_KEY"], test_api_key) + + # 验证ChatDeepSeek是否被正确调用 + mock_chat_deepseek.assert_called_once_with( + model="deepseek-chat", + temperature=0, + max_tokens=1024 + ) + self.assertIsInstance(llm, Mock) # 由于使用了mock,应该返回一个Mock对象 + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/ci/tools/DocsRag.py b/ci/tools/DocsRag.py new file mode 100644 index 0000000000000000000000000000000000000000..5bad541bd3fa79ca0b29b796c6f1091255947774 --- /dev/null +++ b/ci/tools/DocsRag.py @@ -0,0 +1,117 @@ +# To run this script, we need pip some component like this:. +# 1 pip install langchain +# 2 python.exe -m pip install --upgrade pip +# 3 pip install langchain-community +# 4 pip install langchain_deepseek +# 5 pip install sentence-transformers +# 6 pip install -U langchain-huggingface +# 7 pip install hf_xet +# 8 pip install faiss-cpu +import argparse +from langchain.embeddings.base import Embeddings +from langchain_community.vectorstores import FAISS +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.embeddings import HuggingFaceEmbeddings +import os +from langchain_deepseek import ChatDeepSeek +from langchain.chains import RetrievalQA + +# Press Shift+F10 to execute it or replace it with your code. +# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. + +# LOCAL Knowledge +def load_knowledge_base()->str: + knowledge_base = """ +This is a test text. +""" + return knowledge_base + +#2.分开和向量化 +def create_vector_db(txt:str, embeddings: Embeddings) -> FAISS: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=100,#每块最多100字符 + chunk_overlap=20,#相邻重叠20字符 + ) + # 生产适合大小文本 + chunks = text_splitter.create_documents([txt]) + + # 创建向量 + vector_db = FAISS.from_documents(chunks, embeddings) + return vector_db + +#3.检索生成 +def setup_rag_chain(vector_db:FAISS, llm) ->RetrievalQA: + #创检索器 + retriever = vector_db.as_retriever(search_kwargs={"k": 2}) + #创RAG链 + rag_chain = RetrievalQA.from_chain_type( + llm=llm, + chain_type="stuff", + retriever=retriever, + return_source_documents=True + ) + return rag_chain + +#获取embeddings +def get_local_embeddings() -> HuggingFaceEmbeddings: + """加载本地嵌入模型,兼容TensorFlow格式""" + + embedder = HuggingFaceEmbeddings( + model_name='bert-base-chinese', + model_kwargs={"device": "cpu"} + ) + return embedder + +#获取语言模型 +def get_deepseek_llm(api_key:str): + os.environ["DEEPSEEK_API_KEY"] = api_key + return ChatDeepSeek( + model="deepseek-chat", # 支持模型如deepseek-chat(DeepSeek-V3)、deepseek-reasoner(DeepSeek-R1) + temperature=0, + max_tokens=1024 + ) + + +def main(): + parser = argparse.ArgumentParser(description="First RAG demo.") + parser.add_argument("--api_key", type=str, default="*****", help="Please input AI API key.") + + args = parser.parse_args() + key = args.api_key + if not key: + print("The Key input Error.") + return + + #加载文档库 + knowledge_base = load_knowledge_base() + + #创建embedding模型 + embeddings = get_local_embeddings() + vector_db = create_vector_db(knowledge_base, embeddings) + + #创建llm + print(key) + llm = get_deepseek_llm(key) + rag_chain = setup_rag_chain(vector_db, llm) + + print("\nRAG系统已就绪,请输入问题(输入'Q'结束对话):") + while True: + user_query = input("\n问题: ") + if user_query.lower() == "Q": + break + + # 执行RAG流程 + result = rag_chain({"query": user_query}) + + # 显示结果 + print("\n回答:") + print(result["result"]) + + # 显示来源 + print("\n参考内容:") + for doc in result["source_documents"]: + print(f"- {doc.page_content[:100]}...") + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + main()