diff --git a/apps/llm/patterns/executor.py b/apps/llm/patterns/executor.py index 34985745eaadd30f70e7d6f503049bbb82c29717..13b44a8eefb77e8410e9c985ce1fd29b684a76ad 100644 --- a/apps/llm/patterns/executor.py +++ b/apps/llm/patterns/executor.py @@ -96,7 +96,8 @@ class ExecutorThought(CorePattern): last_thought: str = kwargs["last_thought"] user_question: str = kwargs["user_question"] tool_info: dict[str, Any] = kwargs["tool_info"] - language: LanguageType = kwargs.get("language", LanguageType.CHINESE) + language: LanguageType = kwargs.get( + "language", LanguageType.CHINESE) except Exception as e: err = "参数不正确!" raise ValueError(err) from e @@ -184,7 +185,7 @@ class ExecutorSummary(CorePattern): enable_thinking: bool = False, ) -> None: """初始化Background模式 - + :param system_prompt: 系统提示词 :param user_prompt: 用户提示词 :param llm_id: 大模型ID,如果为None则使用系统默认模型 @@ -198,7 +199,8 @@ class ExecutorSummary(CorePattern): """进行初始背景生成""" import logging logger = logging.getLogger(__name__) - logger.info(f"[ExecutorSummary] 初始化参数 - llm_id: {self.llm_id}, enable_thinking: {self.enable_thinking}") + logger.info( + f"[ExecutorSummary] 初始化参数 - llm_id: {self.llm_id}, enable_thinking: {self.enable_thinking}") background: ExecutorBackground = kwargs["background"] conversation_str = convert_context_to_prompt(background.conversation) facts_str = facts_to_prompt(background.facts) @@ -216,35 +218,37 @@ class ExecutorSummary(CorePattern): ] result = "" - + # 根据llm_id获取模型配置 llm_config = None if self.llm_id: from apps.services.llm import LLMManager from apps.llm.adapters import get_provider_from_endpoint from apps.schemas.config import LLMConfig - + llm_info = await LLMManager.get_llm_by_id(self.llm_id) - logger.info(f"[ExecutorSummary] 根据llm_id获取模型信息: {llm_info.model_name if llm_info else 'None'}") + logger.info( + f"[ExecutorSummary] 根据llm_id获取模型信息: {llm_info.model_name if llm_info else 'None'}") if llm_info: # 获取provider,如果没有则从endpoint推断 - provider = llm_info.provider or get_provider_from_endpoint(llm_info.openai_base_url) - + provider = llm_info.provider or get_provider_from_endpoint( + llm_info.openai_base_url) + llm_config = LLMConfig( provider=provider, endpoint=llm_info.openai_base_url, - key=llm_info.openai_api_key, + api_key=llm_info.openai_api_key, model=llm_info.model_name, max_tokens=llm_info.max_tokens, temperature=0.7, ) - + # 初始化LLM客户端 llm = ReasoningLLM(llm_config) if llm_config else ReasoningLLM() - + async for chunk in llm.call( - messages, - streaming=False, + messages, + streaming=False, temperature=0.7, enable_thinking=self.enable_thinking ): diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index 60fe4468f7fc1ce701defc11bdf5fead1bece374..eab73b6581c10648b40d286f0e37d9a5bbc6008c 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -42,7 +42,7 @@ class QuestionRewrite(CorePattern): enable_thinking: bool = False, ) -> None: """初始化问题改写模式 - + :param system_prompt: 系统提示词 :param user_prompt: 用户提示词 :param llm_id: 大模型ID,如果为None则使用系统默认模型 @@ -186,31 +186,32 @@ class QuestionRewrite(CorePattern): history = kwargs.get("history", []) question = kwargs["question"] language = kwargs.get("language", LanguageType.CHINESE) - + # 根据llm_id获取模型配置并创建LLM实例 llm = None if self.llm_id: from apps.services.llm import LLMManager from apps.llm.adapters import get_provider_from_endpoint from apps.schemas.config import LLMConfig - + llm_info = await LLMManager.get_llm_by_id(self.llm_id) if llm_info: - provider = llm_info.provider or get_provider_from_endpoint(llm_info.openai_base_url) - + provider = llm_info.provider or get_provider_from_endpoint( + llm_info.openai_base_url) + llm_config = LLMConfig( provider=provider, endpoint=llm_info.openai_base_url, - key=llm_info.openai_api_key, + api_key=llm_info.openai_api_key, model=llm_info.model_name, max_tokens=llm_info.max_tokens, temperature=0.7, ) llm = ReasoningLLM(llm_config) - + if not llm: llm = ReasoningLLM() - + leave_tokens = llm._config.max_tokens leave_tokens -= TokenCalculator().calculate_token_length( messages=[{"role": "system", "content": _env.from_string(self.system_prompt[language]).render( diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index 14649b8d0e05cb09a27a699fd7a22b0b5293ffc5..dde4a32b1277c0aa70474a0df5e53d624c688dd4 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -56,9 +56,10 @@ class ReasoningContent: # 在内容中间,分离前后部分 parts = content.split(token, 1) text = parts[0] # 之前的内容作为普通文本 - reason = "" + (parts[1] if len(parts) > 1 else "") + reason = "" + \ + (parts[1] if len(parts) > 1 else "") break - + # 如果没有检测到思维链标记,将内容作为普通文本 if not self.is_reasoning: text = content @@ -107,9 +108,10 @@ class ReasoningContent: # 在内容中间,分离前后部分 parts = content.split(token, 1) reason = parts[0] + "" - text = parts[1] if len(parts) > 1 else "" # 之后的内容作为普通文本 + # 之后的内容作为普通文本 + text = parts[1] if len(parts) > 1 else "" break - + if not end_token_found: if self.is_reasoning: # 仍在推理中,将内容作为推理内容 @@ -135,25 +137,26 @@ class ReasoningLLM: else: self._config: LLMConfig = llm_config self._init_client() - + # 初始化适配器 # 优先使用配置中的provider,如果没有则从endpoint推断 if hasattr(self._config, 'provider') and self._config.provider: self._provider = self._config.provider else: self._provider = get_provider_from_endpoint(self._config.endpoint) - self._adapter = AdapterFactory.create_adapter(self._provider, self._config.model) + self._adapter = AdapterFactory.create_adapter( + self._provider, self._config.model) def _init_client(self) -> None: """初始化OpenAI客户端""" - if not self._config.key: + if not self._config.api_key: self._client = AsyncOpenAI( base_url=self._config.endpoint, ) return self._client = AsyncOpenAI( - api_key=self._config.key, + api_key=self._config.api_key, base_url=self._config.endpoint, ) @@ -187,10 +190,10 @@ class ReasoningLLM: """创建流式响应""" if model is None: model = self._config.model - + # 处理思维链控制 messages_copy = [msg.copy() for msg in messages] - + # 如果不支持原生thinking,使用prompt方式控制 if self._adapter.should_use_prompt_thinking(enable_thinking): # 启用思维链但模型不支持原生thinking,不添加/no_think @@ -204,7 +207,7 @@ class ReasoningLLM: else: messages_copy.append( {"role": "user", "content": "/no_think"}) - + # 构建基础参数 base_params = { "model": model, @@ -215,13 +218,13 @@ class ReasoningLLM: "stream_options": {"include_usage": True}, "timeout": 300, } - + # 初始化 extra_body extra_body_params = {} - + # enable_thinking 始终放在 extra_body 中 extra_body_params["enable_thinking"] = enable_thinking - + # 添加扩展参数到 extra_body(这些参数不被标准 OpenAI SDK 支持) if frequency_penalty is not None: extra_body_params["frequency_penalty"] = frequency_penalty @@ -234,24 +237,26 @@ class ReasoningLLM: if top_p is not None: # top_p 是标准参数,但某些 provider 可能需要特殊处理 base_params["top_p"] = top_p - + # 只有当有扩展参数时才添加 extra_body if extra_body_params: base_params["extra_body"] = extra_body_params - + # 使用适配器调整参数 - adapted_params = self._adapter.adapt_create_params(base_params, enable_thinking) - + adapted_params = self._adapter.adapt_create_params( + base_params, enable_thinking) + logger.info(f"[{self._provider}] 调用参数: model={model}, enable_thinking={enable_thinking}, " - f"supports_native_thinking={self._adapter.capabilities.supports_enable_thinking}") - + f"supports_native_thinking={self._adapter.capabilities.supports_enable_thinking}") + # 打印完整请求体(排除messages内容以避免日志过长) log_params = adapted_params.copy() if 'messages' in log_params: log_params['messages'] = f"<{len(log_params['messages'])} messages>" logger.info(f"[{self._provider}] 请求体: {log_params}") - - return await self._client.chat.completions.create(**adapted_params) # type: ignore[] + + # type: ignore[] + return await self._client.chat.completions.create(**adapted_params) async def call( # noqa: C901, PLR0912, PLR0913 self, @@ -279,10 +284,10 @@ class ReasoningLLM: model = self._config.model msg_list = self._validate_messages(messages) stream = await self._create_stream( - msg_list, - max_tokens, - temperature, - model, + msg_list, + max_tokens, + temperature, + model, enable_thinking, frequency_penalty, presence_penalty, @@ -326,7 +331,7 @@ class ReasoningLLM: yield result logger.info("[Reasoning] 推理内容: %s\n\n%s", reasoning_content, result) - + # 如果streaming模式下没有返回任何text,至少返回一个空格 # 避免下游处理空字符串时出错 if streaming and not result: diff --git a/apps/llm/schema.py b/apps/llm/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..6124efecc85d88665a4263853421cb68cf31b971 --- /dev/null +++ b/apps/llm/schema.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class DefaultModelId(str, Enum): + DEFAULT_EMBEDDING_MODEL_ID = "default-embedding-model_id" + DEFAULT_RERANKER_MODEL_ID = "default-reranker-model_id" + DEFAULT_CHAT_MODEL_ID = "default-chat-model_id" + DEFAULT_FUNCTION_CALL_MODEL_ID = "default-function-call-model_id" diff --git a/apps/routers/llm.py b/apps/routers/llm.py index d9f11b42ca36195998e25edeff3f32edd6df7a4f..483641e9dac7215a942da24d60ada49323b7b940 100644 --- a/apps/routers/llm.py +++ b/apps/routers/llm.py @@ -133,100 +133,3 @@ async def update_conv_llm( result=llm_id, ).model_dump(exclude_none=True, by_alias=True), ) - - -@router.get("/embedding", response_model=ResponseData) -async def get_embedding_config( - user_sub: Annotated[str, Depends(get_user)] -) -> JSONResponse: - """获取 Embedding 模型列表""" - # 使用单一查询获取用户可访问的所有embedding模型 - all_models = await LLMManager.list_all_embedding_models(user_sub) - - # 为了兼容原有接口,返回第一个模型的配置 - if all_models: - first_model = all_models[0] - embedding_config = { - 'llmId': first_model.llm_id, - 'type': 'embedding', - 'endpoint': first_model.openai_base_url, - 'api_key': first_model.openai_api_key, - 'model': first_model.model_name, - 'icon': first_model.icon, - } - else: - # 如果没有模型,使用配置文件的默认值 - config = Config().get_config() - embedding_config = { - 'type': 'embedding', - 'endpoint': config.embedding.endpoint, - 'api_key': config.embedding.api_key, - 'model': config.embedding.model, - 'icon': config.embedding.icon, - } - - return JSONResponse( - status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result=embedding_config, - ).model_dump(exclude_none=True, by_alias=True), - ) - -@router.get("/reranker", response_model=ResponseData) -async def get_reranker_config( - user_sub: Annotated[str, Depends(get_user)] -) -> JSONResponse: - """获取 Reranker 模型列表""" - # 使用单一查询获取用户可访问的所有reranker模型 - all_models = await LLMManager.list_all_reranker_models(user_sub) - - result = [] - - # 添加算法类型reranker - result.append({ - 'type': 'algorithm', - 'name': 'jaccard_dis_reranker', - 'llmId': 'algorithm_jaccard', - 'modelName': 'Jaccard Distance Reranker', - 'icon': '' - }) - - # 添加数据库中的reranker模型 - for model in all_models: - result.append({ - 'type': 'reranker', - 'llmId': model.llm_id, - 'modelName': model.model_name, - 'icon': model.icon, - 'endpoint': model.openai_base_url, - 'apiKey': model.openai_api_key, - 'model': model.model_name - }) - - return JSONResponse( - status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result=result, - ).model_dump(exclude_none=True, by_alias=True), - ) - - -@router.get("/capabilities", response_model=ResponseData) -async def get_llm_capabilities( - user_sub: Annotated[str, Depends(get_user)], - llm_id: Annotated[str, Query(description="大模型ID", alias="llmId")] -) -> JSONResponse: - """获取指定模型支持的参数配置项""" - capabilities = await LLMManager.get_model_capabilities(user_sub, llm_id) - return JSONResponse( - status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result=capabilities, - ).model_dump(exclude_none=True, by_alias=True), - ) \ No newline at end of file diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index fc76182e5694357eaeb1fc12ce4bdc4ce1f9d741..656a9f067abc9db6a8a56422adc2d17c91dfd718 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -29,7 +29,8 @@ if TYPE_CHECKING: class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): """提取事实工具""" answer: str = Field(description="用户输入") - llm_id: str | None = Field(default=None, description="大模型ID,如果为None则使用系统默认模型") + llm_id: str | None = Field( + default=None, description="大模型ID,如果为None则使用系统默认模型") enable_thinking: bool = Field(default=False, description="是否启用思维链") i18n_info: ClassVar[dict[str, dict]] = { LanguageType.CHINESE: { @@ -50,7 +51,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): # 提取 llm_id 和 enable_thinking,避免重复传递 llm_id = kwargs.pop("llm_id", None) enable_thinking = kwargs.pop("enable_thinking", False) - + obj = cls( answer=executor.task.runtime.answer, name=executor.step.step.name, @@ -64,7 +65,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): await obj._set_input(executor) return obj - async def _init(self, call_vars: CallVars) -> FactsInput: """初始化工具""" # 组装必要变量 @@ -78,7 +78,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): message=message, ) - async def _exec( self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE ) -> AsyncGenerator[CallOutputChunk, None]: @@ -100,11 +99,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": facts_prompt}, - ], + ], FactsGen, llm_id=self.llm_id, enable_thinking=self.enable_thinking, - ) # type: ignore[arg-type] + ) # type: ignore[arg-type] except Exception as e: # 如果 LLM 返回格式不正确,使用默认空列表 logging.warning(f"[FactsCall] 事实提取失败,使用默认值: {e}") @@ -118,11 +117,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": domain_prompt}, - ], + ], DomainGen, llm_id=self.llm_id, enable_thinking=self.enable_thinking, - ) # type: ignore[arg-type] + ) # type: ignore[arg-type] except Exception as e: # 如果 LLM 返回格式不正确,使用默认空列表 logging.warning(f"[FactsCall] 域名提取失败,使用默认值: {e}") @@ -138,47 +137,47 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): domain=domain_list.keywords, ).model_dump(by_alias=True, exclude_none=True), ) - + async def _json_with_config( - self, - messages: list[dict[str, Any]], + self, + messages: list[dict[str, Any]], schema: type[BaseModel], llm_id: str | None = None, enable_thinking: bool = False, ) -> BaseModel: """使用配置的模型进行JSON生成""" from apps.llm.function import FunctionLLM - + # 根据llm_id获取模型配置 llm_config = None if llm_id: from apps.services.llm import LLMManager from apps.llm.adapters import get_provider_from_endpoint from apps.schemas.config import LLMConfig - + llm_info = await LLMManager.get_llm_by_id(llm_id) if llm_info: - provider = llm_info.provider or get_provider_from_endpoint(llm_info.openai_base_url) - + provider = llm_info.provider or get_provider_from_endpoint( + llm_info.openai_base_url) + llm_config = LLMConfig( provider=provider, endpoint=llm_info.openai_base_url, - key=llm_info.openai_api_key, + api_key=llm_info.openai_api_key, model=llm_info.model_name, max_tokens=llm_info.max_tokens, temperature=0.7, ) - + # 初始化Function LLM json_gen = FunctionLLM(llm_config) if llm_config else FunctionLLM() result = await json_gen.call( - messages=messages, + messages=messages, schema=schema.model_json_schema(), enable_thinking=enable_thinking, ) return schema.model_validate(result) - async def exec( self, executor: "StepExecutor", @@ -191,5 +190,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): if not isinstance(content, dict): err = "[FactsCall] 工具输出格式错误" raise TypeError(err) - executor.task.runtime.facts = FactsOutput.model_validate(content).facts + executor.task.runtime.facts = FactsOutput.model_validate( + content).facts yield chunk diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 94ff2fa0116903a5006aa9a6fd65595983a28651..14aa58f10aa28a282d8f7e5d0f2072d70403cc8a 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -31,7 +31,7 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): to_user: bool = Field(default=True) controlled_output: bool = Field(default=True) - + # 输出参数配置 output_parameters: dict[str, Any] = Field(description="输出参数配置", default={ "reply": {"type": "string", "description": "大模型的回复内容"}, @@ -39,7 +39,7 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): # 模型配置 llmId: str = Field(description="大模型ID", default="") - + # 大模型基础参数 temperature: float = Field(description="大模型温度(随机化程度)", default=0.7) enable_temperature: bool = Field(description="是否启用温度参数", default=True) @@ -51,11 +51,13 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): description="大模型系统提示词", default="You are a helpful assistant.") user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT) - + # 新增参数配置 - enable_frequency_penalty: bool = Field(description="是否启用频率惩罚", default=False) + enable_frequency_penalty: bool = Field( + description="是否启用频率惩罚", default=False) frequency_penalty: float = Field(description="频率惩罚", default=0.0) - enable_presence_penalty: bool = Field(description="是否启用内容重复度惩罚", default=False) + enable_presence_penalty: bool = Field( + description="是否启用内容重复度惩罚", default=False) presence_penalty: float = Field(description="内容重复度惩罚", default=0.0) enable_min_p: bool = Field(description="是否启用动态过滤阈值", default=False) min_p: float = Field(description="动态过滤阈值", default=0.0) @@ -65,7 +67,8 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): top_p: float = Field(description="Top-P采样值", default=0.9) enable_search: bool = Field(description="是否启用联网搜索", default=False) enable_json_mode: bool = Field(description="是否启用JSON模式输出", default=False) - enable_structured_output: bool = Field(description="是否启用结构化输出", default=False) + enable_structured_output: bool = Field( + description="是否启用结构化输出", default=False) i18n_info: ClassVar[dict[str, dict]] = { LanguageType.CHINESE: { @@ -147,33 +150,34 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): if self.llmId: from apps.services.llm import LLMManager from apps.llm.adapters import get_provider_from_endpoint - + llm_info = await LLMManager.get_llm_by_id(self.llmId) if llm_info: from apps.schemas.config import LLMConfig - + # 获取provider,如果没有则从endpoint推断 - provider = llm_info.provider or get_provider_from_endpoint(llm_info.openai_base_url) - + provider = llm_info.provider or get_provider_from_endpoint( + llm_info.openai_base_url) + llm_config = LLMConfig( provider=provider, endpoint=llm_info.openai_base_url, - key=llm_info.openai_api_key, + api_key=llm_info.openai_api_key, model=llm_info.model_name, max_tokens=llm_info.max_tokens, temperature=self.temperature if self.enable_temperature else 0.7, ) - + # 初始化LLM客户端(会自动加载适配器) llm = ReasoningLLM(llm_config) if llm_config else ReasoningLLM() - + # 准备参数,只传递enable为True的参数 call_params = { "messages": data.message, "enable_thinking": self.enable_thinking, "temperature": self.temperature if self.enable_temperature else None, } - + # 添加可选参数(只在enable为True时传递) if self.enable_frequency_penalty: call_params["frequency_penalty"] = self.frequency_penalty @@ -185,7 +189,7 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): call_params["top_k"] = self.top_k if self.enable_top_p: call_params["top_p"] = self.top_p - + async for chunk in llm.call(**call_params): if not chunk: continue @@ -193,11 +197,12 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): yield CallOutputChunk(type=CallOutputType.TEXT, content=chunk) self.tokens.input_tokens = llm.input_tokens self.tokens.output_tokens = llm.output_tokens - + # 最后输出一个DATA chunk,包含完整的输出数据,用于保存到变量池 yield CallOutputChunk( type=CallOutputType.DATA, - content=LLMOutput(reply=full_reply).model_dump(by_alias=True, exclude_none=True) + content=LLMOutput(reply=full_reply).model_dump( + by_alias=True, exclude_none=True) ) except Exception as e: raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 76fde14818295e9b8eab807e46f33cc2e913f865..439212aa035557dadd97381040a9a1e148092030 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -5,6 +5,7 @@ import asyncio import logging from datetime import UTC, datetime +from apps.llm.schema import DefaultModelId from apps.llm.reasoning import ReasoningLLM from apps.schemas.config import LLMConfig from apps.llm.patterns.rewrite import QuestionRewrite @@ -85,22 +86,24 @@ class Scheduler: self.task.ids.user_sub, self.task.ids.conversation_id, ) logger.info(f"[Scheduler] 使用对话记录中的模型ID: {llm_id}") - + if not llm_id: logger.error("[Scheduler] 获取大模型ID失败") return None - - # 首先尝试通过用户ID和LLM ID查找 - try: - llm = await LLMManager.get_llm_by_user_sub_and_id(self.task.ids.user_sub, llm_id) - except ValueError: - # 如果用户级别的LLM不存在,尝试查找系统级别的LLM - logger.info(f"[Scheduler] 用户级别LLM {llm_id} 不存在,尝试查找系统级别LLM") + + # 首先尝试通过用户ID和LLM ID查找,如果是系统级别的LLM,则user_sub为空字符串 + default_model_ids = [] + for model_id in DefaultModelId: + default_model_ids.append(model_id.value) + if llm_id in default_model_ids: try: llm = await LLMManager.get_llm_by_id(llm_id) + logger.info(f"[Scheduler] 使用系统默认模型ID: {llm_id}") + return llm except ValueError: - logger.error(f"[Scheduler] 系统级别LLM {llm_id} 也不存在") - llm = None + logger.error(f"[Scheduler] 系统默认模型ID {llm_id} 不存在") + return None + llm = await LLMManager.get_llm_by_user_sub_and_id(self.task.ids.user_sub, llm_id) if not llm: logger.error("[Scheduler] 获取大模型失败") return None @@ -228,27 +231,11 @@ class Scheduler: if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") return - logger.info(f"[Scheduler] 应用配置的模型ID: {app_metadata.llm_id}, 启用思维链: {app_metadata.enable_thinking if hasattr(app_metadata, 'enable_thinking') else 'N/A'}") - if not app_metadata.llm_id or app_metadata.llm_id == "empty": + logger.info( + f"[Scheduler] 应用配置的模型ID: {app_metadata.llm_id}, 启用思维链: {app_metadata.enable_thinking if hasattr(app_metadata, 'enable_thinking') else 'N/A'}") + if not app_metadata.llm_id: # 获取系统默认模型 - llm_collection = MongoDB().get_collection("llm") - config = Config().get_config() - - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - - if not system_llm: - await LLMManager.init_system_chat_model() - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - - llm = LLM.model_validate(system_llm) + llm = await LLMManager.get_llm_by_id(DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) else: llm = await LLMManager.get_llm_by_id(app_metadata.llm_id) if not llm: @@ -259,7 +246,7 @@ class Scheduler: LLMConfig( provider=llm.provider, endpoint=llm.openai_base_url, - key=llm.openai_api_key, + api_key=llm.openai_api_key, model=llm.model_name, max_tokens=llm.max_tokens, ) @@ -268,18 +255,12 @@ class Scheduler: try: # 使用function call模型进行问题改写 # 降级顺序:应用配置模型 -> 用户偏好的function call模型 -> 系统默认function call模型 -> 系统默认chat模型 - app_llm_id = app_metadata.llm_id if hasattr(app_metadata, 'llm_id') and app_metadata.llm_id != "empty" else None - llm_id_for_rewrite = await LLMManager.get_function_call_model_id( - self.task.ids.user_sub, - app_llm_id=app_llm_id # 传递应用配置的模型ID(最高优先级) - ) - # 如果仍然没有找到模型(极端情况),则使用应用配置的模型作为最后降级方案 + llm_id_for_rewrite = app_metadata.llm_id if not llm_id_for_rewrite: - llm_id_for_rewrite = app_llm_id - logger.warning("[Scheduler] 未找到任何系统模型,使用应用配置的模型进行问题改写") - - enable_thinking_for_rewrite = app_metadata.enable_thinking if hasattr(app_metadata, 'enable_thinking') else False - + llm_id_for_rewrite = DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value + enable_thinking_for_rewrite = app_metadata.enable_thinking if hasattr( + app_metadata, 'enable_thinking') else False + logger.info(f"[Scheduler] 问题改写使用模型ID: {llm_id_for_rewrite}") question_obj = QuestionRewrite( llm_id=llm_id_for_rewrite, @@ -330,8 +311,10 @@ class Scheduler: msg_queue=queue, question=post_body.question, post_body_app=app_info, - enable_thinking=app_metadata.enable_thinking if hasattr(app_metadata, 'enable_thinking') else False, - llm_id=app_metadata.llm_id if hasattr(app_metadata, 'llm_id') and app_metadata.llm_id != "empty" else None, + enable_thinking=app_metadata.enable_thinking if hasattr( + app_metadata, 'enable_thinking') else False, + llm_id=app_metadata.llm_id if hasattr( + app_metadata, 'llm_id') and app_metadata.llm_id != "empty" else None, background=background, ) diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py index ba6c7d882dbbbc39e9353ea54028ea06b9fcb914..eb81215c8861df7c3c5f26cdb34ec3f1e198ab7a 100644 --- a/apps/schemas/collection.py +++ b/apps/schemas/collection.py @@ -26,7 +26,8 @@ class Blacklist(BaseModel): is_audited: bool = False reason_type: str = "" reason: str | None = None - updated_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + updated_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) class UserDomainData(BaseModel): @@ -40,7 +41,8 @@ class AppUsageData(BaseModel): """User表子项:应用使用情况数据""" count: int = 0 - last_used: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + last_used: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) class User(BaseModel): @@ -53,7 +55,8 @@ class User(BaseModel): id: str = Field(alias="_id") user_name: str = Field(default="", description="用户名") - last_login: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + last_login: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) is_active: bool = False is_whitelisted: bool = False credit: int = 100 @@ -83,31 +86,39 @@ class LLM(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") user_sub: str = Field(default="", description="用户ID") title: str = Field(default=NEW_CHAT) - icon: str = Field(default=llm_provider_dict["ollama"]["icon"], description="图标") + icon: str = Field( + default=llm_provider_dict["ollama"]["icon"], description="图标") openai_base_url: str = Field(default=Config().get_config().llm.endpoint) - openai_api_key: str = Field(default=Config().get_config().llm.key) + openai_api_key: str = Field(default=Config().get_config().llm.api_key) model_name: str = Field(default=Config().get_config().llm.model) - max_tokens: int | None = Field(default=Config().get_config().llm.max_tokens) - type: list[str] | str = Field(default=['chat'], description="模型类型,支持单个类型或多个类型") - + max_tokens: int | None = Field( + default=Config().get_config().llm.max_tokens) + type: list[str] | str = Field( + default=['chat'], description="模型类型,支持单个类型或多个类型") + # 模型能力字段 provider: str = Field(default="", description="模型提供商") supports_thinking: bool = Field(default=False, description="是否支持思维链") - can_toggle_thinking: bool = Field(default=False, description="是否支持开关思维链(仅当supports_thinking=True时有效)") - supports_function_calling: bool = Field(default=True, description="是否支持函数调用") + can_toggle_thinking: bool = Field( + default=False, description="是否支持开关思维链(仅当supports_thinking=True时有效)") + supports_function_calling: bool = Field( + default=True, description="是否支持函数调用") supports_json_mode: bool = Field(default=True, description="是否支持JSON模式") - supports_structured_output: bool = Field(default=False, description="是否支持结构化输出") - max_tokens_param: str = Field(default="max_tokens", description="最大token参数名") + supports_structured_output: bool = Field( + default=False, description="是否支持结构化输出") + max_tokens_param: str = Field( + default="max_tokens", description="最大token参数名") notes: str = Field(default="", description="备注信息") - - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - + + created_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) + def normalize_type(self) -> list[str]: """标准化type字段为列表格式""" if isinstance(self.type, str): return [self.type] return self.type - + def model_dump(self, **kwargs): """重写model_dump方法,确保type字段存储为列表""" data = super().model_dump(**kwargs) @@ -144,7 +155,8 @@ class Conversation(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") user_sub: str title: str = NEW_CHAT - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + created_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) app_id: str | None = Field(default="") tasks: list[str] = [] unused_docs: list[str] = [] @@ -166,7 +178,8 @@ class Document(BaseModel): name: str type: str size: float - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + created_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) conversation_id: str | None = Field(default=None) @@ -180,7 +193,8 @@ class Audit(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") user_sub: str | None = None http_method: str - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + created_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) module: str client_ip: str | None = None message: str @@ -196,4 +210,5 @@ class Domain(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") name: str definition: str - updated_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + updated_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/config.py b/apps/schemas/config.py index f26ec85bd2cbf97b7e8ea3287e646fca597b3df2..95c9a72cd02eb8742fe3dece7a548f3ab03854b4 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -37,7 +37,8 @@ class OIDCConfig(BaseModel): login_api: str = Field(description="EulerCopilot登录API") app_id: str = Field(description="OIDC AppID") app_secret: str = Field(description="OIDC App Secret") - redirect_settings_url: str | None = Field(description="用户设置页面重定向URL", default=None) + redirect_settings_url: str | None = Field( + description="用户设置页面重定向URL", default=None) class AutheliaConfig(BaseModel): @@ -49,14 +50,16 @@ class AutheliaConfig(BaseModel): redirect_uri: str = Field(description="重定向URI") enable_pkce: bool = Field(description="是否启用PKCE", default=True) pkce_challenge_method: str = Field(description="PKCE挑战方法", default="S256") - redirect_settings_url: str | None = Field(description="用户设置页面重定向URL", default=None) + redirect_settings_url: str | None = Field( + description="用户设置页面重定向URL", default=None) class OpenEulerConfig(BaseModel): """OpenEuler认证配置""" host: str = Field(description="OpenEuler服务路径") - redirect_settings_url: str | None = Field(description="用户设置页面重定向URL", default=None) + redirect_settings_url: str | None = Field( + description="用户设置页面重定向URL", default=None) class FixedUserConfig(BaseModel): @@ -79,17 +82,17 @@ class EmbeddingConfig(BaseModel): provider: str = Field(description="Embedding提供商") endpoint: str = Field(description="Embedding模型地址") - api_key: str = Field(description="Embedding模型API Key") + api_key: str = Field(description="Embedding模型API Key", default="") model: str = Field(description="Embedding模型名称") class RerankerConfig(BaseModel): """Reranker配置""" - provider: str | None = Field(default=None, description="Reranker提供商") - endpoint: str | None = Field(default=None, description="Reranker模型地址") - api_key: str | None = Field(default=None, description="Reranker模型API Key") - model: str | None = Field(default=None, description="Reranker模型名称") + provider: str | None = Field(description="Reranker提供商") + endpoint: str | None = Field(description="Reranker模型地址") + api_key: str | None = Field(description="Reranker模型API Key", default="") + model: str | None = Field(description="Reranker模型名称") class RAGConfig(BaseModel): @@ -141,7 +144,7 @@ class LLMConfig(BaseModel): """LLM配置""" provider: str = Field(description="LLM提供商") - key: str = Field(description="LLM API密钥") + api_key: str = Field(description="LLM API密钥", default="") endpoint: str = Field(description="LLM API URL地址") model: str = Field(description="LLM API 模型名") max_tokens: int | None = Field( @@ -155,7 +158,7 @@ class FunctionCallConfig(BaseModel): provider: str | None = Field(default=None, description="Function Call 提供商") model: str = Field(description="Function Call 模型名") endpoint: str = Field(description="Function Call API URL地址") - api_key: str = Field(description="Function Call API密钥") + api_key: str = Field(description="Function Call API密钥", default="") max_tokens: int | None = Field( description="Function Call 最大Token数", default=None) temperature: float | None = Field( diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 1e2286bdb794bbbb130a5c9e0295595adfee89ff..8336d7ac38cdeb5fa89e2cbfcf495c0bb4e3bbad 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -59,15 +59,6 @@ class ConversationManager: "model_name": config.llm.model }) - if not system_llm: - # 如果系统模型不存在,创建它 - await LLMManager.init_system_chat_model() - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - llm_item = LLMItem( llm_id=str(system_llm["_id"]), model_name=system_llm["model_name"], diff --git a/apps/services/llm.py b/apps/services/llm.py index 6ae2dca4c64986abc36820367fe66d3158beffff..eb095b3e945647a66e1665a11578cd3f59dc2b26 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -3,14 +3,17 @@ import logging + from apps.common.config import Config from apps.common.mongo import MongoDB +from apps.schemas.config import EmbeddingConfig, RerankerConfig, LLMConfig, FunctionCallConfig from apps.schemas.collection import LLM, LLMItem from apps.schemas.request_data import ( UpdateLLMReq, ) from apps.schemas.response_data import LLMProvider, LLMProviderInfo from apps.templates.generate_llm_operator_config import llm_provider_dict +from apps.llm.schema import DefaultModelId from apps.llm.model_registry import model_registry from apps.llm.adapters import get_provider_from_endpoint @@ -66,14 +69,14 @@ class LLMManager: :return: 大模型对象 """ llm_collection = MongoDB().get_collection("llm") - + result = await llm_collection.find_one({"_id": llm_id}) - + if not result: err = f"[LLMManager] LLM {llm_id} 不存在" logger.error(err) raise ValueError(err) - + return LLM.model_validate(result) @staticmethod @@ -86,14 +89,14 @@ class LLMManager: :return: 大模型对象 """ llm_collection = MongoDB().get_collection("llm") - + result = await llm_collection.find_one({"_id": llm_id, "user_sub": user_sub}) - + if not result: err = f"[LLMManager] LLM {llm_id} 不存在" logger.error(err) raise ValueError(err) - + return LLM.model_validate(result) @staticmethod @@ -116,33 +119,17 @@ class LLMManager: if model_type: # 支持type字段既可以是字符串也可以是数组 base_query["type"] = model_type - + result = await llm_collection.find(base_query).sort({"created_at": 1}).to_list(length=None) llm_list = [] - - # 只有查询chat类型或者没有指定类型时,才检查并创建默认模型 - if not model_type or model_type == 'chat': - # 检查是否已存在系统默认chat模型 - config = Config().get_config() - existing_default = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - - if not existing_default: - # 如果不存在,创建系统默认chat模型 - await LLMManager.init_system_chat_model() - # 重新查询以包含新创建的模型 - result = await llm_collection.find(base_query).sort({"created_at": 1}).to_list(length=None) for llm in result: # 标准化type字段为列表格式 llm_type = llm.get("type", "chat") if isinstance(llm_type, str): llm_type = [llm_type] - + llm_item = LLMProviderInfo( llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], @@ -156,9 +143,11 @@ class LLMManager: provider=llm.get("provider", ""), supportsThinking=llm.get("supports_thinking", False), canToggleThinking=llm.get("can_toggle_thinking", False), - supportsFunctionCalling=llm.get("supports_function_calling", True), + supportsFunctionCalling=llm.get( + "supports_function_calling", True), supportsJsonMode=llm.get("supports_json_mode", True), - supportsStructuredOutput=llm.get("supports_structured_output", False), + supportsStructuredOutput=llm.get( + "supports_structured_output", False), maxTokensParam=llm.get("max_tokens_param", "max_tokens"), notes=llm.get("notes", ""), ) @@ -179,13 +168,14 @@ class LLMManager: llm_collection = mongo.get_collection("llm") # 推断模型能力 - provider = req.provider or get_provider_from_endpoint(req.openai_base_url) - + provider = req.provider or get_provider_from_endpoint( + req.openai_base_url) + # 检查provider类型,如果是public类型,则验证URL if provider in llm_provider_dict: provider_info = llm_provider_dict[provider] provider_type = provider_info.get("type", "public") - + # 如果是public类型的provider if provider_type == "public": standard_url = provider_info.get("url", "") @@ -197,14 +187,14 @@ class LLMManager: # 如果用户没有提供URL,使用标准URL if not req.openai_base_url: req.openai_base_url = standard_url - + model_info = model_registry.get_model_info(provider, req.model_name) - + # 标准化type字段为列表格式 model_type = req.type if isinstance(model_type, str): model_type = [model_type] - + # 使用请求中的能力信息,如果没有则从model_registry获取,最后使用默认值 capabilities = { "provider": provider, @@ -223,13 +213,13 @@ class LLMManager: err = f"[LLMManager] LLM {llm_id} 不存在" logger.error(err) raise ValueError(err) - + # 检查是否为系统级别模型(不允许编辑) if not llm_dict.get("user_sub"): err = f"[LLMManager] 系统级别模型 {llm_id} 不允许编辑" logger.error(err) raise ValueError(err) - + llm = LLM( _id=llm_id, user_sub=user_sub, @@ -283,13 +273,13 @@ class LLMManager: err = f"[LLMManager] LLM {llm_id} 不存在" logger.error(err) raise ValueError(err) - + # 检查是否为系统级别模型(不允许删除) if not llm_config.get("user_sub"): err = f"[LLMManager] 系统级别模型 {llm_id} 不允许删除" logger.error(err) raise ValueError(err) - + # 检查是否为当前用户的模型 if llm_config.get("user_sub") != user_sub: err = f"[LLMManager] 无权限删除模型 {llm_id}" @@ -305,16 +295,7 @@ class LLMManager: "type": "chat", "model_name": config.llm.model }) - - if not system_llm: - # 如果系统模型不存在,创建它 - await LLMManager.init_system_chat_model() - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - + await conv_collection.update_many( {"_id": conv_dict["_id"], "user_sub": user_sub}, {"$set": {"llm": { @@ -347,16 +328,7 @@ class LLMManager: "type": "chat", "model_name": config.llm.model }) - - if not system_llm: - # 如果系统模型不存在,创建它 - await LLMManager.init_system_chat_model() - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - + llm_dict = { "llm_id": str(system_llm["_id"]), "model_name": system_llm["model_name"], @@ -401,24 +373,24 @@ class LLMManager: async def list_embedding_models(user_sub: str = "") -> list[LLMProviderInfo]: """ 获取embedding模型列表 - + :param user_sub: 用户ID,为空时返回系统级别的模型 :return: embedding模型列表 """ mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + query = {"type": "embedding", "user_sub": user_sub} - + result = await llm_collection.find(query).sort({"created_at": 1}).to_list(length=None) - + llm_list = [] for llm in result: # 标准化type字段为列表格式 llm_type = llm.get("type", "embedding") if isinstance(llm_type, str): llm_type = [llm_type] - + llm_item = LLMProviderInfo( llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], @@ -436,28 +408,28 @@ class LLMManager: async def list_all_embedding_models(user_sub: str) -> list[LLMProviderInfo]: """ 获取用户可访问的所有embedding模型列表(包括系统模型和用户模型) - + :param user_sub: 用户ID :return: embedding模型列表 """ mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + # 使用$or查询同时获取系统模型和用户模型 query = { "type": "embedding", "$or": [{"user_sub": ""}, {"user_sub": user_sub}] } - + result = await llm_collection.find(query).sort({"created_at": 1}).to_list(length=None) - + llm_list = [] for llm in result: # 标准化type字段为列表格式 llm_type = llm.get("type", "embedding") if isinstance(llm_type, str): llm_type = [llm_type] - + llm_item = LLMProviderInfo( llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], @@ -475,24 +447,24 @@ class LLMManager: async def list_reranker_models(user_sub: str = "") -> list[LLMProviderInfo]: """ 获取reranker模型列表 - + :param user_sub: 用户ID,为空时返回系统级别的模型 :return: reranker模型列表 """ mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + query = {"type": "reranker", "user_sub": user_sub} - + result = await llm_collection.find(query).sort({"created_at": 1}).to_list(length=None) - + llm_list = [] for llm in result: # 标准化type字段为列表格式 llm_type = llm.get("type", "reranker") if isinstance(llm_type, str): llm_type = [llm_type] - + llm_item = LLMProviderInfo( llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], @@ -510,28 +482,28 @@ class LLMManager: async def list_all_reranker_models(user_sub: str) -> list[LLMProviderInfo]: """ 获取用户可访问的所有reranker模型列表(包括系统模型和用户模型) - + :param user_sub: 用户ID :return: reranker模型列表 """ mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + # 使用$or查询同时获取系统模型和用户模型 query = { "type": "reranker", "$or": [{"user_sub": ""}, {"user_sub": user_sub}] } - + result = await llm_collection.find(query).sort({"created_at": 1}).to_list(length=None) - + llm_list = [] for llm in result: # 标准化type字段为列表格式 llm_type = llm.get("type", "reranker") if isinstance(llm_type, str): llm_type = [llm_type] - + llm_item = LLMProviderInfo( llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], @@ -546,50 +518,10 @@ class LLMManager: return llm_list @staticmethod - async def init_system_chat_model(): - """初始化系统级别的chat模型""" - config = Config().get_config() - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") - - # 推断chat模型能力 - # 优先使用配置文件中明确指定的provider,如果没有则从endpoint推断 - provider = getattr(config.llm, 'provider', '') or get_provider_from_endpoint(config.llm.endpoint) - model_info = model_registry.get_model_info(provider, config.llm.model) - - # 根据provider获取图标 - provider_icon = llm_provider_dict.get(provider, {}).get("icon", "") - - # 创建系统chat模型 - chat_llm = LLM( - user_sub="", # 系统级别模型 - title="System Chat Model", - icon=provider_icon, - openai_api_key=config.llm.key, - openai_base_url=config.llm.endpoint, - model_name=config.llm.model, - max_tokens=config.llm.max_tokens or (model_info.max_tokens_param if model_info else 8192), - type=['chat'], # 使用列表格式 - provider=provider, - supports_thinking=model_info.supports_thinking if model_info else False, - can_toggle_thinking=model_info.can_toggle_thinking if model_info else False, - supports_function_calling=model_info.supports_function_calling if model_info else False, - supports_json_mode=model_info.supports_json_mode if model_info else False, - supports_structured_output=model_info.supports_structured_output if model_info else False, - max_tokens_param=model_info.max_tokens_param if model_info else None, - notes=model_info.notes if model_info else "", - ) - - # 使用by_alias=True将id字段作为_id插入,保持UUID字符串格式 - insert_data = chat_llm.model_dump(by_alias=True) - await llm_collection.insert_one(insert_data) - logger.info(f"已初始化系统chat模型: {config.llm.model}") - - @staticmethod - async def _init_system_model(model_type: str, model_config, title: str): + async def _init_system_model(model_id: str, model_type: str, model_config: EmbeddingConfig | RerankerConfig | LLMConfig | FunctionCallConfig, title: str): """ 初始化系统模型的通用方法 - + :param model_type: 模型类型 ('embedding', 'reranker', 'function_call') :param model_config: 模型配置对象 :param title: 模型标题 @@ -598,22 +530,27 @@ class LLMManager: if model_type == "reranker": # 如果关键字段都为空,则跳过初始化(使用jaccard算法作为默认) if not model_config.provider and not model_config.model: - logger.info(f"[LLMManager] 跳过系统{model_type}模型初始化(将使用jaccard算法作为默认)") + logger.info( + f"[LLMManager] 跳过系统{model_type}模型初始化(将使用jaccard算法作为默认)") return - + mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + # 推断模型能力 # 优先使用配置文件中明确指定的provider,如果没有则从endpoint推断 - provider = getattr(model_config, 'provider', '') or getattr(model_config, 'backend', '') or get_provider_from_endpoint(model_config.endpoint) - model_info = model_registry.get_model_info(provider, model_config.model) - + provider = getattr(model_config, 'provider', '') or getattr( + model_config, 'backend', '') or get_provider_from_endpoint(model_config.endpoint) + model_info = model_registry.get_model_info( + provider, model_config.model) + # 根据provider获取图标,如果没有则使用配置文件中的图标 - provider_icon = llm_provider_dict.get(provider, {}).get("icon", getattr(model_config, 'icon', '')) - + provider_icon = llm_provider_dict.get(provider, {}).get( + "icon", getattr(model_config, 'icon', '')) + # 创建系统模型 system_llm = LLM( + _id=model_id, user_sub="", # 系统级别模型 title=title, icon=provider_icon, @@ -631,23 +568,24 @@ class LLMManager: max_tokens_param=model_info.max_tokens_param if model_info else "max_tokens", notes=model_info.notes if model_info else "", ) - + # 使用by_alias=True将id字段作为_id插入,保持UUID字符串格式 insert_data = system_llm.model_dump(by_alias=True) - await llm_collection.insert_one(insert_data) + # 如果模型已存在,则更新,否则插入 + await llm_collection.update_one({"_id": model_id}, {"$set": insert_data}, upsert=True) logger.info(f"[LLMManager] 创建系统{model_type}模型: {model_config.model}") @staticmethod async def get_function_call_model_id(user_sub: str, app_llm_id: str | None = None) -> str | None: """ 获取function call场景使用的模型ID - + 优先级顺序(从高到低): 1. 应用配置的模型 (app_llm_id) - 最高优先级 2. 用户偏好的 function call 模型 3. 系统默认的 function call 模型 4. 系统默认的 chat 模型 - + :param user_sub: 用户ID :param app_llm_id: 应用配置的模型ID(可选) :return: function call模型ID或chat模型ID,如果都不存在则返回None @@ -655,29 +593,32 @@ class LLMManager: try: mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + # 🔑 第一优先级:应用配置的模型(最高优先级) if app_llm_id: logger.info(f"[LLMManager] 使用应用配置的模型用于函数调用: {app_llm_id}") return app_llm_id - + # 第二优先级:获取用户偏好的function call模型 from apps.services.user import UserManager user_preferences = await UserManager.get_user_preferences_by_user_sub(user_sub) - + # 如果用户配置了function call模型偏好,检查该模型是否真正支持函数调用 if user_preferences.function_call_model_preference: llm_id = user_preferences.function_call_model_preference.llm_id # 检查该模型是否支持函数调用 llm_data = await llm_collection.find_one({"_id": llm_id}) if llm_data: - supports_fc = llm_data.get("supports_function_calling", True) + supports_fc = llm_data.get( + "supports_function_calling", True) if supports_fc: - logger.info(f"[LLMManager] 使用用户偏好的function call模型: {llm_id}") + logger.info( + f"[LLMManager] 使用用户偏好的function call模型: {llm_id}") return llm_id else: - logger.warning(f"[LLMManager] 用户偏好的模型 {llm_id} 不支持函数调用,将使用其他模型") - + logger.warning( + f"[LLMManager] 用户偏好的模型 {llm_id} 不支持函数调用,将使用其他模型") + # 第三优先级:获取系统默认的function call模型 # 注意:type字段可能是数组或字符串,需要同时支持两种格式 system_function_call_model = await llm_collection.find_one({ @@ -688,14 +629,15 @@ class LLMManager: ], "supports_function_calling": True # 确保支持函数调用 }) - + if system_function_call_model: llm_id = str(system_function_call_model["_id"]) logger.info(f"[LLMManager] 使用系统默认的function call模型: {llm_id}") return llm_id - + # 第四优先级:如果没有专门的function call模型,尝试找支持函数调用的chat模型 - logger.warning("[LLMManager] 未找到专门的function call模型,寻找支持函数调用的chat模型") + logger.warning( + "[LLMManager] 未找到专门的function call模型,寻找支持函数调用的chat模型") system_chat_with_fc = await llm_collection.find_one({ "user_sub": "", "$or": [ @@ -704,12 +646,12 @@ class LLMManager: ], "supports_function_calling": True }) - + if system_chat_with_fc: llm_id = str(system_chat_with_fc["_id"]) logger.info(f"[LLMManager] 使用支持函数调用的chat模型: {llm_id}") return llm_id - + # 最后降级:使用系统默认的chat模型 logger.warning("[LLMManager] 未找到支持函数调用的模型,降级使用系统默认的chat模型") config = Config().get_config() @@ -721,27 +663,15 @@ class LLMManager: ], "model_name": config.llm.model }) - - if not system_chat_model: - # 如果系统chat模型不存在,创建它 - await LLMManager.init_system_chat_model() - system_chat_model = await llm_collection.find_one({ - "user_sub": "", - "$or": [ - {"type": "chat"}, # 兼容字符串格式 - {"type": {"$in": ["chat"]}} # 兼容数组格式 - ], - "model_name": config.llm.model - }) - + if system_chat_model: llm_id = str(system_chat_model["_id"]) logger.info(f"[LLMManager] 降级使用系统默认的chat模型: {llm_id}") return llm_id - + logger.error("[LLMManager] 未找到任何可用的模型") return None - + except Exception as e: logger.error(f"[LLMManager] 获取模型失败: {e}") return None @@ -755,27 +685,36 @@ class LLMManager: config = Config().get_config() mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + # 清理所有系统模型(user_sub为空的模型) delete_result = await llm_collection.delete_many({"user_sub": ""}) logger.info(f"[LLMManager] 清理了 {delete_result.deleted_count} 个旧系统模型") - + # 初始化embedding模型 await LLMManager._init_system_model( - "embedding", - config.embedding, + DefaultModelId.DEFAULT_EMBEDDING_MODEL_ID.value, + "embedding", + config.embedding, "System Embedding Model" ) - + # 初始化reranker模型 await LLMManager._init_system_model( - "reranker", - config.reranker, + DefaultModelId.DEFAULT_RERANKER_MODEL_ID.value, + "reranker", + config.reranker, "System Reranker Model" ) - + # 初始化chat模型 + await LLMManager._init_system_model( + DefaultModelId.DEFAULT_CHAT_MODEL_ID.value, + "chat", + config.llm, + "System Chat Model" + ) # 初始化function_call模型 await LLMManager._init_system_model( + DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value, "function_call", config.function_call, "System Function Call Model" @@ -785,39 +724,41 @@ class LLMManager: async def get_model_capabilities(user_sub: str, llm_id: str) -> dict: """ 获取指定模型支持的参数配置项 - + :param user_sub: 用户ID :param llm_id: 模型ID :return: 模型能力配置字典 """ from apps.llm.model_types import ModelType - + # 获取模型信息(支持系统模型和用户模型) mongo = MongoDB() llm_collection = mongo.get_collection("llm") - + result = await llm_collection.find_one({ "_id": llm_id, "$or": [{"user_sub": user_sub}, {"user_sub": ""}] }) - + if not result: err = f"[LLMManager] LLM {llm_id} 不存在或无权限访问" logger.error(err) raise ValueError(err) - + llm = LLM.model_validate(result) - + # 从注册表获取模型能力 - provider = llm.provider or get_provider_from_endpoint(llm.openai_base_url) - capabilities = model_registry.get_model_capabilities(provider, llm.model_name, ModelType.CHAT) - + provider = llm.provider or get_provider_from_endpoint( + llm.openai_base_url) + capabilities = model_registry.get_model_capabilities( + provider, llm.model_name, ModelType.CHAT) + # 构建参数配置项响应 result_dict = { "provider": provider, "modelName": llm.model_name, "modelType": "chat", - + # 基础参数支持 "supportsTemperature": capabilities.supports_temperature if capabilities else True, "supportsTopP": capabilities.supports_top_p if capabilities else True, @@ -825,7 +766,7 @@ class LLMManager: "supportsFrequencyPenalty": capabilities.supports_frequency_penalty if capabilities else False, "supportsPresencePenalty": capabilities.supports_presence_penalty if capabilities else False, "supportsMinP": capabilities.supports_min_p if capabilities else False, - + # 高级功能 "supportsThinking": capabilities.supports_thinking if capabilities else False, "canToggleThinking": capabilities.can_toggle_thinking if capabilities else False, @@ -833,15 +774,15 @@ class LLMManager: "supportsFunctionCalling": capabilities.supports_function_calling if capabilities else True, "supportsJsonMode": capabilities.supports_json_mode if capabilities else True, "supportsStructuredOutput": capabilities.supports_structured_output if capabilities else False, - + # 上下文支持(所有chat模型都支持) "supportsContext": True, - + # 参数名称 "maxTokensParam": capabilities.max_tokens_param if capabilities else "max_tokens", - + # 备注信息 "notes": llm.notes or "" } - + return result_dict diff --git a/apps/services/rag.py b/apps/services/rag.py index ceaa6ad8e95c488389c0ecf832c817964087710c..fb4b00fb29ddd92af1355dbd2352d3ab2b35ebc4 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -365,12 +365,12 @@ Please generate a detailed, well-structured, and clearly formatted answer based LLMConfig( provider=llm.provider, endpoint=llm.openai_base_url, - key=llm.openai_api_key, + api_key=llm.openai_api_key, model=llm.model_name, max_tokens=llm.max_tokens, ) ) - + # 用于问题改写的LLM if history: try: @@ -381,15 +381,16 @@ Please generate a detailed, well-structured, and clearly formatted answer based user_sub, app_llm_id=None # 无应用对话不传递应用模型ID,使用用户preference或系统默认 ) - + # 如果没有找到函数调用模型,使用对话模型作为降级方案 if not function_call_model_id: logger.warning("[RAG] 未找到函数调用模型,使用对话模型进行问题改写") function_call_model_id = llm.id else: logger.info(f"[RAG] 问题改写使用模型: {function_call_model_id}") - - question_obj = QuestionRewrite(llm_id=function_call_model_id, enable_thinking=False) + + question_obj = QuestionRewrite( + llm_id=function_call_model_id, enable_thinking=False) data.query = await question_obj.generate( history=history, question=data.query, language=language )