# AgentSQL **Repository Path**: ezemeti/AgentSQL ## Basic Information - **Project Name**: AgentSQL - **Description**: AgentSQL是一种将自然语言转换为SQL查询的技术。用户只需用日常的语言描述需求,系统就能自动生成对应的SQL语句并执行查询,最终以自然语言的形式返回结果。这种技术极大地降低了数据库操作的门槛,特别适合非技术人员使用。 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 2 - **Forks**: 1 - **Created**: 2025-03-03 - **Last Updated**: 2025-03-30 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Text2SQL-智能对话数据库:让自然语言与数据库无缝交互 在当今数据驱动的时代,数据库操作是每个开发者、数据分析师甚至是内容创作者都无法绕开的重要技能。然而,SQL语言的学习门槛和复杂性,常常让人望而却步。今天,我将为大家介绍一种全新的解决方案——**Text2SQL**,它能够通过自然语言直接生成SQL查询语句,并返回人类可读的结果。本文将以代码为核心,深入解析其工作原理。 --- ## 什么是Text2SQL? Text2SQL是一种将自然语言转换为SQL查询的技术。用户只需用日常的语言描述需求,系统就能自动生成对应的SQL语句并执行查询,最终以自然语言的形式返回结果。这种技术极大地降低了数据库操作的门槛,特别适合非技术人员使用。 --- ![alt](./img/Text2SQL.png) ## 核心模块解析 接下来,我们将从代码层面剖析Text2SQL的核心实现。整个项目分为以下几个关键模块: 1. **数据库操作(Tools/database.py)** 2. **模型调用与链式处理(Chat/model_handler.py)** 3. **提示词设计(Chat/prompt_templates.py)** --- ### 1. 数据库操作:`UniversalDBConnector` #### 核心功能 `UniversalDBConnector` 是一个通用的数据库连接器,支持多种数据库(如SQLite、MySQL、PostgreSQL)。它的主要职责包括: - 获取表结构信息。 - 执行安全的SQL查询和写操作。 - 提供数据库方言(dialect)信息。 #### 代码解读 ```python class UniversalDBConnector: def __init__(self, connection_str: str): self.engine = create_engine(connection_str) self.db = SQLDatabase(self.engine) def get_table_info(self) -> str: """获取数据库表结构信息""" return self.db.get_table_info() def execute_safe_query(self, query: str): if any(cmd in query.upper() for cmd in ["ALTER"]): raise ValueError("写操作被禁止") # 处理写操作 if any(cmd in query.upper() for cmd in ["DELETE", "UPDATE", "INSERT"]): with self.engine.connect() as conn: result = conn.execute(text(query)) conn.commit() return result # 处理查询操作 limited_query = query with self.engine.connect() as conn: result = conn.execute(text(limited_query)) return result ``` #### 关键点 - **安全性**:通过检查SQL关键字(如`ALTER`),防止恶意操作。 - **灵活性**:支持多种数据库类型,适应不同场景。 - **表结构提取**:`get_table_info()` 方法为后续的SQL生成提供了必要的元数据。 --- ### 2. 模型调用与链式处理:`Text2SQLGenerator` 和 `SQL2TextGenerator` #### 核心功能 这两个类分别负责: - **Text2SQLGenerator**:将自然语言问题转化为SQL语句。 - **SQL2TextGenerator**:将SQL查询结果转化为自然语言回答。 #### 代码解读 ```python class Text2SQLGenerator: """支持上下文的Text2SQL生成器""" def __init__(self, llm, prompt_template, memory): self.llm = llm self.prompt_template = prompt_template self.memory = memory # 构建新的链式结构 self.chain = ( RunnableParallel( # 动态获取当前输入参数和内存中的历史记录 dialect=lambda x: x["dialect"], table_info=lambda x: x["table_info"], question=lambda x: x["question"], history=lambda _: self.memory.load_memory_variables({})["history"] ) | self.prompt_template | self.llm | StrOutputParser() ) def generate_sql(self, inputs: Dict[str, Any]) -> str: """生成并返回SQL查询""" response = self.chain.invoke(inputs) return response.strip() class SQL2TextGenerator: """支持上下文的SQL2Text生成器""" def __init__(self, llm, prompt_template): self.llm = llm self.prompt_template = prompt_template self.chain = ( RunnableParallel( result=lambda x: x["result"], question=lambda x: x["question"], sqlcode=lambda x: x["sqlcode"], ) | self.prompt_template | self.llm | StrOutputParser() ) def generate_text(self, inputs: Dict[str, Any]) -> str: """生成并返回自然语言回答""" response = self.chain.invoke(inputs) return response.strip() ``` #### 关键点 - **链式处理**:通过`RunnableParallel`将输入参数动态组合,形成完整的上下文。 - **记忆功能**:利用`memory`保存对话历史,确保上下文连贯。 - **输出格式化**:最终输出严格遵循JSON格式,便于后续处理。 --- ### 3. 提示词设计:`SQL_PROMPT` 和 `RESULT_PROMPT` #### 核心功能 提示词模板定义了模型的行为规则,直接影响生成SQL的质量和自然语言回答的准确性。 #### 代码解读 ```python from langchain_core.prompts import ChatPromptTemplate SQL_PROMPT = ChatPromptTemplate.from_messages([ ("system", """你是一个专业的SQL专家,根据以下数据库结构和对话历史生成SQL查询: Dialect: {dialect} Tables: {table_info} 规则: 1. 支持生成SELECT,UPDATE,INSERT,DELETE语句。根据用户的需求生成SQL语句。按严格遵守用户需求。 2. 不支持创建管理员用户,修改密码等操作 3. 并且查询的时候必须加上LIMIT限制,并且不能超过50条。 4. 优先使用EXISTS代替IN 5. 历史上下文:{history} 6. 按以下JSON格式响应,```json * ```: "message": "自然语言回复", "sqlcode": "生成的SQL语句,列如:[更新或者修改或者删除等等] FROM users WHERE id = (SELECT id FROM users ORDER BY date_of_birth DESC LIMIT 1);", "iscode": "是否需要执行SQL,有则需要要执行SQL的话,值为true,否则为false", """ ), ("human", "用户问题:{question}"), ]) RESULT_PROMPT = ChatPromptTemplate.from_messages([ ("system", """你是一个专业的数据分析师,根据以下查询结果生成自然语言回答: SQL语句:{sqlcode} 查询结果: {result} 规则: 1. 回答必须简洁明了。 2. 回答必须基于查询结果。 3. 按以下JSON格式响应,```json * ```: "message": "自然语言回复", """ ), ("human", "用户问题:{question}"), ]) ``` #### 关键点 - **规则约束**:通过明确的规则(如`LIMIT`限制),确保生成的SQL语句高效且安全。 - **上下文感知**:结合历史对话和表结构信息,生成更精准的SQL。 - **自然语言回复**:不仅返回SQL语句,还提供清晰的解释。 --- ## 工作流程 为了更好地理解Text2SQL的工作原理,我们将其整体流程总结如下: 1. **用户提问**:用户以自然语言提出问题。 2. **SQL生成**: - 系统根据提示词模板和数据库表结构生成SQL语句。 - 输出格式为JSON,包含SQL语句和是否需要执行的标志。 3. **SQL执行**: - 如果需要执行SQL,系统会调用`UniversalDBConnector`进行安全查询。 - 对于写操作(如`DELETE`),返回受影响的行数;对于查询操作,返回前N条记录。 4. **结果解释**: - 系统将查询结果转化为自然语言回答。 - 最终答案以简洁明了的方式呈现给用户。 --- ## 调用 创建一个main.py文件,调用Text2SQL的各个模块,并展示完整的工作流程。 ![alt](./img/export.png) ```python # main.py import time from langchain.memory import ConversationBufferMemory from langchain_mistralai import ChatMistralAI from Chat.model_handler import Text2SQLGenerator, SQL2TextGenerator from Chat.prompt_templates import SQL_PROMPT, RESULT_PROMPT from Tools.database import UniversalDBConnector from typing import Dict, Any import re import json import os def Text2SQL(llm, memory): """自然语言转换SQL语句""" try: return Text2SQLGenerator(llm, SQL_PROMPT, memory) except Exception as e: print("Text2SQL初始化错误:", e) return None def SQL2Text(llm): """SQL结果转自然语言""" try: return SQL2TextGenerator(llm, RESULT_PROMPT) except Exception as e: print("SQL2Text初始化错误:", e) return None def parse_json_response(response): """增强版JSON解析""" try: json_match = re.search(r"```json(.*?)```", response, re.DOTALL) if not json_match: raise ValueError("未找到JSON代码块") json_str = json_match.group(1).strip() data = json.loads(json_str) # 如果iscode存在的话,检查其值是否为true if "iscode" in data and data["iscode"] == "true": data["iscode"] = True return data except Exception as e: raise ValueError(f"JSON解析失败: {str(e)}") def call_ai_for_final_answer(interpreter, question, result, code): """获取最终回答""" return interpreter.generate_text({ "question": question, "result": str(result), # 确保结果为字符串 "sqlcode": code, }) def call_ai_for_sql(generator, db, question): """获取SQL查询""" return generator.generate_sql({ "dialect": db.dialect, "table_info": db.get_table_info(), "question": question }) def main(): # 初始化组件 llm = ChatMistralAI( api_key=os.getenv("MISTRAL_API_KEY"), model="codestral-2501", temperature=0.3 ) # 使用新版Memory初始化 memory = ConversationBufferMemory( memory_key="history", input_key="question", output_key="answer", return_messages=True ) db = UniversalDBConnector("sqlite:///student.db") sql_generator = Text2SQL(llm, memory) result_interpreter = SQL2Text(llm) if not sql_generator or not result_interpreter: print("系统初始化失败") return while True: try: question = input("\n用户问题(输入exit退出): ").strip() if question.lower() == "exit": break if not question: continue # 生成SQL响应 sql_response = call_ai_for_sql(sql_generator, db, question) # print(f"\n[原始响应]\n{sql_response}") sql_response = re.sub(r'.*?', '', sql_response, flags=re.DOTALL).strip() # 解析响应 sql_data = parse_json_response(sql_response) sql_code = sql_data["sqlcode"].strip() is_code = sql_data["iscode"] if is_code and sql_code: print(f"\n[生成SQL]\n{sql_code}") # 执行安全查询 result = db.execute_safe_query(sql_code) if sql_code.strip().upper().startswith("DELETE") or sql_code.strip().upper().startswith( "UPDATE") or sql_code.strip().upper().startswith("INSERT"): # 对于DELETE操作,检查受影响的行数 affected_rows = result.rowcount if affected_rows > 0: # print(f"\n[提示] 已删除 {affected_rows} 条记录") time.sleep(2) # 生成最终解释 final_response = call_ai_for_final_answer( result_interpreter, question, f" 已处理 {affected_rows} 条记录", sql_code ) final_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL).strip() final_data = parse_json_response(final_response) print(f"\n[最终回答]\n{final_data['message']}") memory.save_context( {"question": question}, {"answer": f"已删除 {affected_rows} 条记录"} ) else: print("\n[提示] 未删除任何记录") memory.save_context( {"question": question}, {"answer": "未删除任何记录"} ) else: rows = result.fetchall() if rows: # print(rows) # print(f"\n[查询结果] {len(rows)}条记录") print(f"\n[前5条记录]\n{rows[:5]}") time.sleep(2) # 生成最终解释 final_response = call_ai_for_final_answer( result_interpreter, question, "\n".join(str(row) for row in rows[:20]), # 只展示前5条 sql_code ) final_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL).strip() final_data = parse_json_response(final_response) print(f"\n[最终回答]\n{final_data['message']}") # 保存上下文 memory.save_context( {"question": question}, {"answer": f"{final_data['message']}\nSQL: {sql_code}"} ) else: time.sleep(2) # 生成最终解释 final_response = call_ai_for_final_answer( result_interpreter, question, f"[提示] 已处理 {len(rows)}条记录", sql_code ) final_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL).strip() final_data = parse_json_response(final_response) print(f"\n[最终回答]\n{final_data['message']}") memory.save_context( {"question": question}, {"answer": "该查询未返回有效结果"} ) else: print(f"\n[AI回答]\n{sql_data['message']}") memory.save_context( {"question": question}, {"answer": sql_data["message"]} ) except Exception as e: print(f"\n[错误] {str(e)}") memory.save_context( {"question": question}, {"answer": f"处理时发生错误:{str(e)}"} ) if __name__ == "__main__": main() ``` ![alt](./img/runCode.png) ## 总结 Text2SQL技术通过结合自然语言处理、数据库操作和提示工程,实现了从自然语言到SQL再到自然语言的完整闭环。无论是数据分析、内容创作还是日常办公,这项技术都能显著提升效率,降低学习成本。 希望这篇文章能帮助大家更好地理解Text2SQL的核心原理。如果你对代码或实现细节有任何疑问,欢迎在评论区留言讨论! --- **关注公众号“XXX”,获取更多技术干货!**