# AgentSQL
**Repository Path**: ezemeti/AgentSQL
## Basic Information
- **Project Name**: AgentSQL
- **Description**: AgentSQL是一种将自然语言转换为SQL查询的技术。用户只需用日常的语言描述需求,系统就能自动生成对应的SQL语句并执行查询,最终以自然语言的形式返回结果。这种技术极大地降低了数据库操作的门槛,特别适合非技术人员使用。
- **Primary Language**: Unknown
- **License**: Not specified
- **Default Branch**: master
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 2
- **Forks**: 1
- **Created**: 2025-03-03
- **Last Updated**: 2025-03-30
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
# Text2SQL-智能对话数据库:让自然语言与数据库无缝交互
在当今数据驱动的时代,数据库操作是每个开发者、数据分析师甚至是内容创作者都无法绕开的重要技能。然而,SQL语言的学习门槛和复杂性,常常让人望而却步。今天,我将为大家介绍一种全新的解决方案——**Text2SQL**,它能够通过自然语言直接生成SQL查询语句,并返回人类可读的结果。本文将以代码为核心,深入解析其工作原理。
---
## 什么是Text2SQL?
Text2SQL是一种将自然语言转换为SQL查询的技术。用户只需用日常的语言描述需求,系统就能自动生成对应的SQL语句并执行查询,最终以自然语言的形式返回结果。这种技术极大地降低了数据库操作的门槛,特别适合非技术人员使用。
---

