diff --git a/jiuwen/agent_builder/prompt_builder/tune/base/context_manager.py b/jiuwen/agent_builder/prompt_builder/tune/base/context_manager.py index 961d859b0324d8f87ddedbb738feae4b70e1c382..5eea7c703c9c914e338d6585f801eada9650261c 100644 --- a/jiuwen/agent_builder/prompt_builder/tune/base/context_manager.py +++ b/jiuwen/agent_builder/prompt_builder/tune/base/context_manager.py @@ -3,24 +3,19 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved import copy -import os import threading import hashlib import time -from typing import Dict, Optional, Any, Callable, List, Set -from datetime import datetime, timezone, timedelta +from typing import Dict, Optional, Any from cacheout import Cache -from pydantic import BaseModel, Field from jiuwen.agent_builder.prompt_builder.tune.common.singleton import Singleton from logging import getLogger logger = getLogger(__name__) -from jiuwen.agent_builder.prompt_builder.tune.common.exception import JiuWenBaseException, StatusCode from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant, TaskStatus -from jiuwen.agent_builder.prompt_builder.tune.base.utils import OptimizeInfo Context = Dict[str, Any] STOP_EVENT = "stop_event" diff --git a/jiuwen/agent_builder/prompt_builder/tune/base/utils.py b/jiuwen/agent_builder/prompt_builder/tune/base/utils.py index 08207c3c47d8615bbe00f3333e5d2c4919e95028..a221c3b93641af1b3e32d9253cc3e527e9673000 100644 --- a/jiuwen/agent_builder/prompt_builder/tune/base/utils.py +++ b/jiuwen/agent_builder/prompt_builder/tune/base/utils.py @@ -4,17 +4,17 @@ prompt optimization utils """ -import os from typing import List, Dict, Optional, Any, Union from datetime import datetime, timezone, timedelta -from pydantic import BaseModel, field_validator, Field, FieldValidationInfo +from pydantic import BaseModel, Field, FieldValidationInfo import yaml -from jiuwen.agent_builder.prompt_builder.tune.base.exception import OnStopException +from jiuwen.core.utils.llm.base import BaseModelInfo +from jiuwen.core.utils.llm.model_utils.model_factory import ModelFactory from jiuwen.agent_builder.prompt_builder.tune.common.exception import JiuWenBaseException, StatusCode from jiuwen.agent_builder.prompt_builder.tune.base.case import Case -from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant, TaskStatus +from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant class TaskInfo(BaseModel): @@ -82,29 +82,6 @@ class LLMModelInfo(BaseModel): api_key: str = Field(default="", min_length=0, max_length=256) -class QwenLLM: - def __init__(self, model_info: LLMModelInfo): - self.url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" - self.api_key = "sk-f9410e0700c94022a16b78341e860c45" - self.model_name = model_info.model - - def chat(self, messages): - import requests - body = { - "messages": messages, - "model": self.model_name, - "stream": False - } - headers = { - "Authorization": f"{self.api_key}", - } - response = requests.post(self.url, json=body, headers=headers) - if response.status_code != 200: - print(response.text) - raise Exception(response.text) - content = response.json().get("choices")[0].get("message").get("content") - return dict(content=content) - class LLMModelProcess: """LLM invoke process""" def __init__(self, llm_model_info: LLMModelInfo): @@ -115,12 +92,24 @@ class LLMModelProcess: error_msg="prompt optimization llm config is missing" ) ) - self.chat_llm = None - self.model_info = llm_model_info + if not llm_model_info.model_source or not llm_model_info.model: + raise JiuWenBaseException( + error_code=StatusCode.LLM_CONFIG_MISS_ERROR.code, + message=StatusCode.LLM_CONFIG_MISS_ERROR.errmsg.format( + error_msg="prompt optimization llm config is missing" + ) + ) + model_info = BaseModelInfo( + api_key=llm_model_info.api_key, + api_base=llm_model_info.url, + model=llm_model_info.model + ) + self.chat_llm = ModelFactory().get_model(llm_model_info.model_source, model_info) def chat(self, messages: List[Any]) -> Dict: """chat""" - return QwenLLM(self.model_info).chat(messages) + reply_message = self.chat_llm.invoke(messages) + return dict(content=reply_message.content) def load_yaml_to_dict(file_path: str) -> Dict: diff --git a/jiuwen/agent_builder/prompt_builder/tune/joint_optimizer.py b/jiuwen/agent_builder/prompt_builder/tune/joint_optimizer.py index 3dbaebbada5c10568858b2aa8c0f421e2b171d91..7c73c9d0e3c6318d52d88f98978178db7750e2d4 100644 --- a/jiuwen/agent_builder/prompt_builder/tune/joint_optimizer.py +++ b/jiuwen/agent_builder/prompt_builder/tune/joint_optimizer.py @@ -10,7 +10,7 @@ import threading import copy from os.path import dirname, join from dataclasses import dataclass -from typing import List, Dict, Optional, Any, Tuple +from typing import List, Optional, Tuple from logging import getLogger logger = getLogger(__name__) @@ -458,8 +458,6 @@ class JointOptimizer: response = self.chat_completion(examples_optimization_template) return self.extract_examples_from_response(response) except (KeyError, TypeError, AttributeError) as e: - import traceback - traceback.print_exc() logger.warning(f"Error occur while generating best reasoning examples: {e}") return [] @@ -578,8 +576,6 @@ class JointOptimizer: logger.info(f"Joint optimization task {task_info.task_id} stopped.") return except Exception as e: - import traceback - traceback.print_exc() context[TaskStatus.TASK_STATUS] = TaskStatus.TASK_FAILED context["run_time"] = calculate_runtime(context.get("create_time", "")) checkpoint = ContextManager().get_checkpoint(task_info.task_id) or context diff --git a/jiuwen/core/utils/llm/model_library/siliconflow.py b/jiuwen/core/utils/llm/model_library/siliconflow.py index c2b9e9b4d6e575fb03fde27a92053a6e9d736f78..8df41b51779b41b8cf8bef160f1516a0cdabe1e3 100644 --- a/jiuwen/core/utils/llm/model_library/siliconflow.py +++ b/jiuwen/core/utils/llm/model_library/siliconflow.py @@ -8,7 +8,7 @@ from pydantic import Field, BaseModel from jiuwen.core.utils.llm.base import BaseChatModel, BaseModelInfo from jiuwen.core.utils.llm.messages import AIMessage from jiuwen.core.utils.llm.messages_chunk import AIMessageChunk -from jiuwen.core.utils.llm.model_utils.defult_model import RequestChatModel +from jiuwen.core.utils.llm.model_utils.default_model import RequestChatModel class Siliconflow(BaseModel, BaseChatModel): diff --git a/jiuwen/core/utils/llm/model_utils/defult_model.py b/jiuwen/core/utils/llm/model_utils/default_model.py similarity index 100% rename from jiuwen/core/utils/llm/model_utils/defult_model.py rename to jiuwen/core/utils/llm/model_utils/default_model.py diff --git a/jiuwen/core/utils/llm/model_utils/model_factory.py b/jiuwen/core/utils/llm/model_utils/model_factory.py index 2482bc3cd69d86d3ec245c05dcfae4fe23ccd9ee..ddfd11ef8b6cd252cffc9ad18662c0d918a74e55 100644 --- a/jiuwen/core/utils/llm/model_utils/model_factory.py +++ b/jiuwen/core/utils/llm/model_utils/model_factory.py @@ -67,7 +67,6 @@ class ModelFactory(metaclass=Singleton): def get_model(self, model_provider: str, model_info: BaseModelInfo) -> BaseChatModel: model_cls = self.model_map.get(model_provider.lower()) - print(model_cls) if not model_cls: available_models = ", ".join(self.model_map.keys()) raise ValueError(f"Unavailable model provider: {model_provider}. Available models: {available_models}")