## 核心模块解析
接下来,我们将从代码层面剖析Text2SQL的核心实现。整个项目分为以下几个关键模块:
1. **数据库操作(Tools/database.py)**
2. **模型调用与链式处理(Chat/model_handler.py)**
3. **提示词设计(Chat/prompt_templates.py)**
---
### 1. 数据库操作:`UniversalDBConnector`
#### 核心功能
`UniversalDBConnector` 是一个通用的数据库连接器,支持多种数据库(如SQLite、MySQL、PostgreSQL)。它的主要职责包括:
- 获取表结构信息。
- 执行安全的SQL查询和写操作。
- 提供数据库方言(dialect)信息。
#### 代码解读
```python
class UniversalDBConnector:
def __init__(self, connection_str: str):
self.engine = create_engine(connection_str)
self.db = SQLDatabase(self.engine)
def get_table_info(self) -> str:
"""获取数据库表结构信息"""
return self.db.get_table_info()
def execute_safe_query(self, query: str):
if any(cmd in query.upper() for cmd in ["ALTER"]):
raise ValueError("写操作被禁止")
# 处理写操作
if any(cmd in query.upper() for cmd in ["DELETE", "UPDATE", "INSERT"]):
with self.engine.connect() as conn:
result = conn.execute(text(query))
conn.commit()
return result
# 处理查询操作
limited_query = query
with self.engine.connect() as conn:
result = conn.execute(text(limited_query))
return result
```
#### 关键点
- **安全性**:通过检查SQL关键字(如`ALTER`),防止恶意操作。
- **灵活性**:支持多种数据库类型,适应不同场景。
- **表结构提取**:`get_table_info()` 方法为后续的SQL生成提供了必要的元数据。
---
### 2. 模型调用与链式处理:`Text2SQLGenerator` 和 `SQL2TextGenerator`
#### 核心功能
这两个类分别负责:
- **Text2SQLGenerator**:将自然语言问题转化为SQL语句。
- **SQL2TextGenerator**:将SQL查询结果转化为自然语言回答。
#### 代码解读
```python
class Text2SQLGenerator:
"""支持上下文的Text2SQL生成器"""
def __init__(self, llm, prompt_template, memory):
self.llm = llm
self.prompt_template = prompt_template
self.memory = memory
# 构建新的链式结构
self.chain = (
RunnableParallel(
# 动态获取当前输入参数和内存中的历史记录
dialect=lambda x: x["dialect"],
table_info=lambda x: x["table_info"],
question=lambda x: x["question"],
history=lambda _: self.memory.load_memory_variables({})["history"]
)
| self.prompt_template
| self.llm
| StrOutputParser()
)
def generate_sql(self, inputs: Dict[str, Any]) -> str:
"""生成并返回SQL查询"""
response = self.chain.invoke(inputs)
return response.strip()
class SQL2TextGenerator:
"""支持上下文的SQL2Text生成器"""
def __init__(self, llm, prompt_template):
self.llm = llm
self.prompt_template = prompt_template
self.chain = (
RunnableParallel(
result=lambda x: x["result"],
question=lambda x: x["question"],
sqlcode=lambda x: x["sqlcode"],
)
| self.prompt_template
| self.llm
| StrOutputParser()
)
def generate_text(self, inputs: Dict[str, Any]) -> str:
"""生成并返回自然语言回答"""
response = self.chain.invoke(inputs)
return response.strip()
```
#### 关键点
- **链式处理**:通过`RunnableParallel`将输入参数动态组合,形成完整的上下文。
- **记忆功能**:利用`memory`保存对话历史,确保上下文连贯。
- **输出格式化**:最终输出严格遵循JSON格式,便于后续处理。
---
### 3. 提示词设计:`SQL_PROMPT` 和 `RESULT_PROMPT`
#### 核心功能
提示词模板定义了模型的行为规则,直接影响生成SQL的质量和自然语言回答的准确性。
#### 代码解读
```python
from langchain_core.prompts import ChatPromptTemplate
SQL_PROMPT = ChatPromptTemplate.from_messages([
("system",
"""你是一个专业的SQL专家,根据以下数据库结构和对话历史生成SQL查询:
Dialect: {dialect}
Tables:
{table_info}
规则:
1. 支持生成SELECT,UPDATE,INSERT,DELETE语句。根据用户的需求生成SQL语句。按严格遵守用户需求。
2. 不支持创建管理员用户,修改密码等操作
3. 并且查询的时候必须加上LIMIT限制,并且不能超过50条。
4. 优先使用EXISTS代替IN
5. 历史上下文:{history}
6. 按以下JSON格式响应,```json * ```:
"message": "自然语言回复",
"sqlcode": "生成的SQL语句,列如:[更新或者修改或者删除等等] FROM users WHERE id = (SELECT id FROM users ORDER BY date_of_birth DESC LIMIT 1);",
"iscode": "是否需要执行SQL,有则需要要执行SQL的话,值为true,否则为false",
"""
),
("human", "用户问题:{question}"),
])
RESULT_PROMPT = ChatPromptTemplate.from_messages([
("system",
"""你是一个专业的数据分析师,根据以下查询结果生成自然语言回答:
SQL语句:{sqlcode}
查询结果:
{result}
规则:
1. 回答必须简洁明了。
2. 回答必须基于查询结果。
3. 按以下JSON格式响应,```json * ```:
"message": "自然语言回复",
"""
),
("human", "用户问题:{question}"),
])
```
#### 关键点
- **规则约束**:通过明确的规则(如`LIMIT`限制),确保生成的SQL语句高效且安全。
- **上下文感知**:结合历史对话和表结构信息,生成更精准的SQL。
- **自然语言回复**:不仅返回SQL语句,还提供清晰的解释。
---
## 工作流程
为了更好地理解Text2SQL的工作原理,我们将其整体流程总结如下:
1. **用户提问**:用户以自然语言提出问题。
2. **SQL生成**:
- 系统根据提示词模板和数据库表结构生成SQL语句。
- 输出格式为JSON,包含SQL语句和是否需要执行的标志。
3. **SQL执行**:
- 如果需要执行SQL,系统会调用`UniversalDBConnector`进行安全查询。
- 对于写操作(如`DELETE`),返回受影响的行数;对于查询操作,返回前N条记录。
4. **结果解释**:
- 系统将查询结果转化为自然语言回答。
- 最终答案以简洁明了的方式呈现给用户。
---
## 调用
创建一个main.py文件,调用Text2SQL的各个模块,并展示完整的工作流程。

```python
# main.py
import time
from langchain.memory import ConversationBufferMemory
from langchain_mistralai import ChatMistralAI
from Chat.model_handler import Text2SQLGenerator, SQL2TextGenerator
from Chat.prompt_templates import SQL_PROMPT, RESULT_PROMPT
from Tools.database import UniversalDBConnector
from typing import Dict, Any
import re
import json
import os
def Text2SQL(llm, memory):
"""自然语言转换SQL语句"""
try:
return Text2SQLGenerator(llm, SQL_PROMPT, memory)
except Exception as e:
print("Text2SQL初始化错误:", e)
return None
def SQL2Text(llm):
"""SQL结果转自然语言"""
try:
return SQL2TextGenerator(llm, RESULT_PROMPT)
except Exception as e:
print("SQL2Text初始化错误:", e)
return None
def parse_json_response(response):
"""增强版JSON解析"""
try:
json_match = re.search(r"```json(.*?)```", response, re.DOTALL)
if not json_match:
raise ValueError("未找到JSON代码块")
json_str = json_match.group(1).strip()
data = json.loads(json_str)
# 如果iscode存在的话,检查其值是否为true
if "iscode" in data and data["iscode"] == "true":
data["iscode"] = True
return data
except Exception as e:
raise ValueError(f"JSON解析失败: {str(e)}")
def call_ai_for_final_answer(interpreter, question, result, code):
"""获取最终回答"""
return interpreter.generate_text({
"question": question,
"result": str(result), # 确保结果为字符串
"sqlcode": code,
})
def call_ai_for_sql(generator, db, question):
"""获取SQL查询"""
return generator.generate_sql({
"dialect": db.dialect,
"table_info": db.get_table_info(),
"question": question
})
def main():
# 初始化组件
llm = ChatMistralAI(
api_key=os.getenv("MISTRAL_API_KEY"),
model="codestral-2501",
temperature=0.3
)
# 使用新版Memory初始化
memory = ConversationBufferMemory(
memory_key="history",
input_key="question",
output_key="answer",
return_messages=True
)
db = UniversalDBConnector("sqlite:///student.db")
sql_generator = Text2SQL(llm, memory)
result_interpreter = SQL2Text(llm)
if not sql_generator or not result_interpreter:
print("系统初始化失败")
return
while True:
try:
question = input("\n用户问题(输入exit退出): ").strip()
if question.lower() == "exit":
break
if not question:
continue
# 生成SQL响应
sql_response = call_ai_for_sql(sql_generator, db, question)
# print(f"\n[原始响应]\n{sql_response}")
sql_response = re.sub(r'.*?', '', sql_response, flags=re.DOTALL).strip()
# 解析响应
sql_data = parse_json_response(sql_response)
sql_code = sql_data["sqlcode"].strip()
is_code = sql_data["iscode"]
if is_code and sql_code:
print(f"\n[生成SQL]\n{sql_code}")
# 执行安全查询
result = db.execute_safe_query(sql_code)
if sql_code.strip().upper().startswith("DELETE") or sql_code.strip().upper().startswith(
"UPDATE") or sql_code.strip().upper().startswith("INSERT"):
# 对于DELETE操作,检查受影响的行数
affected_rows = result.rowcount
if affected_rows > 0:
# print(f"\n[提示] 已删除 {affected_rows} 条记录")
time.sleep(2)
# 生成最终解释
final_response = call_ai_for_final_answer(
result_interpreter,
question,
f" 已处理 {affected_rows} 条记录",
sql_code
)
final_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL).strip()
final_data = parse_json_response(final_response)
print(f"\n[最终回答]\n{final_data['message']}")
memory.save_context(
{"question": question},
{"answer": f"已删除 {affected_rows} 条记录"}
)
else:
print("\n[提示] 未删除任何记录")
memory.save_context(
{"question": question},
{"answer": "未删除任何记录"}
)
else:
rows = result.fetchall()
if rows:
# print(rows)
# print(f"\n[查询结果] {len(rows)}条记录")
print(f"\n[前5条记录]\n{rows[:5]}")
time.sleep(2)
# 生成最终解释
final_response = call_ai_for_final_answer(
result_interpreter,
question,
"\n".join(str(row) for row in rows[:20]), # 只展示前5条
sql_code
)
final_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL).strip()
final_data = parse_json_response(final_response)
print(f"\n[最终回答]\n{final_data['message']}")
# 保存上下文
memory.save_context(
{"question": question},
{"answer": f"{final_data['message']}\nSQL: {sql_code}"}
)
else:
time.sleep(2)
# 生成最终解释
final_response = call_ai_for_final_answer(
result_interpreter,
question,
f"[提示] 已处理 {len(rows)}条记录",
sql_code
)
final_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL).strip()
final_data = parse_json_response(final_response)
print(f"\n[最终回答]\n{final_data['message']}")
memory.save_context(
{"question": question},
{"answer": "该查询未返回有效结果"}
)
else:
print(f"\n[AI回答]\n{sql_data['message']}")
memory.save_context(
{"question": question},
{"answer": sql_data["message"]}
)
except Exception as e:
print(f"\n[错误] {str(e)}")
memory.save_context(
{"question": question},
{"answer": f"处理时发生错误:{str(e)}"}
)
if __name__ == "__main__":
main()
```

## 总结
Text2SQL技术通过结合自然语言处理、数据库操作和提示工程,实现了从自然语言到SQL再到自然语言的完整闭环。无论是数据分析、内容创作还是日常办公,这项技术都能显著提升效率,降低学习成本。
希望这篇文章能帮助大家更好地理解Text2SQL的核心原理。如果你对代码或实现细节有任何疑问,欢迎在评论区留言讨论!
---
**关注公众号“XXX”,获取更多技术干货!**