diff --git a/apps/common/mongo.py b/apps/common/mongo.py index 4e9fbde9c6f81cb3d4bf9448f2141d9f6f883934..17cef16f870b6fafbd8cbd4166a0f27cf4156dbf 100644 --- a/apps/common/mongo.py +++ b/apps/common/mongo.py @@ -15,18 +15,14 @@ logger = logging.getLogger(__name__) class MongoDB: + from pymongo import AsyncMongoClient """MongoDB连接器""" + _client: "AsyncMongoClient" = AsyncMongoClient( + f"mongodb://{urllib.parse.quote_plus(Config().get_config().mongodb.user)}:{urllib.parse.quote_plus(Config().get_config().mongodb.password)}@{Config().get_config().mongodb.host}:{Config().get_config().mongodb.port}/?directConnection=true", + ) - def __init__(self) -> None: - """初始化MongoDB连接器""" - from pymongo import AsyncMongoClient - - self._client = AsyncMongoClient( - f"mongodb://{urllib.parse.quote_plus(Config().get_config().mongodb.user)}:{urllib.parse.quote_plus(Config().get_config().mongodb.password)}@{Config().get_config().mongodb.host}:{Config().get_config().mongodb.port}/?directConnection=true", - ) - - - def get_collection(self, collection_name: str) -> "AsyncCollection": + @staticmethod + def get_collection(collection_name: str) -> "AsyncCollection": """ 获取MongoDB集合 @@ -34,20 +30,20 @@ class MongoDB: :return: 集合对象 :rtype: AsyncCollection """ - return self._client[Config().get_config().mongodb.database][collection_name] + return MongoDB._client[Config().get_config().mongodb.database][collection_name] - - async def clear_collection(self, collection_name: str) -> None: + @staticmethod + async def clear_collection(collection_name: str) -> None: """ 清空MongoDB集合 :param str collection_name: 集合名称 :return: 无 """ - await self._client[Config().get_config().mongodb.database][collection_name].delete_many({}) - + await MongoDB._client[Config().get_config().mongodb.database][collection_name].delete_many({}) - def get_session(self) -> "AsyncClientSession": + @staticmethod + def get_session() -> "AsyncClientSession": """ 获取MongoDB会话 @@ -56,4 +52,4 @@ class MongoDB: :return: 会话对象 :rtype: AsyncClientSession """ - return self._client.start_session() + return MongoDB._client.start_session() diff --git a/apps/common/oidc.py b/apps/common/oidc.py index 5c67b9f64ec76db3546b18e9b073b842b7e3ce12..cd0272362989f0f17e395c432a81618f8fe62063 100644 --- a/apps/common/oidc.py +++ b/apps/common/oidc.py @@ -34,8 +34,7 @@ class OIDCProvider: @staticmethod async def set_token(user_sub: str, access_token: str, refresh_token: str) -> None: """设置MongoDB中的OIDC Token到sessions集合""" - mongo = MongoDB() - sessions_collection = mongo.get_collection("session") + sessions_collection = MongoDB.get_collection("session") await sessions_collection.update_one( {"_id": f"access_token_{user_sub}"}, @@ -67,17 +66,14 @@ class OIDCProvider: """检查登录状态""" return await self.provider.get_login_status(cookie) - async def oidc_logout(self, cookie: dict[str, str]) -> None: """触发OIDC的登出""" return await self.provider.oidc_logout(cookie) - async def get_oidc_token(self, code: str) -> dict[str, Any]: """获取OIDC 访问Token""" return await self.provider.get_oidc_token(code) - async def get_oidc_user(self, access_token: str) -> dict[str, Any]: """获取OIDC 用户信息""" return await self.provider.get_oidc_user(access_token) diff --git a/apps/common/oidc_provider/authelia.py b/apps/common/oidc_provider/authelia.py index aeeb9eb2c3e3c48f5679291fbb33508b5666966b..a799c4048f9f086792269801161f1495d20675d4 100644 --- a/apps/common/oidc_provider/authelia.py +++ b/apps/common/oidc_provider/authelia.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) class AutheliaOIDCProvider(OIDCProviderBase): """Authelia OIDC Provider""" - + # PKCE相关的类变量,用于存储code_verifier _code_verifier: str = "" @@ -27,16 +27,17 @@ class AutheliaOIDCProvider(OIDCProviderBase): def _generate_pkce_params(cls) -> tuple[str, str]: """生成PKCE参数""" # 生成code_verifier (43-128个字符的随机字符串) - code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=') - + code_verifier = base64.urlsafe_b64encode( + secrets.token_bytes(32)).decode('utf-8').rstrip('=') + # 生成code_challenge (code_verifier的SHA256哈希值) code_challenge = base64.urlsafe_b64encode( hashlib.sha256(code_verifier.encode('utf-8')).digest() ).decode('utf-8').rstrip('=') - + # 存储code_verifier供后续使用 cls._code_verifier = code_verifier - + return code_verifier, code_challenge @classmethod @@ -60,7 +61,7 @@ class AutheliaOIDCProvider(OIDCProviderBase): "grant_type": "authorization_code", "code": code, } - + # 如果启用了PKCE,添加code_verifier参数 if login_config.enable_pkce and cls._code_verifier: data["code_verifier"] = cls._code_verifier @@ -114,9 +115,11 @@ class AutheliaOIDCProvider(OIDCProviderBase): result = resp.json() # 获取用户名和默认的user_sub - user_name = result.get("name", result.get("preferred_username", result.get("nickname", ""))) - default_user_sub = result.get("sub", result.get("preferred_username", "")) - + user_name = result.get("name", result.get( + "preferred_username", result.get("nickname", ""))) + default_user_sub = result.get( + "sub", result.get("preferred_username", "")) + # 管理员用户特殊处理逻辑 final_user_sub = await cls._handle_admin_user_sub(user_name, default_user_sub) @@ -124,39 +127,40 @@ class AutheliaOIDCProvider(OIDCProviderBase): "user_sub": final_user_sub, "user_name": user_name, } - + @classmethod async def _handle_admin_user_sub(cls, user_name: str, default_user_sub: str) -> str: """处理管理员用户的user_sub逻辑(仅适用于Authelia)""" from apps.common.config import Config - + config = Config().get_config() - + # 只有在使用Authelia provider时才应用此逻辑 if config.login.provider != "authelia": return default_user_sub - + # 检查是否启用了管理员配置且用户名匹配 if not config.admin.enable or user_name != config.admin.user_name: return default_user_sub - + # 检查数据库中是否已存在管理员用户 try: from apps.common.mongo import MongoDB - mongo = MongoDB() - user_collection = mongo.get_collection("user") - + user_collection = MongoDB.get_collection("user") + existing_admin = await user_collection.find_one({"_id": config.admin.user_sub}) - + if existing_admin: # 数据库中已存在管理员用户,使用默认的user_sub - logger.info(f"[_handle_admin_user_sub] 管理员用户已存在,使用默认user_sub: {default_user_sub}") + logger.info( + f"[_handle_admin_user_sub] 管理员用户已存在,使用默认user_sub: {default_user_sub}") return default_user_sub else: # 数据库中不存在管理员用户,使用配置的管理员user_sub - logger.info(f"[_handle_admin_user_sub] 管理员用户不存在,使用配置的user_sub: {config.admin.user_sub}") + logger.info( + f"[_handle_admin_user_sub] 管理员用户不存在,使用配置的user_sub: {config.admin.user_sub}") return config.admin.user_sub - + except Exception as e: logger.error(f"[_handle_admin_user_sub] 检查管理员用户时出错: {e}") # 出错时使用默认的user_sub @@ -182,7 +186,7 @@ class AutheliaOIDCProvider(OIDCProviderBase): err = f"[Authelia] 获取登录状态失败: {resp.status_code},完整输出: {resp.text}" raise RuntimeError(err) result = resp.json() - + # Authelia 返回用户信息表示已登录,需要获取或生成token # 这里返回空的token,实际使用中可能需要根据具体情况调整 return { @@ -218,7 +222,7 @@ class AutheliaOIDCProvider(OIDCProviderBase): # 生成随机的 state 参数以确保安全性和唯一性 state = secrets.token_urlsafe(32) - + # 基础URL参数 url_params = [ f"client_id={login_config.client_id}", @@ -227,7 +231,7 @@ class AutheliaOIDCProvider(OIDCProviderBase): f"redirect_uri={login_config.redirect_uri}", f"state={state}" ] - + # 如果启用PKCE,添加PKCE参数 if login_config.enable_pkce: code_verifier, code_challenge = cls._generate_pkce_params() @@ -236,7 +240,7 @@ class AutheliaOIDCProvider(OIDCProviderBase): f"code_challenge_method={login_config.pkce_challenge_method}" ]) logger.info("[Authelia] 启用PKCE流程,生成code_challenge参数") - + return f"{login_config.host.rstrip('/')}/api/oidc/authorization?" + "&".join(url_params) @classmethod diff --git a/apps/main.py b/apps/main.py index 3f9f143838c4b4ae87345d90c4424b3a81340ef8..1a280e5ed8e373afa06d3edd50860fcee8ec9631 100644 --- a/apps/main.py +++ b/apps/main.py @@ -160,8 +160,7 @@ async def add_no_auth_user() -> None: from apps.common.config import Config config = Config().get_config() - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") # 使用配置文件中的no_auth设置 user_sub = config.no_auth.user_sub @@ -196,8 +195,7 @@ async def set_administrator() -> None: from apps.common.config import Config config = Config().get_config() - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") # 获取管理员配置 admin_user_sub = config.admin.user_sub @@ -232,8 +230,7 @@ async def clear_user_activity() -> None: """清除所有用户的活跃状态""" from apps.services.activity import Activity from apps.common.mongo import MongoDB - mongo = MongoDB() - activity_collection = mongo.get_collection("activity") + activity_collection = MongoDB.get_collection("activity") await activity_collection.delete_many({}) logging.info("清除所有用户活跃状态完成") @@ -325,11 +322,10 @@ async def startup_file_cleanup(): from apps.scheduler.variable.type import VariableType from apps.common.mongo import MongoDB - mongo = MongoDB() - doc_collection = mongo.get_collection("document") - variables_collection = mongo.get_collection("variables") - record_group_collection = mongo.get_collection("record_group") - conversation_collection = mongo.get_collection("conversation") + doc_collection = MongoDB.get_collection("document") + variables_collection = MongoDB.get_collection("variables") + record_group_collection = MongoDB.get_collection("record_group") + conversation_collection = MongoDB.get_collection("conversation") # 获取所有文档ID all_file_ids = set() @@ -409,12 +405,10 @@ async def cleanup_orphaned_files(): from apps.scheduler.variable.type import VariableType from apps.common.mongo import MongoDB - mongo = MongoDB() - doc_collection = mongo.get_collection("document") - variables_collection = mongo.get_collection("variables") - record_group_collection = mongo.get_collection("record_group") - conversation_collection = mongo.get_collection("conversation") - + doc_collection = MongoDB.get_collection("document") + variables_collection = MongoDB.get_collection("variables") + record_group_collection = MongoDB.get_collection("record_group") + conversation_collection = MongoDB.get_collection("conversation") # 获取所有文档ID all_file_ids = set() async for doc in doc_collection.find({}, {"_id": 1}): @@ -504,8 +498,7 @@ async def start_periodic_cleanup(): logger.info("开始定期清理任务") # 获取活跃的对话列表 - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") # 获取最近24小时内有活动的对话ID cutoff_time = datetime.utcnow() - timedelta(hours=24) diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index a3a494d215bf1df6103035b841cf08fe82926e30..73b6f13c71c279c27c5fbd876016a032abf010fe 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from apps.dependency.user import get_user, verify_user from apps.exceptions import InstancePermissionError +from apps.llm.schema import DefaultModelId from apps.schemas.appcenter import AppFlowInfo, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType from apps.schemas.request_data import CreateAppRequest, ModFavAppRequest @@ -234,38 +235,14 @@ async def get_application( )) if not app_data.llm_id: # 获取系统默认模型 - from apps.common.mongo import MongoDB - from apps.common.config import Config - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") - config = Config().get_config() - - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - - if not system_llm: - await LLMManager.init_system_chat_model() - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - - llm_item = LLMIteam( - llmId=str(system_llm["_id"]), - modelName=system_llm["model_name"], - icon=system_llm["icon"] - ) + llm_collection = await LLMManager.get_llm_by_id(DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) else: llm_collection = await LLMManager.get_llm_by_id(app_data.llm_id) - llm_item = LLMIteam( - llmId=llm_collection.id, - modelName=llm_collection.model_name, - icon=llm_collection.icon - ) + llm_item = LLMIteam( + llmId=llm_collection.id, + modelName=llm_collection.model_name, + icon=llm_collection.icon + ) return JSONResponse( status_code=status.HTTP_200_OK, content=GetAppPropertyRsp( diff --git a/apps/routers/variable.py b/apps/routers/variable.py index d8a4bc0e17a361fafb101c0a5d8d64e5da844bfb..a8d752f5e95b37868021a1d6b72f69c25d054fd1 100644 --- a/apps/routers/variable.py +++ b/apps/routers/variable.py @@ -38,27 +38,27 @@ async def _get_predecessor_node_variables( current_step_id: str ) -> List: """获取前置节点的输出变量(优化版本,使用缓存) - + Args: user_sub: 用户ID flow_id: 流程ID conversation_id: 对话ID(可选,配置阶段可能为None) current_step_id: 当前步骤ID - + Returns: List: 前置节点的输出变量列表 """ try: variables = [] pool_manager = await get_pool_manager() - + if conversation_id: # 运行阶段:从对话池获取实际的前置节点变量 conversation_pool = await pool_manager.get_conversation_pool(conversation_id) if conversation_pool: # 获取所有对话变量 all_conversation_vars = await conversation_pool.list_variables() - + # 筛选出前置节点的输出变量(格式为 node_id.key) for var in all_conversation_vars: var_name = var.name @@ -66,7 +66,7 @@ async def _get_predecessor_node_variables( if "." in var_name and not var_name.startswith("system."): # 提取节点ID node_id = var_name.split(".")[0] - + # 检查是否为前置节点(这里可以根据需要添加更精确的前置判断逻辑) if node_id != current_step_id: # 不是当前节点的变量 variables.append(var) @@ -74,30 +74,30 @@ async def _get_predecessor_node_variables( try: # 使用缓存获取变量列表 from apps.services.predecessor_cache_service import PredecessorCacheService - + cached_var_data = await PredecessorCacheService.get_predecessor_variables_optimized( flow_id, current_step_id, user_sub, max_wait_time=5 ) - + # 将缓存的变量数据转换为Variable对象 for var_data in cached_var_data: try: var_name = var_data['name'] - + # 检查是否为当前步骤的输出变量 if "." in var_name and not var_name.startswith("system."): # 提取节点ID node_id = var_name.split(".")[0] - + # 排除当前步骤的输出变量 if node_id == current_step_id: continue # 跳过当前步骤的输出变量 - + from apps.scheduler.variable.variables import create_variable from apps.scheduler.variable.base import VariableMetadata from apps.scheduler.variable.type import VariableType, VariableScope from datetime import datetime - + # 创建变量元数据 metadata = VariableMetadata( name=var_name, @@ -105,28 +105,32 @@ async def _get_predecessor_node_variables( scope=VariableScope(var_data['scope']), description=var_data.get('description', ''), created_by=user_sub, - created_at=datetime.fromisoformat(var_data['created_at'].replace('Z', '+00:00')), - updated_at=datetime.fromisoformat(var_data['updated_at'].replace('Z', '+00:00')) + created_at=datetime.fromisoformat( + var_data['created_at'].replace('Z', '+00:00')), + updated_at=datetime.fromisoformat( + var_data['updated_at'].replace('Z', '+00:00')) ) - + # 创建变量对象,并附加缓存的节点信息(使用None避免类型验证失败) - variable = create_variable(metadata, var_data.get('value')) - + variable = create_variable( + metadata, var_data.get('value')) + # 将节点信息附加到变量对象上(用于后续响应格式化) if hasattr(variable, '_cache_data'): variable._cache_data = var_data else: # 如果对象不支持动态属性,我们可以创建一个包装类或者在响应时处理 setattr(variable, '_cache_data', var_data) - + variables.append(variable) - + except Exception as var_create_error: logger.warning(f"创建缓存变量对象失败: {var_create_error}") continue - - logger.info(f"配置阶段:为节点 {current_step_id} 找到前置节点变量总数: {len([v for v in variables if hasattr(v, 'name') and '.' in v.name and not v.name.startswith('system.')])}") - + + logger.info( + f"配置阶段:为节点 {current_step_id} 找到前置节点变量总数: {len([v for v in variables if hasattr(v, 'name') and '.' in v.name and not v.name.startswith('system.')])}") + except Exception as flow_error: logger.warning(f"配置阶段获取前置节点变量失败,降级到实时解析: {flow_error}") # 降级到原有的实时解析逻辑 @@ -134,17 +138,14 @@ async def _get_predecessor_node_variables( flow_id, current_step_id, user_sub ) variables.extend(predecessor_vars) - + return variables - + except Exception as e: logger.error(f"获取前置节点变量失败: {e}") return [] - - - # 请求和响应模型 class CreateVariableRequest(BaseModel): """创建变量请求""" @@ -153,26 +154,38 @@ class CreateVariableRequest(BaseModel): scope: VariableScope = Field(description="变量作用域") value: Optional[Any] = Field(default=None, description="变量值") description: Optional[str] = Field(default=None, description="变量描述") - flow_id: Optional[str] = Field(default=None, description="流程ID(环境级和对话级变量必需)") + flow_id: Optional[str] = Field( + default=None, description="流程ID(环境级和对话级变量必需)") # 文件类型变量专用字段 - supported_types: Optional[List[str]] = Field(default=None, description="支持的文件类型(文件类型变量专用)") - upload_methods: Optional[List[str]] = Field(default=None, description="支持的上传方式列表(文件类型变量专用)") - max_files: Optional[int] = Field(default=None, description="最大上传文件数(文件类型变量专用)") - max_file_size: Optional[int] = Field(default=None, description="单个文件最大大小(MB,文件类型变量专用)") - required: Optional[bool] = Field(default=None, description="文件是否必填(文件类型变量专用)") + supported_types: Optional[List[str]] = Field( + default=None, description="支持的文件类型(文件类型变量专用)") + upload_methods: Optional[List[str]] = Field( + default=None, description="支持的上传方式列表(文件类型变量专用)") + max_files: Optional[int] = Field( + default=None, description="最大上传文件数(文件类型变量专用)") + max_file_size: Optional[int] = Field( + default=None, description="单个文件最大大小(MB,文件类型变量专用)") + required: Optional[bool] = Field( + default=None, description="文件是否必填(文件类型变量专用)") class UpdateVariableRequest(BaseModel): """更新变量请求""" value: Optional[Any] = Field(default=None, description="新的变量值") - var_type: Optional[VariableType] = Field(default=None, description="新的变量类型") + var_type: Optional[VariableType] = Field( + default=None, description="新的变量类型") description: Optional[str] = Field(default=None, description="新的变量描述") # 文件类型变量专用字段(用于更新文件配置) - supported_types: Optional[List[str]] = Field(default=None, description="支持的文件类型(文件类型变量专用)") - upload_methods: Optional[List[str]] = Field(default=None, description="支持的上传方式列表(文件类型变量专用)") - max_files: Optional[int] = Field(default=None, description="最大上传文件数(文件类型变量专用)") - max_file_size: Optional[int] = Field(default=None, description="单个文件最大大小(MB,文件类型变量专用)") - required: Optional[bool] = Field(default=None, description="文件是否必填(文件类型变量专用)") + supported_types: Optional[List[str]] = Field( + default=None, description="支持的文件类型(文件类型变量专用)") + upload_methods: Optional[List[str]] = Field( + default=None, description="支持的上传方式列表(文件类型变量专用)") + max_files: Optional[int] = Field( + default=None, description="最大上传文件数(文件类型变量专用)") + max_file_size: Optional[int] = Field( + default=None, description="单个文件最大大小(MB,文件类型变量专用)") + required: Optional[bool] = Field( + default=None, description="文件是否必填(文件类型变量专用)") class VariableResponse(BaseModel): @@ -232,7 +245,7 @@ async def create_variable( status_code=status.HTTP_403_FORBIDDEN, detail="不允许创建系统级变量" ) - + # 类型转换和验证 converted_value = None if request.value is not None: @@ -244,29 +257,33 @@ async def create_variable( "supported_types": request.supported_types or [], "upload_methods": request.upload_methods or ["manual"], "max_files": request.max_files or (1 if request.var_type == VariableType.FILE else 10), - "max_file_size": request.max_file_size or (10 * 1024 * 1024), # 默认10MB - "required": request.required if request.required is not None else False # 默认非必填 + # 默认10MB + "max_file_size": request.max_file_size or (10 * 1024 * 1024), + "required": request.required if request.required is not None else False # 默认非必填 } - + # 如果提供了value,合并到配置中 if isinstance(request.value, dict): file_config.update(request.value) else: # 如果value不是字典,将其作为文件ID处理 if request.var_type == VariableType.FILE: - file_config["file_id"] = request.value if isinstance(request.value, str) else "" + file_config["file_id"] = request.value if isinstance( + request.value, str) else "" else: - file_config["file_ids"] = request.value if isinstance(request.value, list) else [] - + file_config["file_ids"] = request.value if isinstance( + request.value, list) else [] + converted_value = await convert_file_value_by_type( - file_config, - request.var_type, - user_sub, + file_config, + request.var_type, + user_sub, conversation_id=None, # 创建变量时没有conversation_id flow_id=request.flow_id ) else: - converted_value = convert_value_by_type(request.value, request.var_type) + converted_value = convert_value_by_type( + request.value, request.var_type) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -279,17 +296,18 @@ async def create_variable( "supported_types": request.supported_types or [], "upload_methods": request.upload_methods or ["manual"], "max_files": request.max_files or (1 if request.var_type == VariableType.FILE else 10), - "max_file_size": request.max_file_size or (10 * 1024 * 1024), # 默认10MB - "required": request.required if request.required is not None else False # 默认非必填 - }, - request.var_type, - user_sub, + # 默认10MB + "max_file_size": request.max_file_size or (10 * 1024 * 1024), + "required": request.required if request.required is not None else False # 默认非必填 + }, + request.var_type, + user_sub, conversation_id=getattr(request, 'conversation_id', None), flow_id=request.flow_id ) - + pool_manager = await get_pool_manager() - + # 根据作用域获取合适的变量池 if request.scope == VariableScope.USER: # 用户级变量需要user_sub参数 @@ -321,13 +339,13 @@ async def create_variable( status_code=status.HTTP_400_BAD_REQUEST, detail=f"不支持的变量作用域: {request.scope.value}" ) - + if not pool: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="无法获取变量池" ) - + # 根据作用域创建不同类型的变量 if request.scope == VariableScope.CONVERSATION: # 创建对话变量模板 @@ -348,13 +366,12 @@ async def create_variable( created_by=user_sub ) - return ResponseData( code=200, message="变量创建成功", result={"variable_name": variable.name}, ) - + except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -380,14 +397,16 @@ async def update_variable( user_sub: Annotated[str, Depends(get_user)], name: str = Query(..., description="变量名称"), scope: VariableScope = Query(..., description="变量作用域"), - flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), - conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"), + flow_id: Optional[str] = Query( + default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query( + default=None, description="对话ID(对话级变量运行时必需)"), request: UpdateVariableRequest = Body(...), ) -> ResponseData: """更新变量值""" try: pool_manager = await get_pool_manager() - + # 根据作用域获取合适的变量池 if scope == VariableScope.USER: if not user_sub: @@ -428,13 +447,13 @@ async def update_variable( status_code=status.HTTP_400_BAD_REQUEST, detail=f"不支持的变量作用域: {scope.value}" ) - + if not pool: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="无法获取变量池" ) - + # 类型转换和验证(仅当提供了新值和新类型时) converted_value = request.value if request.value is not None and request.var_type is not None: @@ -449,31 +468,35 @@ async def update_variable( "supported_types": request.supported_types or [], "upload_methods": request.upload_methods or ["manual"], "max_files": request.max_files or (1 if request.var_type == VariableType.FILE else 10), - "max_file_size": request.max_file_size or (10 * 1024 * 1024), # 默认10MB - "required": request.required if request.required is not None else False # 默认非必填 + # 默认10MB + "max_file_size": request.max_file_size or (10 * 1024 * 1024), + "required": request.required if request.required is not None else False # 默认非必填 } - + # 如果提供了value,合并到配置中 if isinstance(request.value, dict): file_config.update(request.value) else: # 如果value不是字典,将其作为文件ID处理 if request.var_type == VariableType.FILE: - file_config["file_id"] = request.value if isinstance(request.value, str) else "" + file_config["file_id"] = request.value if isinstance( + request.value, str) else "" else: - file_config["file_ids"] = request.value if isinstance(request.value, list) else [] - + file_config["file_ids"] = request.value if isinstance( + request.value, list) else [] + file_value = file_config - + converted_value = await convert_file_value_by_type( - file_value, - request.var_type, - user_sub, + file_value, + request.var_type, + user_sub, conversation_id=conversation_id, flow_id=flow_id ) else: - converted_value = convert_value_by_type(request.value, request.var_type) + converted_value = convert_value_by_type( + request.value, request.var_type) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -492,7 +515,7 @@ async def update_variable( existing_config = {} if isinstance(existing_variable.value, dict): existing_config = existing_variable.value.copy() - + # 构建新的文件配置 file_config = { "supported_types": request.supported_types or existing_config.get("supported_types", []), @@ -501,13 +524,15 @@ async def update_variable( "max_file_size": request.max_file_size or existing_config.get("max_file_size", (10 * 1024 * 1024)), "required": request.required if request.required is not None else existing_config.get("required", False) } - + # 保留现有的文件ID if existing_variable.metadata.var_type == VariableType.FILE: - file_config["file_id"] = existing_config.get("file_id", "") + file_config["file_id"] = existing_config.get( + "file_id", "") else: - file_config["file_ids"] = existing_config.get("file_ids", []) - + file_config["file_ids"] = existing_config.get( + "file_ids", []) + # 如果提供了新的value,使用新的value处理文件ID if request.value is not None: if isinstance(request.value, dict): @@ -515,21 +540,24 @@ async def update_variable( else: # 如果value不是字典,将其作为文件ID处理 if existing_variable.metadata.var_type == VariableType.FILE: - file_config["file_id"] = request.value if isinstance(request.value, str) else "" + file_config["file_id"] = request.value if isinstance( + request.value, str) else "" else: - file_config["file_ids"] = request.value if isinstance(request.value, list) else [] - + file_config["file_ids"] = request.value if isinstance( + request.value, list) else [] + file_value = file_config - + converted_value = await convert_file_value_by_type( - file_value, - existing_variable.metadata.var_type, - user_sub, + file_value, + existing_variable.metadata.var_type, + user_sub, conversation_id=conversation_id, flow_id=flow_id ) else: - converted_value = convert_value_by_type(request.value, existing_variable.metadata.var_type) + converted_value = convert_value_by_type( + request.value, existing_variable.metadata.var_type) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -538,7 +566,7 @@ async def update_variable( except Exception: # 如果获取现有变量失败,使用原值 converted_value = request.value - + # 更新变量 variable = await pool.update_variable( name=name, @@ -546,13 +574,13 @@ async def update_variable( var_type=request.var_type, description=request.description ) - + return ResponseData( code=200, message="变量更新成功", result={"variable_name": variable.name} ) - + except ValueError as e: logger.error(f"更新变量失败(ValueError): {e}", exc_info=True) raise HTTPException( @@ -585,13 +613,15 @@ async def delete_variable( user_sub: Annotated[str, Depends(get_user)], name: str = Query(..., description="变量名称"), scope: VariableScope = Query(..., description="变量作用域"), - flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), - conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"), + flow_id: Optional[str] = Query( + default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query( + default=None, description="对话ID(对话级变量运行时必需)"), ) -> ResponseData: """删除变量""" try: pool_manager = await get_pool_manager() - + # 根据作用域获取合适的变量池 if scope == VariableScope.USER: if not user_sub: @@ -632,28 +662,28 @@ async def delete_variable( status_code=status.HTTP_400_BAD_REQUEST, detail=f"不支持的变量作用域: {scope.value}" ) - + if not pool: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="无法获取变量池" ) - + # 删除变量 success = await pool.delete_variable(name) - + if not success: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="变量不存在" ) - + return ResponseData( code=200, message="变量删除成功", result={"variable_name": name} ) - + except ValueError as e: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -682,47 +712,52 @@ async def get_variable( user_sub: Annotated[str, Depends(get_user)], name: str = Query(..., description="变量名称"), scope: VariableScope = Query(..., description="变量作用域"), - flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), - conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"), + flow_id: Optional[str] = Query( + default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query( + default=None, description="对话ID(系统级和对话级变量必需)"), ) -> VariableResponse: """获取单个变量""" try: pool_manager = await get_pool_manager() - + # 根据作用域获取变量 variable = await pool_manager.get_variable_from_any_pool( name=name, scope=scope, user_id=user_sub if scope == VariableScope.USER else None, - flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, - conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None + flow_id=flow_id if scope in [ + VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=conversation_id if scope in [ + VariableScope.SYSTEM, VariableScope.CONVERSATION] else None ) - + if not variable: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="变量不存在" ) - + # 检查权限 if not variable.can_access(user_sub): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="没有权限访问此变量" ) - + # 构建响应 var_dict = variable.to_dict() return VariableResponse( name=variable.name, var_type=variable.var_type.value, scope=variable.scope.value, - value=str(var_dict["value"]) if var_dict["value"] is not None else "", + value=str(var_dict["value"] + ) if var_dict["value"] is not None else "", description=variable.metadata.description, created_at=variable.metadata.created_at.isoformat(), updated_at=variable.metadata.updated_at.isoformat(), ) - + except HTTPException: raise except Exception as e: @@ -741,75 +776,83 @@ async def get_variable( async def list_variables( user_sub: Annotated[str, Depends(get_user)], scope: VariableScope = Query(..., description="变量作用域"), - flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), - conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"), - current_step_id: Optional[str] = Query(default=None, description="当前步骤ID(用于获取前置节点变量)"), - exclude_pattern: Optional[str] = Query(default=None, description="排除模式:'step_id'排除包含.的变量名") + flow_id: Optional[str] = Query( + default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query( + default=None, description="对话ID(系统级和对话级变量必需)"), + current_step_id: Optional[str] = Query( + default=None, description="当前步骤ID(用于获取前置节点变量)"), + exclude_pattern: Optional[str] = Query( + default=None, description="排除模式:'step_id'排除包含.的变量名") ) -> VariableListResponse: """列出指定作用域的变量""" try: pool_manager = await get_pool_manager() - + # 获取变量列表 variables = await pool_manager.list_variables_from_any_pool( scope=scope, user_id=user_sub if scope == VariableScope.USER else None, - flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, - conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None + flow_id=flow_id if scope in [ + VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=conversation_id if scope in [ + VariableScope.SYSTEM, VariableScope.CONVERSATION] else None ) - + # 如果是对话级变量且提供了current_step_id,则额外获取前置节点的输出变量 if scope == VariableScope.CONVERSATION and current_step_id and flow_id: predecessor_variables = await _get_predecessor_node_variables( user_sub, flow_id, conversation_id, current_step_id ) variables.extend(predecessor_variables) - + # 应用排除模式过滤 if exclude_pattern == "step_id" and scope == VariableScope.CONVERSATION: # 排除包含"."的变量名(即节点特定变量),只保留全局对话变量 variables = [var for var in variables if "." not in var.name] - + # 过滤权限并构建响应 filtered_variables = [] for variable in variables: if variable.can_access(user_sub): var_dict = variable.to_dict() - + # 检查是否为前置节点变量 is_predecessor_var = ( - "." in variable.name and + "." in variable.name and not variable.name.startswith("system.") and scope == VariableScope.CONVERSATION and flow_id ) - + if is_predecessor_var: # 前置节点变量特殊处理 parts = variable.name.split(".", 1) if len(parts) == 2: step_id, var_name = parts - + # 确保不是当前步骤的输出变量(双重保险) if current_step_id and step_id == current_step_id: continue - + # 优先使用缓存数据中的节点信息 if hasattr(variable, '_cache_data') and variable._cache_data: cache_data = variable._cache_data step_name = cache_data.get('step_name', step_id) - step_id_from_cache = cache_data.get('step_id', step_id) + step_id_from_cache = cache_data.get( + 'step_id', step_id) else: # 降级到实时获取节点信息 node_info = await _get_node_info_by_step_id(flow_id, step_id) step_name = node_info["name"] step_id_from_cache = node_info["step_id"] - + filtered_variables.append(VariableResponse( name=var_name, # 只保留变量名部分 var_type=variable.var_type.value, scope=variable.scope.value, - value=str(var_dict["value"]) if var_dict["value"] is not None else "", + value=str( + var_dict["value"]) if var_dict["value"] is not None else "", description=variable.metadata.description, created_at=variable.metadata.created_at.isoformat(), updated_at=variable.metadata.updated_at.isoformat(), @@ -822,7 +865,8 @@ async def list_variables( name=variable.name, var_type=variable.var_type.value, scope=variable.scope.value, - value=str(var_dict["value"]) if var_dict["value"] is not None else "", + value=str( + var_dict["value"]) if var_dict["value"] is not None else "", description=variable.metadata.description, created_at=variable.metadata.created_at.isoformat(), updated_at=variable.metadata.updated_at.isoformat(), @@ -833,17 +877,18 @@ async def list_variables( name=variable.name, var_type=variable.var_type.value, scope=variable.scope.value, - value=str(var_dict["value"]) if var_dict["value"] is not None else "", + value=str( + var_dict["value"]) if var_dict["value"] is not None else "", description=variable.metadata.description, created_at=variable.metadata.created_at.isoformat(), updated_at=variable.metadata.updated_at.isoformat(), )) - + return VariableListResponse( variables=filtered_variables, total=len(filtered_variables) ) - + except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -870,18 +915,18 @@ async def parse_template( flow_id=request.flow_id, conversation_id=None, # 不再使用conversation_id ) - + # 解析模板 parsed_template = await parser.parse_template(request.template) - + # 提取使用的变量 variables_used = await parser.extract_variables(request.template) - + return ParseTemplateResponse( parsed_template=parsed_template, variables_used=variables_used ) - + except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -908,15 +953,15 @@ async def validate_template( flow_id=request.flow_id, conversation_id=None, # 不再使用conversation_id ) - + # 验证模板 is_valid, invalid_refs = await parser.validate_template(request.template) - + return ValidateTemplateResponse( is_valid=is_valid, invalid_references=invalid_refs ) - + except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -957,18 +1002,18 @@ async def clear_conversation_variables( pool_manager = await get_pool_manager() # 清空工作流的对话级变量 await pool_manager.clear_conversation_variables(flow_id) - + return ResponseData( code=200, message="工作流对话变量已清空", result={"flow_id": flow_id} ) - + except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"清空对话变量失败: {str(e)}" - ) + ) async def _get_node_info_by_step_id(flow_id: str, step_id: str) -> Dict[str, str]: @@ -977,7 +1022,7 @@ async def _get_node_info_by_step_id(flow_id: str, step_id: str) -> Dict[str, str flow_item = await _get_flow_by_flow_id(flow_id) if not flow_item: return {"name": step_id, "step_id": step_id} # 降级返回step_id作为名称 - + # 查找对应的节点 for node in flow_item.nodes: if node.step_id == step_id: @@ -985,41 +1030,42 @@ async def _get_node_info_by_step_id(flow_id: str, step_id: str) -> Dict[str, str "name": node.name or step_id, # 如果没有名称则使用step_id "step_id": step_id } - + # 如果没有找到节点,返回默认值 return {"name": step_id, "step_id": step_id} - + except Exception as e: logger.error(f"获取节点信息失败: {e}") return {"name": step_id, "step_id": step_id} async def _get_predecessor_variables_from_topology( - flow_id: str, - current_step_id: str, + flow_id: str, + current_step_id: str, user_sub: str ) -> List: """通过工作流拓扑分析获取前置节点变量""" try: variables = [] - + # 直接通过flow_id获取工作流拓扑信息 flow_item = await _get_flow_by_flow_id(flow_id) if not flow_item: logger.warning(f"无法获取工作流信息: flow_id={flow_id}") return variables - + # 分析前置节点 predecessor_nodes = _find_predecessor_nodes(flow_item, current_step_id) - + # 为每个前置节点创建潜在的输出变量 for node in predecessor_nodes: node_vars = await _create_node_output_variables(node, user_sub) variables.extend(node_vars) - - logger.info(f"通过拓扑分析为节点 {current_step_id} 创建了 {len(variables)} 个前置节点变量") + + logger.info( + f"通过拓扑分析为节点 {current_step_id} 创建了 {len(variables)} 个前置节点变量") return variables - + except Exception as e: logger.error(f"通过拓扑分析获取前置节点变量失败: {e}") return [] @@ -1029,25 +1075,25 @@ async def _get_flow_by_flow_id(flow_id: str): """直接通过flow_id获取工作流信息""" try: from apps.common.mongo import MongoDB - - app_collection = MongoDB().get_collection("app") - + + app_collection = MongoDB.get_collection("app") + # 查询包含此flow_id的app,同时获取app_id app_record = await app_collection.find_one( {"flows.id": flow_id}, {"_id": 1} ) - + if not app_record: logger.warning(f"未找到包含flow_id {flow_id} 的应用") return None - + app_id = app_record["_id"] - + # 使用现有的FlowManager方法获取flow flow_item = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) return flow_item - + except Exception as e: logger.error(f"通过flow_id获取工作流失败: {e}") return None @@ -1057,7 +1103,7 @@ def _find_predecessor_nodes(flow_item, current_step_id: str) -> List: """在工作流中查找前置节点""" try: predecessor_nodes = [] - + # 遍历边,找到指向当前节点的边 for edge in flow_item.edges: if edge.target_node == current_step_id: @@ -1068,10 +1114,10 @@ def _find_predecessor_nodes(flow_item, current_step_id: str) -> List: ) if source_node: predecessor_nodes.append(source_node) - + logger.info(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点") return predecessor_nodes - + except Exception as e: logger.error(f"查找前置节点失败: {e}") return [] @@ -1083,13 +1129,13 @@ async def _create_node_output_variables(node, user_sub: str) -> List: from apps.scheduler.variable.variables import create_variable from apps.scheduler.variable.base import VariableMetadata from datetime import datetime, UTC - + variables = [] node_id = node.step_id - + # 调试:输出节点的完整参数信息 logger.info(f"节点 {node_id} 的参数结构: {node.parameters}") - + # 统一从节点的output_parameters创建变量 output_params = {} if hasattr(node, 'parameters') and node.parameters: @@ -1098,14 +1144,15 @@ async def _create_node_output_variables(node, user_sub: str) -> List: output_params = node.parameters.get('output_parameters', {}) logger.info(f"从字典中获取output_parameters: {output_params}") else: - output_params = getattr(node.parameters, 'output_parameters', {}) + output_params = getattr( + node.parameters, 'output_parameters', {}) logger.info(f"从对象属性中获取output_parameters: {output_params}") - + # 如果没有配置output_parameters,跳过此节点 if not output_params: logger.info(f"节点 {node_id} 没有配置output_parameters,跳过创建输出变量") return variables - + # 遍历output_parameters中的每个key-value对,创建对应的变量 for param_name, param_config in output_params.items(): # 解析参数配置 @@ -1116,7 +1163,7 @@ async def _create_node_output_variables(node, user_sub: str) -> List: # 如果param_config不是字典,可能是简单的类型字符串 param_type = str(param_config) if param_config else 'string' description = '' - + # 确定变量类型 var_type = VariableType.STRING # 默认类型 if param_type == 'number': @@ -1143,7 +1190,7 @@ async def _create_node_output_variables(node, user_sub: str) -> List: var_type = VariableType.FILE elif param_type == 'secret': var_type = VariableType.SECRET - + # 创建变量元数据 metadata = VariableMetadata( name=f"{node_id}.{param_name}", @@ -1154,14 +1201,15 @@ async def _create_node_output_variables(node, user_sub: str) -> List: created_at=datetime.now(UTC), updated_at=datetime.now(UTC) ) - + # 创建变量对象(使用None作为默认值,避免类型验证失败) variable = create_variable(metadata, None) # 配置阶段的潜在变量,值为None variables.append(variable) - - logger.info(f"为节点 {node_id} 创建了 {len(variables)} 个输出变量: {[v.name for v in variables]}") + + logger.info( + f"为节点 {node_id} 创建了 {len(variables)} 个输出变量: {[v.name for v in variables]}") return variables - + except Exception as e: logger.error(f"创建节点输出变量失败: {e}") return [] @@ -1170,25 +1218,25 @@ async def _create_node_output_variables(node, user_sub: str) -> List: def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: """ 根据变量类型将字符串值转换为对应的Python类型 - + Args: value: 字符串格式的值 var_type: 变量类型 - + Returns: 转换后的值 - + Raises: ValueError: 当值无法转换为指定类型时 """ if value is None: return None - + try: match var_type: case VariableType.STRING | VariableType.SECRET: return str(value) - + case VariableType.NUMBER: # 尝试转换为数字 if isinstance(value, str) and value.strip() == "": @@ -1198,7 +1246,7 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: return float(value) else: return int(value) - + case VariableType.BOOLEAN: # 处理布尔值转换 if isinstance(value, bool): @@ -1212,7 +1260,7 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: else: raise ValueError(f"无法将 '{value}' 转换为布尔值") return bool(value) - + case VariableType.OBJECT: # 处理对象类型 if isinstance(value, dict): @@ -1223,7 +1271,7 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: except json.JSONDecodeError as e: raise ValueError(f"无法解析JSON对象: {e}") return dict(value) - + case VariableType.ARRAY_STRING: # 处理字符串数组 if isinstance(value, list): @@ -1240,7 +1288,7 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: # 按逗号分割 return [item.strip() for item in value.split(',') if item.strip()] return list(value) - + case VariableType.ARRAY_NUMBER: # 处理数字数组 if isinstance(value, list): @@ -1254,14 +1302,15 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: if isinstance(item, (int, float)): result.append(item) else: - result.append(float(item) if '.' in str(item) else int(item)) + result.append(float(item) if '.' in str( + item) else int(item)) return result else: raise ValueError("期望数组格式") except (json.JSONDecodeError, ValueError) as e: raise ValueError(f"无法解析数字数组: {e}") return list(value) - + case VariableType.ARRAY_BOOLEAN: # 处理布尔数组 if isinstance(value, list): @@ -1276,7 +1325,7 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: except json.JSONDecodeError as e: raise ValueError(f"无法解析布尔数组: {e}") return list(value) - + case VariableType.ARRAY_OBJECT: # 处理对象数组 if isinstance(value, list): @@ -1291,7 +1340,7 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: except json.JSONDecodeError as e: raise ValueError(f"无法解析对象数组: {e}") return list(value) - + case VariableType.FILE | VariableType.ARRAY_FILE: # 文件类型需要保持字典格式或正确解析字符串 if isinstance(value, dict): @@ -1311,7 +1360,7 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: else: # 其他类型尝试转换为字符串 return str(value) - + case VariableType.ARRAY_SECRET: # 密钥数组 if isinstance(value, list): @@ -1326,11 +1375,11 @@ def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: except json.JSONDecodeError as e: raise ValueError(f"无法解析密钥数组: {e}") return list(value) - + case _: # 默认返回字符串 return str(value) - + except (ValueError, TypeError, json.JSONDecodeError) as e: raise ValueError(f"无法将值 '{value}' 转换为类型 '{var_type.value}': {str(e)}") @@ -1341,7 +1390,7 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: match var_type: case VariableType.STRING: return str(value) - + case VariableType.NUMBER: if isinstance(value, (int, float)): return value @@ -1349,7 +1398,7 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: return float(value) if '.' in value else int(value) else: return float(value) - + case VariableType.BOOLEAN: if isinstance(value, bool): return value @@ -1365,7 +1414,7 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: return bool(value) else: return bool(value) - + case VariableType.OBJECT: if isinstance(value, dict): return value @@ -1376,10 +1425,10 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: raise ValueError(f"无法解析JSON对象: {e}") else: raise ValueError(f"无法将类型 {type(value).__name__} 转换为对象") - + case VariableType.SECRET: return str(value) - + case VariableType.ARRAY_STRING: # 如果已经是列表,检查元素类型 if isinstance(value, list): @@ -1397,7 +1446,7 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: return [item.strip() for item in value.split(',') if item.strip()] else: return [str(value)] - + case VariableType.ARRAY_NUMBER: # 如果已经是列表,检查元素类型 if isinstance(value, list): @@ -1406,7 +1455,8 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: if isinstance(item, (int, float)): result.append(item) else: - result.append(float(item) if '.' in str(item) else int(item)) + result.append(float(item) if '.' in str( + item) else int(item)) return result if isinstance(value, str): try: @@ -1417,7 +1467,8 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: if isinstance(item, (int, float)): result.append(item) else: - result.append(float(item) if '.' in str(item) else int(item)) + result.append(float(item) if '.' in str( + item) else int(item)) return result else: raise ValueError("期望数组格式") @@ -1425,7 +1476,7 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: raise ValueError(f"无法解析数字数组: {e}") else: return [value] - + case VariableType.ARRAY_BOOLEAN: # 如果已经是列表,检查元素类型 if isinstance(value, list): @@ -1441,7 +1492,7 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: raise ValueError(f"无法解析布尔数组: {e}") else: return [value] - + case VariableType.ARRAY_OBJECT: # 如果已经是列表,检查元素类型 if isinstance(value, list): @@ -1457,7 +1508,7 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: raise ValueError(f"无法解析对象数组: {e}") else: return [value] - + case VariableType.ARRAY_SECRET: # 密钥数组 if isinstance(value, list): @@ -1473,20 +1524,21 @@ def convert_value_by_type(value: Any, var_type: VariableType) -> Any: raise ValueError(f"无法解析密钥数组: {e}") else: return [str(value)] - + case _: # 默认返回字符串 return str(value) - + except (ValueError, TypeError, json.JSONDecodeError) as e: - raise ValueError(f"无法将值 '{value}' (类型: {type(value).__name__}) 转换为类型 '{var_type.value}': {str(e)}") + raise ValueError( + f"无法将值 '{value}' (类型: {type(value).__name__}) 转换为类型 '{var_type.value}': {str(e)}") async def convert_file_value_by_type(value: Any, var_type: VariableType, user_sub: str, conversation_id: Optional[str] = None, flow_id: Optional[str] = None) -> Any: """异步处理文件类型的值转换""" if var_type not in [VariableType.FILE, VariableType.ARRAY_FILE]: return value - + # 文件类型变量的默认存储结构 if var_type == VariableType.FILE: # 单个文件变量结构 @@ -1498,10 +1550,13 @@ async def convert_file_value_by_type(value: Any, var_type: VariableType, user_su # 缺少必要字段,使用默认结构 return { "file_id": "", # 文件ID,默认为空 - "supported_types": value.get("supported_types", []), # 支持的文件类型 - "upload_methods": value.get("upload_methods", ["manual"]), # 支持的上传方式 + # 支持的文件类型 + "supported_types": value.get("supported_types", []), + # 支持的上传方式 + "upload_methods": value.get("upload_methods", ["manual"]), "max_files": value.get("max_files", 1), # 最大文件数(单个文件为1) - "max_file_size": value.get("max_file_size", 10 * 1024 * 1024), # 默认10MB + # 默认10MB + "max_file_size": value.get("max_file_size", 10 * 1024 * 1024), "required": value.get("required", False) # 是否必填 } elif isinstance(value, str): @@ -1513,19 +1568,19 @@ async def convert_file_value_by_type(value: Any, var_type: VariableType, user_su "supported_types": [], "upload_methods": ["manual"], "max_files": 1, - "max_file_size": 10 * 1024 * 1024, # 默认10MB + "max_file_size": 10 * 1024 * 1024, # 默认10MB "required": False # 默认非必填 } else: # 不是文件ID,使用默认结构 - return { - "file_id": "", - "supported_types": [], - "upload_methods": ["manual"], - "max_files": 1, - "max_file_size": 10 * 1024 * 1024, # 默认10MB - "required": False # 默认非必填 - } + return { + "file_id": "", + "supported_types": [], + "upload_methods": ["manual"], + "max_files": 1, + "max_file_size": 10 * 1024 * 1024, # 默认10MB + "required": False # 默认非必填 + } else: # 其他类型,使用默认结构 return { @@ -1533,10 +1588,10 @@ async def convert_file_value_by_type(value: Any, var_type: VariableType, user_su "supported_types": [], "upload_methods": ["manual"], "max_files": 1, - "max_file_size": 10 * 1024 * 1024, # 默认10MB + "max_file_size": 10 * 1024 * 1024, # 默认10MB "required": False # 默认非必填 } - + elif var_type == VariableType.ARRAY_FILE: # 文件列表变量结构 if isinstance(value, dict): @@ -1547,10 +1602,13 @@ async def convert_file_value_by_type(value: Any, var_type: VariableType, user_su # 缺少必要字段,使用默认结构 return { "file_ids": [], # 文件ID列表,默认为空 - "supported_types": value.get("supported_types", []), # 支持的文件类型 - "upload_methods": value.get("upload_methods", ["manual"]), # 支持的上传方式 + # 支持的文件类型 + "supported_types": value.get("supported_types", []), + # 支持的上传方式 + "upload_methods": value.get("upload_methods", ["manual"]), "max_files": value.get("max_files", 10), # 最大文件数 - "max_file_size": value.get("max_file_size", 10 * 1024 * 1024), # 默认10MB + # 默认10MB + "max_file_size": value.get("max_file_size", 10 * 1024 * 1024), "required": value.get("required", False) # 是否必填 } elif isinstance(value, list): @@ -1562,7 +1620,7 @@ async def convert_file_value_by_type(value: Any, var_type: VariableType, user_su "supported_types": [], "upload_methods": ["manual"], "max_files": len(value), - "max_file_size": 10 * 1024 * 1024, # 默认10MB + "max_file_size": 10 * 1024 * 1024, # 默认10MB "required": False # 默认非必填 } else: @@ -1572,7 +1630,7 @@ async def convert_file_value_by_type(value: Any, var_type: VariableType, user_su "supported_types": [], "upload_methods": ["manual"], "max_files": 10, - "max_file_size": 10 * 1024 * 1024, # 默认10MB + "max_file_size": 10 * 1024 * 1024, # 默认10MB "required": False # 默认非必填 } else: @@ -1582,8 +1640,8 @@ async def convert_file_value_by_type(value: Any, var_type: VariableType, user_su "supported_types": [], "upload_methods": ["manual"], "max_files": 10, - "max_file_size": 10 * 1024 * 1024, # 默认10MB + "max_file_size": 10 * 1024 * 1024, # 默认10MB "required": False # 默认非必填 } - - return value \ No newline at end of file + + return value diff --git a/apps/scheduler/call/reply/direct_reply.py b/apps/scheduler/call/reply/direct_reply.py index 171bdbb1f08d2506ba1e73943b075359ccce71c3..90a4b7c71f588dd2b62a937a0471f3aaf7c0eed5 100644 --- a/apps/scheduler/call/reply/direct_reply.py +++ b/apps/scheduler/call/reply/direct_reply.py @@ -54,11 +54,12 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep """解析附件变量引用,返回变量名到文件ID列表的映射""" variable_file_map = {} pool_manager = await get_pool_manager() - + for var_name, var_info in attachment_dict.items(): # 获取完整的displayName,如 "conversation.file" - display_name = var_info.get('displayName', f"conversation.{var_name}") - + display_name = var_info.get( + 'displayName', f"conversation.{var_name}") + # 解析displayName获取scope和实际变量名 parts = display_name.split('.') if len(parts) >= 2: @@ -67,36 +68,39 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep else: logger.warning(f"[DirectReply] 无效的变量路径: {display_name}") continue - + # 确定变量作用域 try: scope = VariableScope(scope_str) except ValueError: logger.warning(f"[DirectReply] 无效的变量作用域: {scope_str}") continue - + # 从变量池中获取变量 variable = await pool_manager.get_variable_from_any_pool( name=actual_var_name, scope=scope, user_id=getattr(self._sys_vars.ids, 'user_sub', None), flow_id=getattr(self._sys_vars.ids, 'flow_id', None), - conversation_id=getattr(self._sys_vars.ids, 'conversation_id', None) + conversation_id=getattr( + self._sys_vars.ids, 'conversation_id', None) ) - + if variable is None: - logger.warning(f"[DirectReply] 变量不存在: {display_name} (scope: {scope_str}, name: {actual_var_name})") + logger.warning( + f"[DirectReply] 变量不存在: {display_name} (scope: {scope_str}, name: {actual_var_name})") continue - + # 检查是否为文件类型变量 if variable.var_type not in [VariableType.FILE, VariableType.ARRAY_FILE]: - logger.warning(f"[DirectReply] 变量不是文件类型: {display_name}, 类型: {variable.var_type}") + logger.warning( + f"[DirectReply] 变量不是文件类型: {display_name}, 类型: {variable.var_type}") continue - + # 提取文件ID file_ids = [] value = variable.value - + if variable.var_type == VariableType.FILE: # 单文件变量 if isinstance(value, str): @@ -113,39 +117,40 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep file_ids.append(item["file_id"]) elif isinstance(value, dict) and "file_ids" in value: file_ids.extend(value["file_ids"]) - + if file_ids: variable_file_map[display_name] = file_ids - logger.info(f"[DirectReply] 提取到文件变量 {display_name}: {len(file_ids)} 个文件") + logger.info( + f"[DirectReply] 提取到文件变量 {display_name}: {len(file_ids)} 个文件") else: logger.warning(f"[DirectReply] 文件变量 {display_name} 中没有有效的文件ID") - + return variable_file_map - + async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行直接回复""" data = DirectReplyInput(**input_data) - + try: # 使用基类的变量解析功能处理文本中的变量引用 final_answer = await self._resolve_variables_in_text(data.answer, self._sys_vars) - + logger.info(f"[DirectReply] 原始答案: {data.answer}") logger.info(f"[DirectReply] 解析后答案: {final_answer}") - + # 首先返回文本内容 yield CallOutputChunk( - type=CallOutputType.TEXT, + type=CallOutputType.TEXT, content=final_answer ) - + # 处理附件 if data.attachment: logger.info(f"[DirectReply] 开始处理附件: {data.attachment}") - + # 解析变量引用并提取文件ID variable_file_map = await self._parse_variable_references(data.attachment) - + if variable_file_map: # 🔑 重要修改:按变量分组处理文件,支持array[file]类型 for var_name, file_ids in variable_file_map.items(): @@ -154,32 +159,38 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep file_id = file_ids[0] try: # 从MongoDB获取文件信息 - mongo = MongoDB() - doc_collection = mongo.get_collection("document") + doc_collection = MongoDB.get_collection( + "document") doc_info = await doc_collection.find_one({"_id": file_id}) if not doc_info: - logger.warning(f"[DirectReply] 文件信息不存在: {file_id}") + logger.warning( + f"[DirectReply] 文件信息不存在: {file_id}") continue - - filename = doc_info.get("name", f"file_{file_id}") + + filename = doc_info.get( + "name", f"file_{file_id}") file_size = doc_info.get("size", 0) - file_type = doc_info.get("type", "application/octet-stream") - + file_type = doc_info.get( + "type", "application/octet-stream") + # 从MinIO下载文件数据 try: - metadata, file_data = MinioClient.download_file("document", file_id) - + metadata, file_data = MinioClient.download_file( + "document", file_id) + # 从metadata中获取原始文件名(如果存在) if "file_name" in metadata: try: - original_filename = base64.b64decode(metadata["file_name"]).decode('utf-8') + original_filename = base64.b64decode( + metadata["file_name"]).decode('utf-8') filename = original_filename except Exception: pass # 使用默认文件名 - + # 将文件数据编码为base64 - file_content_b64 = base64.b64encode(file_data).decode('utf-8') - + file_content_b64 = base64.b64encode( + file_data).decode('utf-8') + # 单个文件:使用FILE类型 yield CallOutputChunk( type=CallOutputType.FILE, @@ -192,49 +203,58 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep "variable_name": var_name } ) - logger.info(f"[DirectReply] 成功返回单文件: {var_name}/{filename} (ID: {file_id})") - + logger.info( + f"[DirectReply] 成功返回单文件: {var_name}/{filename} (ID: {file_id})") + except Exception as minio_error: - logger.error(f"[DirectReply] 从MinIO下载文件失败: {file_id}, 错误: {minio_error}") + logger.error( + f"[DirectReply] 从MinIO下载文件失败: {file_id}, 错误: {minio_error}") continue - + except Exception as file_error: - logger.error(f"[DirectReply] 处理文件失败: {file_id}, 错误: {file_error}") + logger.error( + f"[DirectReply] 处理文件失败: {file_id}, 错误: {file_error}") continue - + elif len(file_ids) > 1: # 多文件:合并为一个FILES类型的响应 files_data = [] - + for file_id in file_ids: try: # 从MongoDB获取文件信息 - mongo = MongoDB() - doc_collection = mongo.get_collection("document") + doc_collection = MongoDB.get_collection( + "document") doc_info = await doc_collection.find_one({"_id": file_id}) if not doc_info: - logger.warning(f"[DirectReply] 文件信息不存在: {file_id}") + logger.warning( + f"[DirectReply] 文件信息不存在: {file_id}") continue - - filename = doc_info.get("name", f"file_{file_id}") + + filename = doc_info.get( + "name", f"file_{file_id}") file_size = doc_info.get("size", 0) - file_type = doc_info.get("type", "application/octet-stream") - + file_type = doc_info.get( + "type", "application/octet-stream") + # 从MinIO下载文件数据 try: - metadata, file_data = MinioClient.download_file("document", file_id) - + metadata, file_data = MinioClient.download_file( + "document", file_id) + # 从metadata中获取原始文件名(如果存在) if "file_name" in metadata: try: - original_filename = base64.b64decode(metadata["file_name"]).decode('utf-8') + original_filename = base64.b64decode( + metadata["file_name"]).decode('utf-8') filename = original_filename except Exception: pass # 使用默认文件名 - + # 将文件数据编码为base64 - file_content_b64 = base64.b64encode(file_data).decode('utf-8') - + file_content_b64 = base64.b64encode( + file_data).decode('utf-8') + files_data.append({ "file_id": file_id, "filename": filename, @@ -243,15 +263,17 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep "content": file_content_b64, "variable_name": var_name }) - + except Exception as minio_error: - logger.error(f"[DirectReply] 从MinIO下载文件失败: {file_id}, 错误: {minio_error}") + logger.error( + f"[DirectReply] 从MinIO下载文件失败: {file_id}, 错误: {minio_error}") continue - + except Exception as file_error: - logger.error(f"[DirectReply] 处理文件失败: {file_id}, 错误: {file_error}") + logger.error( + f"[DirectReply] 处理文件失败: {file_id}, 错误: {file_error}") continue - + # 如果有成功处理的文件,返回FILES类型 if files_data: yield CallOutputChunk( @@ -262,13 +284,15 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep "files": files_data } ) - logger.info(f"[DirectReply] 成功返回多文件: {var_name} ({len(files_data)} 个文件)") + logger.info( + f"[DirectReply] 成功返回多文件: {var_name} ({len(files_data)} 个文件)") else: logger.warning("[DirectReply] 没有找到有效的文件变量") - + except Exception as e: logger.error(f"[DirectReply] 处理回复内容失败: {e}") raise CallError( - message=f"直接回复处理失败:{e!s}", - data={"original_answer": data.answer, "attachment": data.attachment} + message=f"直接回复处理失败:{e!s}", + data={"original_answer": data.answer, + "attachment": data.attachment} ) from e diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 9e712998859b341425c4759239efd361826b98ad..b95147401ae85008a89df7c744f03dde709100d0 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -58,8 +58,7 @@ class MCPHost: async def get_client(self, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") # 检查用户是否启用了这个mcp mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) @@ -189,8 +188,7 @@ class MCPHost: async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: """获取工具列表""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") # 获取工具列表 tool_list = [] diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 7cbb719990cafc7dfe698815e4d3559017c929e8..3559382b137bfeae782972215c1d6b7dbbbcf652 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -46,7 +46,7 @@ class MCPSelector: ) # 拿到工具 - tool_collection = MongoDB().get_collection("mcp") + tool_collection = MongoDB.get_collection("mcp") llm_tool_list = [] for tool_vec in tool_vecs: diff --git a/apps/scheduler/pool/check.py b/apps/scheduler/pool/check.py index 3a76f6c65a50a0b5364d8fc3498717a9c06e346a..f8f54870e627eb359b4d225af71fcb24ee767743 100644 --- a/apps/scheduler/pool/check.py +++ b/apps/scheduler/pool/check.py @@ -53,10 +53,10 @@ class FileChecker: async def diff(self, check_type: MetadataType) -> tuple[list[str], list[str]]: """生成更新列表和删除列表""" if check_type == MetadataType.APP: - collection = MongoDB().get_collection("app") + collection = MongoDB.get_collection("app") self._dir_path = Path(Config().get_config().deploy.data_dir) / "semantics" / "app" elif check_type == MetadataType.SERVICE: - collection = MongoDB().get_collection("service") + collection = MongoDB.get_collection("service") self._dir_path = Path(Config().get_config().deploy.data_dir) / "semantics" / "service" changed_list = [] diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index aefeda8219fdb02f939f550ca2e3264d50cc8f2b..c63d560370a7c942f32c6ea14dec95602ba0c817 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -113,11 +113,10 @@ class AppLoader: :param app_id: 应用 ID """ - mongo = MongoDB() try: - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") await app_collection.delete_one({"_id": app_id}) # 删除应用数据 - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") # 删除用户使用记录 await user_collection.update_many( {f"app_usage.{app_id}": {"$exists": True}}, @@ -144,9 +143,8 @@ class AppLoader: logger.error(err) raise ValueError(err) # 更新应用数据 - mongo = MongoDB() try: - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") metadata.permission = metadata.permission if metadata.permission else Permission() await app_collection.update_one( {"_id": metadata.id}, diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 45bee6773378d93e3c48841f802b9e4b44318807..7f8e7b312bf87548c0f98a82e27238d34a7f39fe 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -147,9 +147,8 @@ class CallLoader(metaclass=SingletonMeta): async def _delete_from_db(self, call_name: str) -> None: """从数据库中删除单个Call""" # 从MongoDB中删除 - mongo = MongoDB() - call_collection = mongo.get_collection("call") - node_collection = mongo.get_collection("node") + call_collection = MongoDB.get_collection("call") + node_collection = MongoDB.get_collection("node") try: await call_collection.delete_one({"_id": call_name}) await node_collection.delete_many({"call_id": call_name}) @@ -168,9 +167,8 @@ class CallLoader(metaclass=SingletonMeta): async def _add_to_db(self, call_metadata: list[CallPool]) -> None: # noqa: C901 """更新数据库""" # 更新MongoDB - mongo = MongoDB() - call_collection = mongo.get_collection("call") - node_collection = mongo.get_collection("node") + call_collection = MongoDB.get_collection("call") + node_collection = MongoDB.get_collection("node") call_descriptions = [] try: for call in call_metadata: @@ -217,9 +215,8 @@ class CallLoader(metaclass=SingletonMeta): async def load(self) -> None: """初始化Call信息""" # 清空collection - mongo = MongoDB() - call_collection = mongo.get_collection("call") - node_collection = mongo.get_collection("node") + call_collection = MongoDB.get_collection("call") + node_collection = MongoDB.get_collection("node") try: await call_collection.delete_many({}) await node_collection.delete_many({"service_id": ""}) diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index a332a84ceaf5f5a62b4f6e0e842ef0fcb7dda657..332528cc71b7115c4719f0dbc07f0abcb2b60967 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -252,7 +252,7 @@ class FlowLoader: async def _update_db(self, app_id: str, metadata: AppFlow) -> None: # noqa: C901 """更新数据库""" try: - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") # 获取当前的flows app_data = await app_collection.find_one({"_id": app_id}) if not app_data: @@ -392,7 +392,7 @@ class FlowLoader: async def _update_subflow_db(self, app_id: str, flow_id: str, metadata: "AppSubFlow") -> None: """更新数据库中的子工作流元数据""" try: - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") # 查找应用 app_record = await app_collection.find_one({"_id": app_id}) @@ -433,7 +433,7 @@ class FlowLoader: async def _delete_subflow_db(self, app_id: str, flow_id: str, sub_flow_id: str) -> None: """从数据库中删除子工作流元数据""" try: - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") # 从应用的子工作流列表中移除 await app_collection.update_one( diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 9b187689c507e42e9fea6668583d9acb540f0788..be38141659bbaf5d5ef80456729bf8064bf230b6 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -134,7 +134,7 @@ class MCPLoader(metaclass=SingletonMeta): """ 清除状态为ready或failed的MCP安装任务 """ - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") mcp_ids = ProcessHandler.get_all_task_ids() # 检索_id在mcp_ids且状态为ready或者failed的MCP的内容 db_service_list = await mcp_collection.find( @@ -180,8 +180,7 @@ class MCPLoader(metaclass=SingletonMeta): """ template_path = MCP_PATH / "template" logger.info("[MCPLoader] 初始化所有MCP模板: %s", template_path) - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") # 遍历所有模板 mcp_ids = [] async for mcp_dir in template_path.iterdir(): @@ -265,7 +264,7 @@ class MCPLoader(metaclass=SingletonMeta): @staticmethod async def _insert_template_db(mcp_id: str, config: MCPServerConfig) -> None: """插入单个MCP Server模板信息到数据库""" - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") await mcp_collection.update_one( {"_id": mcp_id}, { @@ -294,7 +293,7 @@ class MCPLoader(metaclass=SingletonMeta): tool_list = await MCPLoader._get_template_tool(mcp_id, config) # 基本信息插入数据库 - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") # 清空当前工具列表 await mcp_collection.update_one( {"_id": mcp_id}, @@ -401,8 +400,7 @@ class MCPLoader(metaclass=SingletonMeta): :return: 无 """ # 更新数据库 - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") logger.info("[MCPLoader] 更新MCP模板状态: %s -> %s", mcp_id, status) await mcp_collection.update_one( {"_id": mcp_id}, @@ -468,8 +466,7 @@ class MCPLoader(metaclass=SingletonMeta): ) await f.aclose() # 更新数据库 - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") await mcp_collection.update_one( {"_id": mcp_id}, {"$addToSet": {"activated": user_sub}}, @@ -491,8 +488,7 @@ class MCPLoader(metaclass=SingletonMeta): await asyncer.asyncify(shutil.rmtree)(user_path.as_posix(), ignore_errors=True) # 更新数据库 - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") await mcp_collection.update_one( {"_id": mcp_id}, {"$pull": {"activated": user_sub}}, @@ -508,7 +504,7 @@ class MCPLoader(metaclass=SingletonMeta): """ deleted_mcp_list = [] - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") mcp_list = await mcp_collection.find({}, {"_id": 1}).to_list(None) for db_item in mcp_list: mcp_path: Path = MCP_PATH / "template" / db_item["_id"] @@ -525,8 +521,7 @@ class MCPLoader(metaclass=SingletonMeta): :param list[str] cancel_mcp_list: 需要取消的MCP列表 :return: 无 """ - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") # 更新数据库状态 cancel_mcp_list = await mcp_collection.distinct("_id", {"_id": {"$in": cancel_mcp_list}, "status": MCPInstallStatus.INSTALLING}) await mcp_collection.update_many( @@ -546,7 +541,7 @@ class MCPLoader(metaclass=SingletonMeta): :return: 无 """ # 从MongoDB中移除 - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") mcp_service_list = await mcp_collection.find( {"_id": {"$in": deleted_mcp_list}}, ).to_list(None) @@ -593,8 +588,7 @@ class MCPLoader(metaclass=SingletonMeta): logger.warning("[MCPLoader] users目录不存在,跳过加载用户MCP") return - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") mcp_list = {} # 遍历users目录 diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index d0d20318238d1a809111a147c2f77f37497fcf35..cd32bb021e957d8d8677a3ad48423ec97f1e129a 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -73,9 +73,8 @@ class ServiceLoader: async def delete(self, service_id: str, *, is_reload: bool = False) -> None: """删除Service,并更新数据库""" - mongo = MongoDB() - service_collection = mongo.get_collection("service") - node_collection = mongo.get_collection("node") + service_collection = MongoDB.get_collection("service") + node_collection = MongoDB.get_collection("node") try: await service_collection.delete_one({"_id": service_id}) await node_collection.delete_many({"service_id": service_id}) @@ -102,9 +101,8 @@ class ServiceLoader: logger.error(err) raise ValueError(err) # 更新MongoDB - mongo = MongoDB() - service_collection = mongo.get_collection("service") - node_collection = mongo.get_collection("node") + service_collection = MongoDB.get_collection("service") + node_collection = MongoDB.get_collection("node") try: # 先删除旧的节点 await node_collection.delete_many({"service_id": metadata.id}) diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index 83a2b19c8a5581532b63c4471468fdab14776874..09a89550bac1583311a8f516526342fb18b89758 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -54,8 +54,7 @@ class MCPPool(metaclass=SingletonMeta): async def _validate_user(self, mcp_id: str, user_sub: str) -> bool: """验证用户是否已激活""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub}) return mcp_db_result is not None @@ -86,7 +85,8 @@ class MCPPool(metaclass=SingletonMeta): logger.warning("[MCPPool] 用户 %s 不在池中,无法停止", user_sub) return if mcp_id not in self.pool[user_sub]: - logger.warning("[MCPPool] MCP %s 不在用户 %s 的池中,无法停止", mcp_id, user_sub) + logger.warning( + "[MCPPool] MCP %s 不在用户 %s 的池中,无法停止", mcp_id, user_sub) return await self.pool[user_sub][mcp_id].stop() - del self.pool[user_sub][mcp_id] \ No newline at end of file + del self.pool[user_sub][mcp_id] diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 52e740f3b28c792cd120063461816dde5596ba65..13821f8965f7bb7587217ed653a4bf1a5184e9be 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -131,8 +131,7 @@ class Pool: async def get_flow_metadata(self, app_id: str) -> list[AppFlow]: """从数据库中获取特定App的全部Flow的元数据""" - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") flow_metadata_list = [] try: flow_list = await app_collection.find_one({"_id": app_id}, {"flows": 1}) @@ -166,8 +165,7 @@ class Pool: return Plugin # 从MongoDB里拿到数据 - mongo = MongoDB() - call_collection = mongo.get_collection("call") + call_collection = MongoDB.get_collection("call") call_db_data = await call_collection.find_one({"_id": call_id}) if not call_db_data: err = f"[Pool] Call{call_id}不存在" diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 51d6c7d6ee54abcd48ffef1f55fcacbfeb409f19..6980dbc9f3bb11b318b6d7749db60c517d3a9c1a 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -226,7 +226,7 @@ class Scheduler: logger.error("[Scheduler] 未使用应用中心功能!") return # 获取agent信息 - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") app_metadata = AppPool.model_validate(await app_collection.find_one({"_id": app_info.app_id})) if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") diff --git a/apps/scheduler/variable/pool_base.py b/apps/scheduler/variable/pool_base.py index 79df334690260471bc44259f9a10c91082a20ff3..023f249bc353015f59035d8b5d1b67e2f51291d3 100644 --- a/apps/scheduler/variable/pool_base.py +++ b/apps/scheduler/variable/pool_base.py @@ -15,10 +15,10 @@ logger = logging.getLogger(__name__) class BaseVariablePool(ABC): """变量池基类""" - + def __init__(self, pool_id: str, scope: VariableScope): """初始化变量池 - + Args: pool_id: 池标识符(如user_id、flow_id、conversation_id等) scope: 池作用域 @@ -28,12 +28,12 @@ class BaseVariablePool(ABC): self._variables: Dict[str, BaseVariable] = {} self._initialized = False self._lock = asyncio.Lock() - + @property def is_initialized(self) -> bool: """检查是否已初始化""" return self._initialized - + async def initialize(self): """初始化变量池""" async with self._lock: @@ -41,37 +41,38 @@ class BaseVariablePool(ABC): await self._load_variables() await self._setup_default_variables() self._initialized = True - logger.info(f"已初始化变量池: {self.__class__.__name__}({self.pool_id})") - + logger.info( + f"已初始化变量池: {self.__class__.__name__}({self.pool_id})") + @abstractmethod async def _load_variables(self): """从存储加载变量""" pass - + @abstractmethod async def _setup_default_variables(self): """设置默认变量""" pass - + @abstractmethod def can_modify(self) -> bool: """检查是否允许修改变量""" pass - - async def add_variable(self, - name: str, - var_type: VariableType, - value: Any = None, - description: Optional[str] = None, - created_by: Optional[str] = None, - is_system: bool = False) -> BaseVariable: + + async def add_variable(self, + name: str, + var_type: VariableType, + value: Any = None, + description: Optional[str] = None, + created_by: Optional[str] = None, + is_system: bool = False) -> BaseVariable: """添加变量""" if not self.can_modify(): raise PermissionError(f"不允许修改{self.scope.value}级变量") - + if name in self._variables: raise ValueError(f"变量 {name} 已存在") - + # 创建变量元数据 metadata = VariableMetadata( name=name, @@ -84,41 +85,42 @@ class BaseVariablePool(ABC): created_by=created_by or "system", is_system=is_system # 标记是否为系统变量 ) - + # 创建变量 variable = create_variable(metadata, value) self._variables[name] = variable - + # 持久化 await self._persist_variable(variable) - - logger.info(f"已添加{'系统' if is_system else ''}变量: {name} 到池 {self.pool_id}") + + logger.info( + f"已添加{'系统' if is_system else ''}变量: {name} 到池 {self.pool_id}") return variable - - async def update_variable(self, - name: str, - value: Optional[Any] = None, - var_type: Optional[VariableType] = None, - description: Optional[str] = None, - force_system_update: bool = False) -> BaseVariable: + + async def update_variable(self, + name: str, + value: Optional[Any] = None, + var_type: Optional[VariableType] = None, + description: Optional[str] = None, + force_system_update: bool = False) -> BaseVariable: """更新变量值、类型或描述""" if not self.can_modify() and not force_system_update: raise PermissionError(f"不允许修改{self.scope.value}级变量") - + if name not in self._variables: raise ValueError(f"变量 {name} 不存在") - + variable = self._variables[name] - + # 检查是否为系统变量(除非强制更新) - if (hasattr(variable.metadata, 'is_system') and - variable.metadata.is_system and - not force_system_update): + if (hasattr(variable.metadata, 'is_system') and + variable.metadata.is_system and + not force_system_update): raise PermissionError(f"系统变量 {name} 不允许修改") - + # 🔑 新增:对于文件类型变量,在更新前清理旧文件资源 old_file_ids = await self._get_file_ids_from_variable(variable) - + # 🔑 重要:如果类型改变,需要重新创建变量对象 if var_type is not None and var_type != variable.metadata.var_type: from .variables import create_variable @@ -127,10 +129,10 @@ class BaseVariablePool(ABC): old_metadata.var_type = var_type if description is not None: old_metadata.description = description - + # 创建新类型的变量对象 variable = create_variable(old_metadata, value) - + # 更新到字典中 self._variables[name] = variable else: @@ -139,54 +141,54 @@ class BaseVariablePool(ABC): variable.metadata.description = description if value is not None: variable.value = value - + # 🔑 新增:清理被替换的文件 if value is not None: new_file_ids = await self._get_file_ids_from_variable(variable) await self._cleanup_replaced_files(variable, old_file_ids, new_file_ids) - + # 持久化到数据库 await self._persist_variable(variable) - + return variable - + async def delete_variable(self, name: str) -> bool: """删除变量""" if not self.can_modify(): raise PermissionError(f"不允许修改{self.scope.value}级变量") - + if name not in self._variables: return False - + variable = self._variables[name] - + # 检查是否为系统变量 if hasattr(variable.metadata, 'is_system') and variable.metadata.is_system: raise PermissionError(f"系统变量 {name} 不允许删除") - + # 🔑 新增:对于文件类型变量,清理关联的文件资源 await self._cleanup_file_resources_if_needed(variable) - + del self._variables[name] - + # 从数据库删除 await self._delete_variable_from_db(variable) - + return True - + async def _cleanup_file_resources_if_needed(self, variable: BaseVariable) -> None: """如果变量是文件类型,清理关联的文件资源(但保护已绑定历史记录的文件)""" try: from .type import VariableType - + if variable.metadata.var_type not in [VariableType.FILE, VariableType.ARRAY_FILE]: return - + if not isinstance(variable.value, dict): return - + file_ids_to_cleanup = [] - + if variable.metadata.var_type == VariableType.FILE: file_id = variable.value.get("file_id") if file_id: @@ -194,39 +196,40 @@ class BaseVariablePool(ABC): else: # ARRAY_FILE file_ids = variable.value.get("file_ids", []) file_ids_to_cleanup.extend(file_ids) - + if file_ids_to_cleanup: # 🔑 修正:检查文件是否已绑定历史记录 protected_file_ids = await self._get_protected_file_ids(file_ids_to_cleanup) - actual_cleanup_ids = [fid for fid in file_ids_to_cleanup if fid not in protected_file_ids] - + actual_cleanup_ids = [ + fid for fid in file_ids_to_cleanup if fid not in protected_file_ids] + if actual_cleanup_ids: user_id = getattr(variable.metadata, 'created_by', None) if user_id: from apps.services.document import DocumentManager await DocumentManager.delete_document(user_id, actual_cleanup_ids) - + if protected_file_ids: - logger.info(f"保护了变量 {variable.name} 中 {len(protected_file_ids)} 个已绑定历史记录的文件") + logger.info( + f"保护了变量 {variable.name} 中 {len(protected_file_ids)} 个已绑定历史记录的文件") else: logger.warning(f"无法确定变量 {variable.name} 的创建者,跳过文件清理") else: logger.info(f"变量 {variable.name} 的所有文件都已绑定历史记录,跳过清理") - + except Exception as e: logger.error(f"清理变量 {variable.name} 的文件资源失败: {e}") # 不抛出异常,避免影响变量删除流程 - + async def _get_protected_file_ids(self, file_ids: list[str]) -> set[str]: """获取已经绑定到历史记录的文件ID列表""" try: from apps.common.mongo import MongoDB - - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - + + record_group_collection = MongoDB.get_collection("record_group") + protected_ids = set() - + # 查询所有RecordGroup中绑定的文件 async for record_group in record_group_collection.find( {"docs.id": {"$in": file_ids}}, @@ -237,36 +240,36 @@ class BaseVariablePool(ABC): doc_id = doc.get("id") or doc.get("_id") if doc_id in file_ids: protected_ids.add(doc_id) - + return protected_ids - + except Exception as e: logger.error(f"检查文件历史记录绑定状态失败: {e}") # 出错时保护所有文件,避免误删 return set(file_ids) - + async def get_variable(self, name: str) -> Optional[BaseVariable]: """获取变量""" return self._variables.get(name) - + async def list_variables(self, include_system: bool = True) -> List[BaseVariable]: """列出所有变量""" if include_system: return list(self._variables.values()) else: # 只返回非系统变量 - return [var for var in self._variables.values() - if not (hasattr(var.metadata, 'is_system') and var.metadata.is_system)] - + return [var for var in self._variables.values() + if not (hasattr(var.metadata, 'is_system') and var.metadata.is_system)] + async def list_system_variables(self) -> List[BaseVariable]: """列出系统变量""" - return [var for var in self._variables.values() - if hasattr(var.metadata, 'is_system') and var.metadata.is_system] - + return [var for var in self._variables.values() + if hasattr(var.metadata, 'is_system') and var.metadata.is_system] + async def has_variable(self, name: str) -> bool: """检查变量是否存在""" return name in self._variables - + async def copy_variables(self) -> Dict[str, BaseVariable]: """拷贝所有变量""" copied = {} @@ -286,60 +289,60 @@ class BaseVariablePool(ABC): # 创建新的变量实例 copied[name] = create_variable(new_metadata, variable.value) return copied - + async def _persist_variable(self, variable: BaseVariable): """持久化变量""" try: - collection = MongoDB().get_collection("variables") + collection = MongoDB.get_collection("variables") data = variable.serialize() - + # 构建查询条件 query = { "metadata.name": variable.name, "metadata.scope": variable.scope.value } - + # 添加池特定的查询条件 self._add_pool_query_conditions(query, variable) - + # 更新或插入 from pymongo import WriteConcern result = await collection.with_options( write_concern=WriteConcern(w="majority", j=True) ).replace_one(query, data, upsert=True) - + if not (result.acknowledged and (result.matched_count > 0 or result.upserted_id)): raise RuntimeError(f"变量持久化失败: {variable.name}") - + except Exception as e: logger.error(f"持久化变量失败: {e}") raise - + async def _delete_variable_from_db(self, variable: BaseVariable): """从数据库删除变量""" try: - collection = MongoDB().get_collection("variables") - + collection = MongoDB.get_collection("variables") + query = { "metadata.name": variable.name, "metadata.scope": variable.scope.value } - + # 添加池特定的查询条件 self._add_pool_query_conditions(query, variable) - + from pymongo import WriteConcern result = await collection.with_options( write_concern=WriteConcern(w="majority", j=True) ).delete_one(query) - + if not result.acknowledged: raise RuntimeError(f"变量删除失败: {variable.name}") - + except Exception as e: logger.error(f"删除变量失败: {e}") raise - + @abstractmethod def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): """添加池特定的查询条件""" @@ -350,13 +353,13 @@ class BaseVariablePool(ABC): try: from .type import VariableType from .file_utils import FileVariableHelper - + if variable.metadata.var_type not in [VariableType.FILE, VariableType.ARRAY_FILE]: return [] - + if not isinstance(variable.value, dict): return [] - + if variable.metadata.var_type == VariableType.FILE: # 使用辅助函数统一处理 file_id = FileVariableHelper.get_file_id(variable.value) @@ -364,35 +367,39 @@ class BaseVariablePool(ABC): else: # ARRAY_FILE # 使用辅助函数统一处理,包含兼容性逻辑 return FileVariableHelper.get_file_ids(variable.value) - + except Exception as e: logger.error(f"提取变量 {variable.name} 的文件ID失败: {e}") return [] - + async def _cleanup_replaced_files(self, variable: BaseVariable, old_file_ids: list[str], new_file_ids: list[str]) -> None: """清理被替换的文件(不在新文件列表中的旧文件,但保护已绑定历史记录的文件)""" try: # 找出被替换的文件ID - replaced_file_ids = [fid for fid in old_file_ids if fid not in new_file_ids] - + replaced_file_ids = [ + fid for fid in old_file_ids if fid not in new_file_ids] + if replaced_file_ids: # 🔑 修正:检查被替换的文件是否已绑定历史记录 protected_file_ids = await self._get_protected_file_ids(replaced_file_ids) - actual_cleanup_ids = [fid for fid in replaced_file_ids if fid not in protected_file_ids] - + actual_cleanup_ids = [ + fid for fid in replaced_file_ids if fid not in protected_file_ids] + if actual_cleanup_ids: user_id = getattr(variable.metadata, 'created_by', None) if user_id: from apps.services.document import DocumentManager await DocumentManager.delete_document(user_id, actual_cleanup_ids) - + if protected_file_ids: - logger.info(f"保护了变量 {variable.name} 中 {len(protected_file_ids)} 个已绑定历史记录的文件") + logger.info( + f"保护了变量 {variable.name} 中 {len(protected_file_ids)} 个已绑定历史记录的文件") else: - logger.warning(f"无法确定变量 {variable.name} 的创建者,跳过被替换文件的清理") + logger.warning( + f"无法确定变量 {variable.name} 的创建者,跳过被替换文件的清理") else: logger.info(f"变量 {variable.name} 被替换的文件都已绑定历史记录,跳过清理") - + except Exception as e: logger.error(f"清理变量 {variable.name} 被替换的文件失败: {e}") # 不抛出异常,避免影响变量更新流程 @@ -400,20 +407,20 @@ class BaseVariablePool(ABC): class UserVariablePool(BaseVariablePool): """用户变量池""" - + def __init__(self, user_id: str): super().__init__(user_id, VariableScope.USER) self.user_id = user_id - + async def _load_variables(self): """从数据库加载用户变量""" try: - collection = MongoDB().get_collection("variables") + collection = MongoDB.get_collection("variables") cursor = collection.find({ "metadata.scope": VariableScope.USER.value, "metadata.user_sub": self.user_id }) - + loaded_count = 0 async for doc in cursor: try: @@ -428,20 +435,20 @@ class UserVariablePool(BaseVariablePool): except Exception as e: var_name = doc.get("metadata", {}).get("name", "unknown") logger.warning(f"用户变量 {var_name} 数据损坏: {e}") - + logger.debug(f"用户 {self.user_id} 加载变量完成: {loaded_count} 个") - + except Exception as e: logger.error(f"加载用户变量失败: {e}") - + async def _setup_default_variables(self): """用户变量池不需要默认变量""" pass - + def can_modify(self) -> bool: """用户变量允许修改""" return True - + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): """添加用户变量池的查询条件""" query["metadata.user_sub"] = self.user_id @@ -449,29 +456,30 @@ class UserVariablePool(BaseVariablePool): class FlowVariablePool(BaseVariablePool): """流程变量池(环境变量 + 系统变量模板 + 对话变量模板)""" - + def __init__(self, flow_id: str, parent_flow_id: Optional[str] = None): super().__init__(flow_id, VariableScope.ENVIRONMENT) # 保持主要scope为ENVIRONMENT self.flow_id = flow_id self.parent_flow_id = parent_flow_id - + # 分别存储不同类型的变量 # _variables 继续存储环境变量(保持向后兼容) self._system_templates: Dict[str, BaseVariable] = {} # 系统变量模板 self._conversation_templates: Dict[str, BaseVariable] = {} # 对话变量模板 - + async def _load_variables(self): """从数据库加载所有类型的变量(环境变量 + 模板变量)""" try: - collection = MongoDB().get_collection("variables") - loaded_counts = {"environment": 0, "system_templates": 0, "conversation_templates": 0} - + collection = MongoDB.get_collection("variables") + loaded_counts = {"environment": 0, + "system_templates": 0, "conversation_templates": 0} + # 1. 加载环境变量 env_cursor = collection.find({ "metadata.scope": VariableScope.ENVIRONMENT.value, "metadata.flow_id": self.flow_id }) - + async for doc in env_cursor: try: variable_class_name = doc.get("class") @@ -485,14 +493,14 @@ class FlowVariablePool(BaseVariablePool): except Exception as e: var_name = doc.get("metadata", {}).get("name", "unknown") logger.warning(f"环境变量 {var_name} 数据损坏: {e}") - + # 2. 加载系统变量模板 system_template_cursor = collection.find({ "metadata.scope": VariableScope.SYSTEM.value, "metadata.flow_id": self.flow_id, "metadata.is_template": True }) - + async for doc in system_template_cursor: try: variable_class_name = doc.get("class") @@ -506,14 +514,14 @@ class FlowVariablePool(BaseVariablePool): except Exception as e: var_name = doc.get("metadata", {}).get("name", "unknown") logger.warning(f"系统变量模板 {var_name} 数据损坏: {e}") - + # 3. 加载对话变量模板 conv_template_cursor = collection.find({ "metadata.scope": VariableScope.CONVERSATION.value, "metadata.flow_id": self.flow_id, "metadata.is_template": True }) - + async for doc in conv_template_cursor: try: variable_class_name = doc.get("class") @@ -527,19 +535,19 @@ class FlowVariablePool(BaseVariablePool): except Exception as e: var_name = doc.get("metadata", {}).get("name", "unknown") logger.warning(f"对话变量模板 {var_name} 数据损坏: {e}") - + total_loaded = sum(loaded_counts.values()) logger.debug(f"流程 {self.flow_id} 加载变量完成: 环境变量{loaded_counts['environment']}个, " - f"系统模板{loaded_counts['system_templates']}个, " - f"对话模板{loaded_counts['conversation_templates']}个, 总计{total_loaded}个") - + f"系统模板{loaded_counts['system_templates']}个, " + f"对话模板{loaded_counts['conversation_templates']}个, 总计{total_loaded}个") + except Exception as e: logger.error(f"加载流程变量失败: {e}") - + async def _setup_default_variables(self): """设置默认的系统变量模板""" from datetime import datetime, UTC - + # 定义系统变量模板(这些是模板,不是实例) system_var_templates = [ ("query", VariableType.STRING, "用户查询内容", ""), @@ -552,7 +560,7 @@ class FlowVariablePool(BaseVariablePool): ("conversation_id", VariableType.STRING, "对话ID", ""), ("timestamp", VariableType.NUMBER, "当前时间戳", 0), ] - + created_count = 0 for var_name, var_type, description, default_value in system_var_templates: # 如果系统变量模板不存在,才创建 @@ -569,7 +577,7 @@ class FlowVariablePool(BaseVariablePool): ) variable = create_variable(metadata, default_value) self._system_templates[var_name] = variable - + # 持久化模板到数据库 try: await self._persist_variable(variable) @@ -577,30 +585,30 @@ class FlowVariablePool(BaseVariablePool): logger.debug(f"已持久化系统变量模板: {var_name}") except Exception as e: logger.error(f"持久化系统变量模板失败: {var_name} - {e}") - + if created_count > 0: logger.info(f"已为流程 {self.flow_id} 初始化 {created_count} 个系统变量模板") - + def can_modify(self) -> bool: """环境变量允许修改""" return True - + # === 系统变量模板相关方法 === - + async def get_system_template(self, name: str) -> Optional[BaseVariable]: """获取系统变量模板""" return self._system_templates.get(name) - + async def list_system_templates(self) -> List[BaseVariable]: """列出所有系统变量模板""" return list(self._system_templates.values()) - - async def add_system_template(self, name: str, var_type: VariableType, + + async def add_system_template(self, name: str, var_type: VariableType, default_value: Any = None, description: str = None) -> BaseVariable: """添加系统变量模板""" if name in self._system_templates: raise ValueError(f"系统变量模板 {name} 已存在") - + metadata = VariableMetadata( name=name, var_type=var_type, @@ -611,33 +619,33 @@ class FlowVariablePool(BaseVariablePool): is_system=True, is_template=True ) - + variable = create_variable(metadata, default_value) self._system_templates[name] = variable - + # 持久化到数据库 await self._persist_variable(variable) - + logger.info(f"已添加系统变量模板: {name} 到流程 {self.flow_id}") return variable - + # === 对话变量模板相关方法 === - + async def get_conversation_template(self, name: str) -> Optional[BaseVariable]: """获取对话变量模板""" return self._conversation_templates.get(name) - + async def list_conversation_templates(self) -> List[BaseVariable]: """列出所有对话变量模板""" return list(self._conversation_templates.values()) - + async def add_conversation_template(self, name: str, var_type: VariableType, - default_value: Any = None, description: str = None, - created_by: str = None) -> BaseVariable: + default_value: Any = None, description: str = None, + created_by: str = None) -> BaseVariable: """添加对话变量模板""" if name in self._conversation_templates: raise ValueError(f"对话变量模板 {name} 已存在") - + metadata = VariableMetadata( name=name, var_type=var_type, @@ -648,23 +656,23 @@ class FlowVariablePool(BaseVariablePool): is_system=False, is_template=True ) - + variable = create_variable(metadata, default_value) self._conversation_templates[name] = variable - + # 持久化到数据库 await self._persist_variable(variable) - + # 🔑 重要:清除对话模板池缓存,确保下次继承时使用最新的模板 from .pool_manager import get_pool_manager pool_manager = await get_pool_manager() pool_manager.clear_conversation_template_cache(self.flow_id) - + logger.info(f"已添加对话变量模板: {name} 到流程 {self.flow_id}") return variable - + # === 重写基类方法支持多scope查询 === - + async def get_variable_by_scope(self, name: str, scope: VariableScope) -> Optional[BaseVariable]: """根据作用域获取变量""" if scope == VariableScope.ENVIRONMENT: @@ -675,7 +683,7 @@ class FlowVariablePool(BaseVariablePool): return self._conversation_templates.get(name) else: return None - + async def list_variables_by_scope(self, scope: VariableScope) -> List[BaseVariable]: """根据作用域列出变量""" if scope == VariableScope.ENVIRONMENT: @@ -686,27 +694,27 @@ class FlowVariablePool(BaseVariablePool): return list(self._conversation_templates.values()) else: return [] - + # === 重写基类方法支持多字典操作 === - - async def update_variable(self, name: str, value: Any = None, - var_type: Optional[VariableType] = None, - description: Optional[str] = None, - force_system_update: bool = False) -> BaseVariable: + + async def update_variable(self, name: str, value: Any = None, + var_type: Optional[VariableType] = None, + description: Optional[str] = None, + force_system_update: bool = False) -> BaseVariable: """更新变量(支持多字典查找)""" - + # 先在环境变量中查找 if name in self._variables: return await super().update_variable(name, value, var_type, description, force_system_update) - + # 在系统变量模板中查找 elif name in self._system_templates: variable = self._system_templates[name] - + # 检查权限 if not force_system_update and getattr(variable.metadata, 'is_system', False): raise PermissionError(f"系统变量 {name} 不允许直接修改") - + # 🔑 重要:如果类型改变,需要重新创建变量对象 if var_type is not None and var_type != variable.metadata.var_type: from .variables import create_variable @@ -715,14 +723,14 @@ class FlowVariablePool(BaseVariablePool): old_metadata.var_type = var_type if description is not None: old_metadata.description = description - + # 创建新类型的变量对象 variable = create_variable(old_metadata, value) - + # 更新时间戳 from datetime import datetime, UTC variable.metadata.updated_at = datetime.now(UTC) - + # 更新到字典中 self._system_templates[name] = variable else: @@ -731,19 +739,19 @@ class FlowVariablePool(BaseVariablePool): variable.metadata.description = description if value is not None: variable.value = value - + # 更新时间戳 from datetime import datetime, UTC variable.metadata.updated_at = datetime.now(UTC) - + # 持久化 await self._persist_variable(variable) return variable - + # 在对话变量模板中查找 elif name in self._conversation_templates: variable = self._conversation_templates[name] - + # 🔑 重要:如果类型改变,需要重新创建变量对象 if var_type is not None and var_type != variable.metadata.var_type: from .variables import create_variable @@ -752,14 +760,14 @@ class FlowVariablePool(BaseVariablePool): old_metadata.var_type = var_type if description is not None: old_metadata.description = description - + # 创建新类型的变量对象 variable = create_variable(old_metadata, value) - + # 更新时间戳 from datetime import datetime, UTC variable.metadata.updated_at = datetime.now(UTC) - + # 更新到字典中 self._conversation_templates[name] = variable else: @@ -768,75 +776,75 @@ class FlowVariablePool(BaseVariablePool): variable.metadata.description = description if value is not None: variable.value = value - + # 更新时间戳 from datetime import datetime, UTC variable.metadata.updated_at = datetime.now(UTC) - + # 持久化 await self._persist_variable(variable) - + # 🔑 重要:清除对话模板池缓存,确保下次继承时使用最新的模板 from .pool_manager import get_pool_manager pool_manager = await get_pool_manager() pool_manager.clear_conversation_template_cache(self.flow_id) - + return variable - + else: raise ValueError(f"变量 {name} 不存在") - + async def delete_variable(self, name: str) -> bool: """删除变量(支持多字典查找)""" - + # 先在环境变量中查找 if name in self._variables: return await super().delete_variable(name) - + # 在系统变量模板中查找 elif name in self._system_templates: variable = self._system_templates[name] - + # 检查权限 if getattr(variable.metadata, 'is_system', False): raise PermissionError(f"系统变量模板 {name} 不允许删除") - + del self._system_templates[name] await self._delete_variable_from_db(variable) return True - + # 在对话变量模板中查找 elif name in self._conversation_templates: variable = self._conversation_templates[name] del self._conversation_templates[name] await self._delete_variable_from_db(variable) return True - + else: return False - + async def get_variable(self, name: str) -> Optional[BaseVariable]: """获取变量(支持多字典查找)""" - + # 先在环境变量中查找 if name in self._variables: return self._variables[name] - + # 在系统变量模板中查找 elif name in self._system_templates: return self._system_templates[name] - + # 在对话变量模板中查找 elif name in self._conversation_templates: return self._conversation_templates[name] - + else: return None - + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): """添加环境变量池的查询条件""" query["metadata.flow_id"] = self.flow_id - + async def inherit_from_parent(self, parent_pool: "FlowVariablePool"): """从父流程继承环境变量""" parent_variables = await parent_pool.copy_variables() @@ -846,27 +854,28 @@ class FlowVariablePool(BaseVariablePool): self._variables[name] = variable # 持久化继承的变量 await self._persist_variable(variable) - - logger.info(f"流程 {self.flow_id} 从父流程 {parent_pool.flow_id} 继承了 {len(parent_variables)} 个环境变量") + + logger.info( + f"流程 {self.flow_id} 从父流程 {parent_pool.flow_id} 继承了 {len(parent_variables)} 个环境变量") class ConversationVariablePool(BaseVariablePool): """对话变量池 - 包含系统变量和对话变量""" - + def __init__(self, conversation_id: str, flow_id: str): super().__init__(conversation_id, VariableScope.CONVERSATION) self.conversation_id = conversation_id self.flow_id = flow_id - + async def _load_variables(self): """从数据库加载对话变量""" try: - collection = MongoDB().get_collection("variables") + collection = MongoDB.get_collection("variables") cursor = collection.find({ "metadata.scope": VariableScope.CONVERSATION.value, "metadata.conversation_id": self.conversation_id }) - + loaded_count = 0 async for doc in cursor: try: @@ -881,26 +890,26 @@ class ConversationVariablePool(BaseVariablePool): except Exception as e: var_name = doc.get("metadata", {}).get("name", "unknown") logger.warning(f"对话变量 {var_name} 数据损坏: {e}") - + logger.debug(f"对话 {self.conversation_id} 加载变量完成: {loaded_count} 个") - + except Exception as e: logger.error(f"加载对话变量失败: {e}") - + async def _setup_default_variables(self): """从flow模板继承系统变量和对话变量""" from .pool_manager import get_pool_manager - + try: pool_manager = await get_pool_manager() flow_pool = await pool_manager.get_flow_pool(self.flow_id) - + if not flow_pool: logger.warning(f"未找到流程池 {self.flow_id},无法继承变量模板") return - + created_count = 0 - + # 1. 从系统变量模板创建系统变量实例 system_templates = await flow_pool.list_system_templates() for template in system_templates: @@ -917,11 +926,11 @@ class ConversationVariablePool(BaseVariablePool): is_system=True, # 标记为系统变量 is_template=False # 这是实例,不是模板 ) - + # 使用模板的默认值创建实例 variable = create_variable(metadata, template.value) self._variables[template.name] = variable - + # 持久化系统变量实例 try: await self._persist_variable(variable) @@ -929,7 +938,7 @@ class ConversationVariablePool(BaseVariablePool): logger.debug(f"已从模板创建系统变量实例: {template.name}") except Exception as e: logger.error(f"持久化系统变量实例失败: {template.name} - {e}") - + # 2. 从对话变量模板创建对话变量实例 conversation_templates = await flow_pool.list_conversation_templates() for template in conversation_templates: @@ -946,11 +955,11 @@ class ConversationVariablePool(BaseVariablePool): is_system=False, # 对话变量 is_template=False # 这是实例,不是模板 ) - + # 使用模板的默认值创建实例 variable = create_variable(metadata, template.value) self._variables[template.name] = variable - + # 持久化对话变量实例 try: await self._persist_variable(variable) @@ -958,22 +967,23 @@ class ConversationVariablePool(BaseVariablePool): logger.debug(f"已从模板创建对话变量实例: {template.name}") except Exception as e: logger.error(f"持久化对话变量实例失败: {template.name} - {e}") - + if created_count > 0: - logger.info(f"已为对话 {self.conversation_id} 从流程模板继承 {created_count} 个变量") - + logger.info( + f"已为对话 {self.conversation_id} 从流程模板继承 {created_count} 个变量") + except Exception as e: logger.error(f"从流程模板继承变量失败: {e}") - + def can_modify(self) -> bool: """对话变量允许修改""" return True - + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): """添加对话变量池的查询条件""" query["metadata.conversation_id"] = self.conversation_id query["metadata.flow_id"] = self.flow_id - + async def update_system_variable(self, name: str, value: Any) -> bool: """更新系统变量的值(系统内部调用)""" try: @@ -982,23 +992,23 @@ class ConversationVariablePool(BaseVariablePool): except Exception as e: logger.error(f"更新系统变量失败: {name} - {e}") return False - + async def inherit_from_conversation_template(self, template_pool: Optional["ConversationVariablePool"] = None): """从对话模板池继承变量(如果存在)""" if template_pool: template_variables = await template_pool.copy_variables() inherited_count = 0 - + for name, template_variable in template_variables.items(): # 只继承非系统变量 if not (hasattr(template_variable.metadata, 'is_system') and template_variable.metadata.is_system): # 🔑 重要修复:检查变量是否已存在,如果存在则保留现有实例 if name in self._variables: continue - + # 创建新的变量实例(从模板创建实例) from .variables import create_variable - + # 创建新的metadata,将模板转换为实例 instance_metadata = VariableMetadata( name=template_variable.name, @@ -1011,12 +1021,13 @@ class ConversationVariablePool(BaseVariablePool): is_system=False, # 对话变量实例 is_template=False # 这是实例,不是模板 ) - + # 使用模板的值创建实例 - instance_variable = create_variable(instance_metadata, template_variable.value) + instance_variable = create_variable( + instance_metadata, template_variable.value) self._variables[name] = instance_variable inherited_count += 1 - + # 持久化实例到数据库 try: await self._persist_variable(instance_variable) @@ -1024,5 +1035,6 @@ class ConversationVariablePool(BaseVariablePool): logger.debug(f"已从模板继承并持久化对话变量实例: {name}") except Exception as e: logger.error(f"持久化继承的对话变量实例失败: {name} - {e}") - - logger.info(f"对话 {self.conversation_id} 从模板继承了 {inherited_count} 个变量") \ No newline at end of file + + logger.info( + f"对话 {self.conversation_id} 从模板继承了 {inherited_count} 个变量") diff --git a/apps/scheduler/variable/pool_manager.py b/apps/scheduler/variable/pool_manager.py index 78691a5591382d19d7bd41e9765ee51ec4193c71..769e623b9536c7cec516b800a6dd562b23a9ffa6 100644 --- a/apps/scheduler/variable/pool_manager.py +++ b/apps/scheduler/variable/pool_manager.py @@ -53,7 +53,7 @@ class VariablePoolManager: try: # 这里应该从相应的用户和流程数据库表中加载 # 目前先从变量表中推断存在的实体 - collection = MongoDB().get_collection("variables") + collection = MongoDB.get_collection("variables") # 获取所有唯一的用户ID user_ids = await collection.distinct("metadata.user_sub", { @@ -299,7 +299,7 @@ class VariablePoolManager: return self._conversation_template_pools[flow_id] # 从MongoDB查询对话变量模板 (只查询is_template=True的) - collection = MongoDB().get_collection("variables") + collection = MongoDB.get_collection("variables") cursor = collection.find({ "metadata.scope": VariableScope.CONVERSATION.value, "metadata.flow_id": flow_id, diff --git a/apps/scheduler/variable/security.py b/apps/scheduler/variable/security.py index 07fde158c01e7bf9375d7f2a555e3bc2d5b1806b..32c88350af1f9c3937c2b3454bbf308910f7affa 100644 --- a/apps/scheduler/variable/security.py +++ b/apps/scheduler/variable/security.py @@ -256,7 +256,7 @@ class SecretVariableSecurity: List[Dict[str, Any]]: 访问日志列表 """ try: - collection = MongoDB().get_collection("variable_access_logs") + collection = MongoDB.get_collection("variable_access_logs") # 构建查询条件 query = { @@ -306,7 +306,7 @@ class SecretVariableSecurity: """检查访问频率限制""" try: # 获取最近1分钟的访问记录 - collection = MongoDB().get_collection("variable_access_logs") + collection = MongoDB.get_collection("variable_access_logs") count = await collection.count_documents({ "user_sub": user_sub, "access_time": { @@ -366,7 +366,7 @@ class SecretVariableSecurity: } # 保存到数据库 - collection = MongoDB().get_collection("variable_access_logs") + collection = MongoDB.get_collection("variable_access_logs") await collection.insert_one(log_entry) except Exception as e: @@ -375,7 +375,7 @@ class SecretVariableSecurity: async def _save_audit_record(self, audit_record: Dict[str, Any]): """保存审计记录""" try: - collection = MongoDB().get_collection("variable_audit_logs") + collection = MongoDB.get_collection("variable_audit_logs") await collection.insert_one(audit_record) except Exception as e: logger.error(f"保存审计记录失败: {e}") @@ -386,13 +386,13 @@ class SecretVariableSecurity: cutoff_time = datetime.now(UTC) - timedelta(days=self.audit_retention_days) # 清理访问日志 - access_collection = MongoDB().get_collection("variable_access_logs") + access_collection = MongoDB.get_collection("variable_access_logs") result1 = await access_collection.delete_many({ "access_time": {"$lt": cutoff_time} }) # 清理审计日志 - audit_collection = MongoDB().get_collection("variable_audit_logs") + audit_collection = MongoDB.get_collection("variable_audit_logs") result2 = await audit_collection.delete_many({ "access_time": {"$lt": cutoff_time} }) diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index fd6912027fea007460344437d9f7138bf2a16b83..51554c731e4d66a42166dd92a8cb92b4444a8edb 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -103,7 +103,8 @@ class LLMIteam(BaseModel): icon: str = Field(default=llm_provider_dict["ollama"]["icon"]) llm_id: str = Field(alias="llmId", default="") - model_name: str = Field(alias="modelName", default=Config().get_config().llm.model) + model_name: str = Field( + alias="modelName", default=Config().get_config().llm.model) class KbIteam(BaseModel): @@ -272,7 +273,8 @@ class TeamKnowledgeBaseItem(BaseModel): class ListTeamKnowledgeMsg(BaseModel): """GET /api/knowledge Result数据结构""" - team_kb_list: list[TeamKnowledgeBaseItem] = Field(default=[], alias="teamKbList", description="团队知识库列表") + team_kb_list: list[TeamKnowledgeBaseItem] = Field( + default=[], alias="teamKbList", description="团队知识库列表") class ListTeamKnowledgeRsp(ResponseData): @@ -298,7 +300,8 @@ class GetAppPropertyMsg(AppData): app_id: str = Field(..., alias="appId", description="应用ID") published: bool = Field(..., description="是否已发布") - mcp_service: list[AppMcpServiceInfo] = Field(default=[], alias="mcpService", description="MCP服务信息列表") + mcp_service: list[AppMcpServiceInfo] = Field( + default=[], alias="mcpService", description="MCP服务信息列表") llm: LLMIteam | None = Field(alias="llm", default=None) @@ -413,8 +416,10 @@ class GetServiceDetailMsg(BaseModel): service_id: str = Field(..., alias="serviceId", description="服务ID") name: str = Field(..., description="服务名称") - apis: list[ServiceApiData] | None = Field(default=None, description="解析后的接口列表") - data: dict[str, Any] | None = Field(default=None, description="YAML 内容数据对象") + apis: list[ServiceApiData] | None = Field( + default=None, description="解析后的接口列表") + data: dict[str, Any] | None = Field( + default=None, description="YAML 内容数据对象") class GetServiceDetailRsp(ResponseData): @@ -457,14 +462,17 @@ class NodeServiceListRsp(ResponseData): class MCPServiceCardItem(BaseModel): """插件中心:MCP服务卡片数据结构""" - mcpservice_id: str = Field(..., alias="mcpserviceId", description="mcp服务ID") + mcpservice_id: str = Field(..., alias="mcpserviceId", + description="mcp服务ID") name: str = Field(..., description="mcp服务名称") description: str = Field(..., description="mcp服务简介") icon: str = Field(..., description="mcp服务图标") author: str = Field(..., description="mcp服务作者") author_name: str = Field(..., alias="authorName", description="mcp服务作者用户名") - is_active: bool = Field(default=False, alias="isActive", description="mcp服务是否激活") - status: MCPInstallStatus = Field(default=MCPInstallStatus.INSTALLING, description="mcp服务状态") + is_active: bool = Field( + default=False, alias="isActive", description="mcp服务是否激活") + status: MCPInstallStatus = Field( + default=MCPInstallStatus.INSTALLING, description="mcp服务状态") class BaseMCPServiceOperationMsg(BaseModel): @@ -543,7 +551,8 @@ class EditMCPServiceMsg(BaseModel): class GetMCPServiceDetailRsp(ResponseData): """GET /api/service/{serviceId} 返回数据结构""" - result: GetMCPServiceDetailMsg | EditMCPServiceMsg = Field(..., title="Result") + result: GetMCPServiceDetailMsg | EditMCPServiceMsg = Field( + ..., title="Result") class DeleteMCPServiceRsp(ResponseData): @@ -574,7 +583,8 @@ class PutFlowReq(BaseModel): """创建/修改流拓扑结构""" flow: FlowItem - focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem()) + focus_point: PositionItem = Field( + alias="focusPoint", default=PositionItem()) class FlowStructurePutMsg(BaseModel): @@ -631,7 +641,8 @@ class LLMProvider(BaseModel): icon: str = Field(description="LLM提供商图标") alias_zh: str = Field(default="", description="中文名称", alias="aliasZh") alias_en: str = Field(default="", description="英文名称", alias="aliasEn") - type: str = Field(default="public", description="类型:public(公网) 或 private(私有)") + type: str = Field(default="public", + description="类型:public(公网) 或 private(私有)") class ListLLMProviderRsp(ResponseData): @@ -658,43 +669,67 @@ class LLMProviderInfo(BaseModel): default="", ) model_name: str = Field(description="模型名称", alias="modelName") - max_tokens: int | None = Field(default=None, description="最大token数", alias="maxTokens") - is_editable: bool = Field(default=True, description="是否可编辑", alias="isEditable") + max_tokens: int | None = Field( + default=None, description="最大token数", alias="maxTokens") + is_editable: bool = Field( + default=True, description="是否可编辑", alias="isEditable") type: list[str] = Field(default=['chat'], description="模型类型列表") - + # 模型能力字段 - 基础能力 provider: str = Field(default="", description="模型提供商") - supports_streaming: bool = Field(default=True, description="是否支持流式输出", alias="supportsStreaming") - supports_function_calling: bool = Field(default=True, description="是否支持函数调用", alias="supportsFunctionCalling") - supports_json_mode: bool = Field(default=True, description="是否支持JSON模式", alias="supportsJsonMode") - supports_structured_output: bool = Field(default=False, description="是否支持结构化输出", alias="supportsStructuredOutput") - + supports_streaming: bool = Field( + default=True, description="是否支持流式输出", alias="supportsStreaming") + supports_function_calling: bool = Field( + default=True, description="是否支持函数调用", alias="supportsFunctionCalling") + supports_json_mode: bool = Field( + default=True, description="是否支持JSON模式", alias="supportsJsonMode") + supports_structured_output: bool = Field( + default=False, description="是否支持结构化输出", alias="supportsStructuredOutput") + # 推理能力 - supports_thinking: bool = Field(default=False, description="是否支持思维链", alias="supportsThinking") - can_toggle_thinking: bool = Field(default=False, description="是否支持开关思维链", alias="canToggleThinking") - supports_reasoning_content: bool = Field(default=False, description="是否返回reasoning_content字段", alias="supportsReasoningContent") - + supports_thinking: bool = Field( + default=False, description="是否支持思维链", alias="supportsThinking") + can_toggle_thinking: bool = Field( + default=False, description="是否支持开关思维链", alias="canToggleThinking") + supports_reasoning_content: bool = Field( + default=False, description="是否返回reasoning_content字段", alias="supportsReasoningContent") + # 参数支持 - max_tokens_param: str = Field(default="max_tokens", description="最大token参数名", alias="maxTokensParam") - supports_temperature: bool = Field(default=True, description="是否支持temperature参数", alias="supportsTemperature") - supports_top_p: bool = Field(default=True, description="是否支持top_p参数", alias="supportsTopP") - supports_top_k: bool = Field(default=False, description="是否支持top_k参数", alias="supportsTopK") - supports_frequency_penalty: bool = Field(default=False, description="是否支持frequency_penalty参数", alias="supportsFrequencyPenalty") - supports_presence_penalty: bool = Field(default=False, description="是否支持presence_penalty参数", alias="supportsPresencePenalty") - supports_min_p: bool = Field(default=False, description="是否支持min_p参数", alias="supportsMinP") - + max_tokens_param: str = Field( + default="max_tokens", description="最大token参数名", alias="maxTokensParam") + supports_temperature: bool = Field( + default=True, description="是否支持temperature参数", alias="supportsTemperature") + supports_top_p: bool = Field( + default=True, description="是否支持top_p参数", alias="supportsTopP") + supports_top_k: bool = Field( + default=False, description="是否支持top_k参数", alias="supportsTopK") + supports_frequency_penalty: bool = Field( + default=False, description="是否支持frequency_penalty参数", alias="supportsFrequencyPenalty") + supports_presence_penalty: bool = Field( + default=False, description="是否支持presence_penalty参数", alias="supportsPresencePenalty") + supports_min_p: bool = Field( + default=False, description="是否支持min_p参数", alias="supportsMinP") + # 高级功能 - supports_response_format: bool = Field(default=True, description="是否支持response_format参数", alias="supportsResponseFormat") - supports_tools: bool = Field(default=True, description="是否支持tools参数", alias="supportsTools") - supports_tool_choice: bool = Field(default=True, description="是否支持tool_choice参数", alias="supportsToolChoice") - supports_extra_body: bool = Field(default=True, description="是否支持extra_body参数", alias="supportsExtraBody") - supports_stream_options: bool = Field(default=True, description="是否支持stream_options参数", alias="supportsStreamOptions") - + supports_response_format: bool = Field( + default=True, description="是否支持response_format参数", alias="supportsResponseFormat") + supports_tools: bool = Field( + default=True, description="是否支持tools参数", alias="supportsTools") + supports_tool_choice: bool = Field( + default=True, description="是否支持tool_choice参数", alias="supportsToolChoice") + supports_extra_body: bool = Field( + default=True, description="是否支持extra_body参数", alias="supportsExtraBody") + supports_stream_options: bool = Field( + default=True, description="是否支持stream_options参数", alias="supportsStreamOptions") + # 特殊参数 - supports_enable_thinking: bool = Field(default=False, description="是否支持enable_thinking参数", alias="supportsEnableThinking") - supports_thinking_budget: bool = Field(default=False, description="是否支持思维链token预算", alias="supportsThinkingBudget") - supports_enable_search: bool = Field(default=False, description="是否支持联网搜索", alias="supportsEnableSearch") - + supports_enable_thinking: bool = Field( + default=False, description="是否支持enable_thinking参数", alias="supportsEnableThinking") + supports_thinking_budget: bool = Field( + default=False, description="是否支持思维链token预算", alias="supportsThinkingBudget") + supports_enable_search: bool = Field( + default=False, description="是否支持联网搜索", alias="supportsEnableSearch") + # 其他信息 notes: str = Field(default="", description="备注信息") @@ -734,7 +769,8 @@ class GetParamsRsp(ResponseData): class OperateAndBindType(BaseModel): """操作和绑定类型数据结构""" - operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型") + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field( + description="操作类型") bind_type: ValueType = Field(description="绑定类型") diff --git a/apps/scripts/delete_user.py b/apps/scripts/delete_user.py index c3594a80824b3d45ec2f11fb36e20ae9c164eea5..cf1aacdbbaa20b114bbfaa3a38c3cc4f5f02ad51 100644 --- a/apps/scripts/delete_user.py +++ b/apps/scripts/delete_user.py @@ -22,7 +22,7 @@ async def _delete_user(timestamp: float) -> None: for user_id in user_ids: await UserManager.delete_userinfo_by_user_sub(user_id) # 查找用户关联的文件 - doc_collection = MongoDB().get_collection("document") + doc_collection = MongoDB.get_collection("document") docs = [doc["_id"] async for doc in doc_collection.find({"user_sub": user_id})] # 删除文件 try: diff --git a/apps/services/activity.py b/apps/services/activity.py index 7ba20cbb4300a5fa5cb60e37cdb887f38d54232f..5c3a71139ad715702c96abc12e4b0cca08174507 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -24,7 +24,7 @@ class Activity: """ # 检查用户是否正在提问 - active = await MongoDB().get_collection("activity").find_one( + active = await MongoDB.get_collection("activity").find_one( {"_id": active_id}, ) return bool(active) @@ -34,14 +34,15 @@ class Activity: """设置用户的活跃标识""" time = round(datetime.now(UTC).timestamp(), 3) # 设置用户活跃状态 - collection = MongoDB().get_collection("activity") + collection = MongoDB.get_collection("activity") # 查看用户活跃标识是否在滑动窗口内 if await collection.count_documents({"user_sub": user_sub, "timestamp": {"$gt": time - SLIDE_WINDOW_TIME}}) >= SLIDE_WINDOW_QUESTION_COUNT: err = "[Activity] 用户在滑动窗口内提问次数超过限制,请稍后再试。" raise ActivityError(err) await collection.delete_many( - {"user_sub": user_sub, "timestamp": {"$lte": time - SLIDE_WINDOW_TIME}}, + {"user_sub": user_sub, "timestamp": { + "$lte": time - SLIDE_WINDOW_TIME}}, ) # 插入新的活跃记录 tmp_record = { @@ -62,6 +63,6 @@ class Activity: :param user_sub: 用户实体ID """ # 清除用户当前活动标识 - await MongoDB().get_collection("activity").delete_one( + await MongoDB.get_collection("activity").delete_one( {"_id": active_id}, ) diff --git a/apps/services/api_key.py b/apps/services/api_key.py index 8f9711f81606040019ecd525a17a2ff13d89367a..994d19a7f2f90d7d7ba93bf3a234cb67607d0376 100644 --- a/apps/services/api_key.py +++ b/apps/services/api_key.py @@ -21,12 +21,11 @@ class ApiKeyManager: :param user_sub: 用户名 :return: API Key """ - mongo = MongoDB() api_key = str(uuid.uuid4().hex) api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] try: - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") await user_collection.update_one( {"_id": user_sub}, {"$set": {"api_key": api_key_hash}}, @@ -45,11 +44,10 @@ class ApiKeyManager: :param user_sub: 用户ID :return: 删除API Key是否成功 """ - mongo = MongoDB() if not await ApiKeyManager.api_key_exists(user_sub): return False try: - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") await user_collection.update_one( {"_id": user_sub}, {"$unset": {"api_key": ""}}, @@ -67,9 +65,8 @@ class ApiKeyManager: :param user_sub: 用户ID :return: API Key是否存在 """ - mongo = MongoDB() try: - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") user_data = await user_collection.find_one({"_id": user_sub}, {"_id": 0, "api_key": 1}) return user_data is not None and ("api_key" in user_data and user_data["api_key"]) except Exception: @@ -84,10 +81,9 @@ class ApiKeyManager: :param api_key: API Key :return: 用户ID """ - mongo = MongoDB() api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] try: - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") user_data = await user_collection.find_one({"api_key": api_key_hash}, {"_id": 1}) return user_data["_id"] if user_data else None except Exception: @@ -102,10 +98,9 @@ class ApiKeyManager: :param api_key: API Key :return: 验证API Key是否成功 """ - mongo = MongoDB() api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] try: - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") key_data = await user_collection.find_one({"api_key": api_key_hash}, {"_id": 1}) except Exception: logger.exception("[ApiKeyManager] 验证API Key失败") @@ -121,13 +116,12 @@ class ApiKeyManager: :param user_sub: 用户ID :return: 更新后的API Key """ - mongo = MongoDB() if not await ApiKeyManager.api_key_exists(user_sub): return None api_key = str(uuid.uuid4().hex) api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] try: - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") await user_collection.update_one( {"_id": user_sub}, {"$set": {"api_key": api_key_hash}}, diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index 6ab42c8d3c9f245047c37973413a53103637204e..4b758182cd0c733c0edf16822e7fe552b44b3ecf 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -94,7 +94,8 @@ class AppCenterManager: name=app.name, description=app.description, author=app.author, - authorName=getattr(app, 'author_name', app.author), # 使用author_name,备选author + # 使用author_name,备选author + authorName=getattr(app, 'author_name', app.author), favorited=(app.id in user_favorite_app_ids), published=app.published, ) @@ -111,8 +112,7 @@ class AppCenterManager: :param app_id: 应用唯一标识 :return: 应用元数据 """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") db_data = await app_collection.find_one({"_id": app_id}) if not db_data: msg = "应用不存在" @@ -182,8 +182,7 @@ class AppCenterManager: break # 更新数据库 - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") await app_collection.update_one( {"_id": app_id}, {"$set": {"published": published}}, @@ -208,9 +207,8 @@ class AppCenterManager: :param user_sub: 用户唯一标识 :param favorited: 是否收藏 """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") - user_collection = mongo.get_collection("user") + app_collection = MongoDB.get_collection("app") + user_collection = MongoDB.get_collection("user") db_data = await app_collection.find_one({"_id": app_id}) if not db_data: msg = "应用不存在" @@ -245,8 +243,7 @@ class AppCenterManager: :param app_id: 应用唯一标识 :param user_sub: 用户唯一标识 """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") app_data = AppPool.model_validate(await app_collection.find_one({"_id": app_id})) if not app_data: msg = "应用不存在" @@ -269,9 +266,8 @@ class AppCenterManager: :param user_sub: 用户唯一标识 :return: 最近使用的应用列表 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - app_collection = mongo.get_collection("app") + user_collection = MongoDB.get_collection("user") + app_collection = MongoDB.get_collection("app") # 校验用户信息 user_data = User.model_validate(await user_collection.find_one({"_id": user_sub})) # 获取最近使用的应用ID列表,按最后使用时间倒序排序 @@ -287,14 +283,15 @@ class AppCenterManager: else: # 查询 MongoDB,获取符合条件的应用 apps = await app_collection.find({"_id": {"$in": app_ids}}, {"name": 1, "published": 1}).to_list(length=len(app_ids)) - app_map = {str(a["_id"]): {"name": a.get("name", ""), "published": a.get("published", False)} for a in apps} + app_map = {str(a["_id"]): {"name": a.get("name", ""), + "published": a.get("published", False)} for a in apps} return RecentAppList( applications=[ RecentAppListItem( - appId=app_id, + appId=app_id, name=app_map.get(app_id, {}).get("name", ""), published=app_map.get(app_id, {}).get("published", False) - ) + ) for app_id in app_ids ], ) @@ -311,14 +308,14 @@ class AppCenterManager: if not app_id: return True try: - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") current_time = round(datetime.now(UTC).timestamp(), 3) result = await user_collection.update_one( {"_id": user_sub}, # 查询条件 { "$set": { - f"app_usage.{app_id}.last_used": current_time, # 更新最后使用时间 + # 更新最后使用时间 + f"app_usage.{app_id}.last_used": current_time, }, "$inc": { f"app_usage.{app_id}.count": 1, # 增加使用次数 @@ -344,8 +341,7 @@ class AppCenterManager: :return: 默认工作流ID """ try: - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") db_data = await app_collection.find_one({"_id": app_id}) if not db_data: logger.warning("[AppCenterManager] 应用不存在: %s", app_id) @@ -366,8 +362,7 @@ class AppCenterManager: page_size: int, ) -> tuple[list[AppPool], int]: """根据过滤条件搜索应用并计算总页数""" - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") total_apps = await app_collection.count_documents(search_conditions) db_data = ( await app_collection.find(search_conditions) @@ -391,8 +386,7 @@ class AppCenterManager: :raises ValueError: 应用不存在 :raises InstancePermissionError: 权限不足 """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") app_data = AppPool.model_validate(await app_collection.find_one({"_id": app_id})) if not app_data: msg = "应用不存在" @@ -409,7 +403,7 @@ class AppCenterManager: from apps.services.user import UserManager user_info = await UserManager.get_userinfo_by_user_sub(user_sub) author_name = user_info.user_name if user_info and user_info.user_name else user_sub - + return { "type": MetadataType.APP, "id": app_id, @@ -532,7 +526,7 @@ class AppCenterManager: else: # 在预期的条件下,如果在 data 或 app_data 中找不到 llm_id,则默认回退为空字符串。 metadata.llm_id = "" - + # 处理enable_thinking字段 if data is not None and hasattr(data, "enable_thinking"): # 创建或更新应用场景,使用传入的 enable_thinking 状态 @@ -547,23 +541,25 @@ class AppCenterManager: if metadata.llm_id: try: from apps.common.mongo import MongoDB - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") llm_doc = await llm_collection.find_one({"_id": metadata.llm_id}) if llm_doc: - supports_thinking = llm_doc.get("supports_thinking", False) - can_toggle_thinking = llm_doc.get("can_toggle_thinking", False) + supports_thinking = llm_doc.get( + "supports_thinking", False) + can_toggle_thinking = llm_doc.get( + "can_toggle_thinking", False) # 只有同时支持思维链和开关思维链时,才默认启用 if supports_thinking and can_toggle_thinking: default_enable_thinking = True except Exception as e: logger.warning(f"[AppCenter] 获取LLM思维链能力失败: {e}") metadata.enable_thinking = default_enable_thinking - + # Agent 应用的发布状态逻辑 if published is not None: # 从 update_app_publish_status 调用,'published' 参数已提供 metadata.published = published - else: # 从 create_app 或 update_app 调用 (此时传递给 _create_metadata 的 'published' 参数为 None) + # 从 create_app 或 update_app 调用 (此时传递给 _create_metadata 的 'published' 参数为 None) + else: # 'published' 状态重置为 False。 metadata.published = False @@ -617,7 +613,8 @@ class AppCenterManager: # 设置权限 if data: - common_params["permission"] = AppCenterManager._create_permission(data.permission) + common_params["permission"] = AppCenterManager._create_permission( + data.permission) elif app_data: common_params["permission"] = app_data.permission @@ -663,7 +660,6 @@ class AppCenterManager: @staticmethod async def _get_favorite_app_ids_by_user(user_sub: str) -> list[str]: """获取用户收藏的应用ID""" - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") user_data = User.model_validate(await user_collection.find_one({"_id": user_sub})) return user_data.fav_apps diff --git a/apps/services/application.py b/apps/services/application.py index dde22780423b1d7016e194cb366253549cfcd908..0884ac36d12a46635b232d14e410d340e0a95488 100644 --- a/apps/services/application.py +++ b/apps/services/application.py @@ -21,8 +21,7 @@ class AppManager: :param app_id: 应用id :return: 如果用户具有所需权限则返回True,否则返回False """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") + app_collection = MongoDB.get_collection("app") query = { "_id": app_id, "$or": [ @@ -49,8 +48,7 @@ class AppManager: :param app_id: 应用id :return: 如果应用属于用户则返回True,否则返回False """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") # 获取应用集合' + app_collection = MongoDB.get_collection("app") # 获取应用集合' query = { "_id": app_id, "author": user_sub, diff --git a/apps/services/audit_log.py b/apps/services/audit_log.py index f7b4843a4fd8615bc8e6942dade3f48eda88c56d..05677c4f68b036a2f8efd978750cf0df22860e83 100644 --- a/apps/services/audit_log.py +++ b/apps/services/audit_log.py @@ -18,6 +18,5 @@ class AuditLogManager: :param data: 审计日志数据 """ - mongo = MongoDB() - collection = mongo.get_collection("audit") + collection = MongoDB.get_collection("audit") await collection.insert_one(data.model_dump(by_alias=True)) diff --git a/apps/services/blacklist.py b/apps/services/blacklist.py index 7dad46f28c2a1e22de0e7da211828fd3d6275079..feb912a644ad9174e509c27d053e28d562befad4 100644 --- a/apps/services/blacklist.py +++ b/apps/services/blacklist.py @@ -22,7 +22,7 @@ class QuestionBlacklistManager: async def check_blacklisted_questions(input_question: str) -> bool: """给定问题,查找问题是否在黑名单里""" try: - blacklist_collection = MongoDB().get_collection("blacklist") + blacklist_collection = MongoDB.get_collection("blacklist") result = await blacklist_collection.find_one( {"question": {"$regex": f"/{re.escape(input_question)}/i"}, "is_audited": True}, {"_id": 1}, ) @@ -47,7 +47,7 @@ class QuestionBlacklistManager: is_deletion标识是否为删除操作 """ try: - blacklist_collection = MongoDB().get_collection("blacklist") + blacklist_collection = MongoDB.get_collection("blacklist") if is_deletion: await blacklist_collection.find_one_and_delete({"_id": blacklist_id}) @@ -71,7 +71,7 @@ class QuestionBlacklistManager: async def get_blacklisted_questions(limit: int, offset: int, *, is_audited: bool) -> list[Blacklist]: """分页式获取目前所有的问题(待审核或已拉黑)黑名单""" try: - blacklist_collection = MongoDB().get_collection("blacklist") + blacklist_collection = MongoDB.get_collection("blacklist") return [ Blacklist.model_validate(item) async for item in blacklist_collection.find({"is_audited": is_audited}).skip(offset).limit(limit) @@ -89,7 +89,7 @@ class UserBlacklistManager: async def get_blacklisted_users(limit: int, offset: int) -> list[str]: """获取当前所有黑名单用户""" try: - user_collection = MongoDB().get_collection("user") + user_collection = MongoDB.get_collection("user") return [ user["_id"] async for user in user_collection.find({"credit": {"$lte": 0}}, {"_id": 1}) @@ -105,7 +105,7 @@ class UserBlacklistManager: async def check_blacklisted_users(user_sub: str) -> bool: """检测某用户是否已被拉黑""" try: - user_collection = MongoDB().get_collection("user") + user_collection = MongoDB.get_collection("user") result = await user_collection.find_one( {"user_sub": user_sub, "credit": {"$lte": 0}, "is_whitelisted": False}, {"_id": 1}, ) @@ -123,7 +123,7 @@ class UserBlacklistManager: """修改用户的信用分""" try: # 获取用户当前信用分 - user_collection = MongoDB().get_collection("user") + user_collection = MongoDB.get_collection("user") result = await user_collection.find_one({"user_sub": user_sub}, {"_id": 0, "credit": 1}) # 用户不存在 if result is None: @@ -169,7 +169,7 @@ class AbuseManager: """存储用户举报详情""" try: # 判断record_id是否合法 - record_group_collection = MongoDB().get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") record = await record_group_collection.aggregate( [ {"$match": {"user_sub": user_sub}}, @@ -190,7 +190,7 @@ class AbuseManager: record_data = RecordContent.model_validate_json(record_data) # 检查该条目类似内容是否已被举报过 - blacklist_collection = MongoDB().get_collection("question_blacklist") + blacklist_collection = MongoDB.get_collection("question_blacklist") query = await blacklist_collection.find_one({"_id": record_id}) if query is not None: logger.info("[AbuseManager] 问题已被举报过") @@ -217,7 +217,7 @@ class AbuseManager: async def audit_abuse_report(question_id: str, *, is_deletion: bool = False) -> bool: """对某一特定的待审问题进行操作,包括批准审核与删除未审问题""" try: - blacklist_collection = MongoDB().get_collection("blacklist") + blacklist_collection = MongoDB.get_collection("blacklist") if is_deletion: await blacklist_collection.delete_one({"_id": question_id, "is_audited": False}) return True diff --git a/apps/services/comment.py b/apps/services/comment.py index 60c13905ed8d991eec341b8368912a7e7298aa04..33f0b531a600459c1e1c6c8a136f4abba14ecfcb 100644 --- a/apps/services/comment.py +++ b/apps/services/comment.py @@ -20,7 +20,7 @@ class CommentManager: :param record_id: 问答ID :return: 评论内容 """ - record_group_collection = MongoDB().get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") result = await record_group_collection.aggregate( [ {"$match": {"_id": group_id, "records.id": record_id}}, @@ -42,8 +42,7 @@ class CommentManager: :param record_id: 问答ID :param data: 评论内容 """ - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") await record_group_collection.update_one( {"_id": group_id, "records.id": record_id}, {"$set": {"records.$.comment": data.model_dump(by_alias=True)}}, diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 02702f2fe0efeeba5eb632a25a87b8852bdf0c78..89fe4dc885d82b69828b020e2ff0185f4de03ebe 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -25,7 +25,7 @@ class ConversationManager: @staticmethod async def get_conversation_by_user_sub(user_sub: str) -> list[Conversation]: """根据用户ID获取对话列表,按时间由近到远排序""" - conv_collection = MongoDB().get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") return [ Conversation(**conv) async for conv in conv_collection.find({"user_sub": user_sub, "debug": False}).sort({"created_at": 1}) @@ -34,7 +34,7 @@ class ConversationManager: @staticmethod async def get_conversation_by_conversation_id(user_sub: str, conversation_id: str) -> Conversation | None: """通过ConversationID查询对话信息""" - conv_collection = MongoDB().get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) if not result: return None @@ -98,12 +98,11 @@ class ConversationManager: kb_list=kb_item_list, debug=debug if debug else False, ) - mongo = MongoDB() try: - async with mongo.get_session() as session, await session.start_transaction(): - conv_collection = mongo.get_collection("conversation") + async with MongoDB.get_session() as session, await session.start_transaction(): + conv_collection = MongoDB.get_collection("conversation") await conv_collection.insert_one(conv.model_dump(by_alias=True), session=session) - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") update_data: dict[str, dict[str, Any]] = { "$push": {"conversations": conversation_id}, } @@ -128,8 +127,7 @@ class ConversationManager: @staticmethod async def update_conversation_by_conversation_id(user_sub: str, conversation_id: str, data: dict[str, Any]) -> bool: """通过ConversationID更新对话信息""" - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") result = await conv_collection.update_one( {"_id": conversation_id, "user_sub": user_sub}, {"$set": data}, @@ -139,10 +137,9 @@ class ConversationManager: @staticmethod async def delete_conversation_by_conversation_id(user_sub: str, conversation_id: str) -> None: """通过ConversationID删除对话""" - mongo = MongoDB() - user_collection = mongo.get_collection("user") - conv_collection = mongo.get_collection("conversation") - record_group_collection = mongo.get_collection("record_group") + user_collection = MongoDB.get_collection("user") + conv_collection = MongoDB.get_collection("conversation") + record_group_collection = MongoDB.get_collection("record_group") # 🔑 修正:获取所有需要清理的文件ID files_to_cleanup = [] @@ -162,7 +159,7 @@ class ConversationManager: elif "_id" in doc: files_to_cleanup.append(doc["_id"]) - async with mongo.get_session() as session, await session.start_transaction(): + async with MongoDB.get_session() as session, await session.start_transaction(): await conv_collection.delete_one( {"_id": conversation_id, "user_sub": user_sub}, session=session, ) diff --git a/apps/services/document.py b/apps/services/document.py index 980eeac17dfaf64ce71cd0ebdfe2e0ea1ddf7ece..9e00eb97d1c2ef92600da2db5e092fab795efeed 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -157,10 +157,8 @@ class DocumentManager: - 这里:上传文件到MinIO/MongoDB,然后更新变量池的file_id/file_ids """ uploaded_files = [] - - mongo = MongoDB() - doc_collection = mongo.get_collection("document") - conversation_collection = mongo.get_collection("conversation") + doc_collection = MongoDB.get_collection("document") + conversation_collection = MongoDB.get_collection("conversation") # 🔑 第一步:上传文件到MinIO和MongoDB for document in documents: @@ -267,8 +265,7 @@ class DocumentManager: """存储用户scope文件""" uploaded_files = [] - mongo = MongoDB() - doc_collection = mongo.get_collection("document") + doc_collection = MongoDB.get_collection("document") for document in documents: if document.filename is None or document.size is None: continue @@ -308,8 +305,7 @@ class DocumentManager: """存储环境scope文件""" uploaded_files = [] - mongo = MongoDB() - doc_collection = mongo.get_collection("document") + doc_collection = MongoDB.get_collection("document") for document in documents: if document.filename is None or document.size is None: continue @@ -349,8 +345,7 @@ class DocumentManager: """为变量系统存储文件 - 不包含RAG处理和变量池存储""" uploaded_files = [] - mongo = MongoDB() - doc_collection = mongo.get_collection("document") + doc_collection = MongoDB.get_collection("document") for document in files: if document.filename is None or document.size is None: @@ -376,7 +371,8 @@ class DocumentManager: # 如果有关联的conversation,更新conversation的unused_docs if conversation_id: - conversation_collection = mongo.get_collection("conversation") + conversation_collection = MongoDB.get_collection( + "conversation") await conversation_collection.update_one( {"_id": conversation_id}, { @@ -652,8 +648,7 @@ class DocumentManager: async def _get_flow_id_for_conversation(cls, conversation_id: str) -> str: """获取对话对应的flow_id""" try: - mongo = MongoDB() - conversation_collection = mongo.get_collection("conversation") + conversation_collection = MongoDB.get_collection("conversation") conversation = await conversation_collection.find_one({"_id": conversation_id}) if conversation and conversation.get("app_id"): @@ -669,9 +664,8 @@ class DocumentManager: @classmethod async def get_unused_docs(cls, user_sub: str, conversation_id: str) -> list[Document]: """获取Conversation中未使用的文件""" - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") - doc_collection = mongo.get_collection("document") + conv_collection = MongoDB.get_collection("conversation") + doc_collection = MongoDB.get_collection("document") conv = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) if not conv: @@ -685,9 +679,8 @@ class DocumentManager: async def get_used_docs_by_record_group( cls, user_sub: str, record_group_id: str, type: str | None = None) -> list[RecordDocument]: """获取RecordGroup关联的文件""" - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - document_collection = mongo.get_collection("document") + record_group_collection = MongoDB.get_collection("record_group") + document_collection = MongoDB.get_collection("document") if type not in ["question", "answer", None]: raise ValueError("type must be 'question', 'answer' or None") record_group = await record_group_collection.find_one({"_id": record_group_id, "user_sub": user_sub}) @@ -726,9 +719,8 @@ class DocumentManager: async def get_used_docs_by_record_groups( cls, user_sub: str, record_group_ids: list[str], type: str | None = None) -> list[RecordDocument]: """获取多个RecordGroup关联的文件""" - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - document_collection = mongo.get_collection("document") + record_group_collection = MongoDB.get_collection("record_group") + document_collection = MongoDB.get_collection("document") if type not in ["question", "answer", None]: raise ValueError("type must be 'question', 'answer' or None") docs = [] @@ -775,9 +767,8 @@ class DocumentManager: async def get_used_docs( cls, user_sub: str, conversation_id: str, record_num: int | None = 10, type: str | None = None) -> list[Document]: """获取最后n次问答所用到的文件""" - mongo = MongoDB() - docs_collection = mongo.get_collection("document") - record_group_collection = mongo.get_collection("record_group") + docs_collection = MongoDB.get_collection("document") + record_group_collection = MongoDB.get_collection("record_group") if type not in ["question", "answer", None]: raise ValueError("type must be 'question', 'answer' or None") if record_num: @@ -810,11 +801,10 @@ class DocumentManager: @classmethod async def delete_document(cls, user_sub: str, document_list: list[str]) -> bool: """从未使用文件列表中删除一个文件""" - mongo = MongoDB() - doc_collection = mongo.get_collection("document") - conv_collection = mongo.get_collection("conversation") + doc_collection = MongoDB.get_collection("document") + conv_collection = MongoDB.get_collection("conversation") try: - async with mongo.get_session() as session, await session.start_transaction(): + async with MongoDB.get_session() as session, await session.start_transaction(): for doc in document_list: doc_info = await doc_collection.find_one_and_delete( {"_id": doc, "user_sub": user_sub}, session=session, @@ -846,11 +836,10 @@ class DocumentManager: @classmethod async def delete_document_by_conversation_id(cls, user_sub: str, conversation_id: str) -> list[str]: """通过ConversationID删除文件""" - mongo = MongoDB() - doc_collection = mongo.get_collection("document") + doc_collection = MongoDB.get_collection("document") doc_ids = [] - async with mongo.get_session() as session, await session.start_transaction(): + async with MongoDB.get_session() as session, await session.start_transaction(): async for doc in doc_collection.find( {"user_sub": user_sub, "conversation_id": conversation_id}, session=session, ): @@ -869,16 +858,14 @@ class DocumentManager: @classmethod async def get_doc_count(cls, user_sub: str, conversation_id: str) -> int: """获取对话文件数量""" - mongo = MongoDB() - doc_collection = mongo.get_collection("document") + doc_collection = MongoDB.get_collection("document") return await doc_collection.count_documents({"user_sub": user_sub, "conversation_id": conversation_id}) @classmethod async def change_doc_status(cls, user_sub: str, conversation_id: str, record_group_id: str) -> None: """文件状态由unused改为used""" - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - conversation_collection = mongo.get_collection("conversation") + record_group_collection = MongoDB.get_collection("record_group") + conversation_collection = MongoDB.get_collection("conversation") # 查找Conversation中的unused_docs conversation = await conversation_collection.find_one({"user_sub": user_sub, "_id": conversation_id}) @@ -901,8 +888,7 @@ class DocumentManager: @classmethod async def save_answer_doc(cls, user_sub: str, record_group_id: str, doc_infos: list[RecordDocument]) -> None: """保存与答案关联的文件""" - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") for doc_info in doc_infos: await record_group_collection.update_one( {"_id": record_group_id, "user_sub": user_sub}, diff --git a/apps/services/domain.py b/apps/services/domain.py index 7b512a8343cb06e298b18df3b8db750be41d7cf0..d50ce67086788def4f3b6591be6e36e6ef016727 100644 --- a/apps/services/domain.py +++ b/apps/services/domain.py @@ -21,11 +21,9 @@ class DomainManager: :return: 领域信息列表 """ - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") + domain_collection = MongoDB.get_collection("domain") return [Domain(**domain) async for domain in domain_collection.find()] - @staticmethod async def get_domain_by_domain_name(domain_name: str) -> Domain | None: """ @@ -34,14 +32,12 @@ class DomainManager: :param domain_name: 领域名称 :return: 领域信息 """ - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") + domain_collection = MongoDB.get_collection("domain") domain_data = await domain_collection.find_one({"domain_name": domain_name}) if domain_data: return Domain(**domain_data) return None - @staticmethod async def add_domain(domain_data: PostDomainData) -> None: """ @@ -49,15 +45,13 @@ class DomainManager: :param domain_data: 领域信息 """ - mongo = MongoDB() domain = Domain( name=domain_data.domain_name, definition=domain_data.domain_description, ) - domain_collection = mongo.get_collection("domain") + domain_collection = MongoDB.get_collection("domain") await domain_collection.insert_one(domain.model_dump(by_alias=True)) - @staticmethod async def update_domain_by_domain_name(domain_data: PostDomainData) -> Domain: """ @@ -66,12 +60,11 @@ class DomainManager: :param domain_data: 领域信息 :return: 更新后的领域信息 """ - mongo = MongoDB() update_dict = { "definition": domain_data.domain_description, "updated_at": round(datetime.now(tz=UTC).timestamp(), 3), } - domain_collection = mongo.get_collection("domain") + domain_collection = MongoDB.get_collection("domain") await domain_collection.update_one( {"name": domain_data.domain_name}, {"$set": update_dict}, @@ -85,6 +78,5 @@ class DomainManager: :param domain_data: 领域信息 """ - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") + domain_collection = MongoDB.get_collection("domain") await domain_collection.delete_one({"name": domain_data.domain_name}) diff --git a/apps/services/flow.py b/apps/services/flow.py index 997e2fb9852a0326a7953485906b2e428b29e7c9..18c146b257bc713823a4bf635247151c1d069b24 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -41,8 +41,8 @@ class FlowManager: :param service_id: 服务id :return: 如果用户具有所需权限则返回True,否则返回False """ - node_pool_collection = MongoDB().get_collection("node") - service_collection = MongoDB().get_collection("service") + node_pool_collection = MongoDB.get_collection("node") + service_collection = MongoDB.get_collection("service") try: node_pool_record = await node_pool_collection.find_one({"_id": node_meta_data_id}) @@ -83,7 +83,7 @@ class FlowManager: :param service_id: 服务id :return: 节点元数据的列表 """ - node_pool_collection = MongoDB().get_collection("node") # 获取节点集合 + node_pool_collection = MongoDB.get_collection("node") # 获取节点集合 try: cursor = node_pool_collection.find( {"service_id": service_id}).sort("created_at", ASCENDING) @@ -147,8 +147,8 @@ class FlowManager: :user_sub: 用户的唯一标识符 :return: service的列表 """ - service_collection = MongoDB().get_collection("service") - user_collection = MongoDB().get_collection("user") + service_collection = MongoDB.get_collection("service") + user_collection = MongoDB.get_collection("user") try: db_result = await user_collection.find_one({"_id": user_sub}) user = User.model_validate(db_result) @@ -220,7 +220,7 @@ class FlowManager: :param node_meta_data_id: node_meta_data的id :return: node meta data id对应的节点源数据信息 """ - node_pool_collection = MongoDB().get_collection("node") # 获取节点集合 + node_pool_collection = MongoDB.get_collection("node") # 获取节点集合 try: node_pool_record = await node_pool_collection.find_one({"_id": node_meta_data_id}) if node_pool_record is None: @@ -253,7 +253,7 @@ class FlowManager: :return: 流的item和用户在这个流上的视觉焦点 """ try: - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") app_record = await app_collection.find_one({"_id": app_id}) if app_record is None: logger.error("[FlowManager] 应用 %s 不存在", app_id) @@ -462,7 +462,7 @@ class FlowManager: import time st = time.time() try: - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") app_record = await app_collection.find_one({"_id": app_id}) if app_record is None: logger.error("[FlowManager] 应用 %s 不存在", app_id) @@ -612,7 +612,7 @@ class FlowManager: """ try: - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") key = f"flow/{flow_id}.yaml" await app_collection.update_one({"_id": app_id}, {"$unset": {f"hashes.{key}": ""}}) await app_collection.update_one({"_id": app_id}, {"$pull": {"flows": {"id": flow_id}}}) @@ -902,7 +902,7 @@ class FlowManager: :return: 子工作流的item """ try: - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") app_record = await app_collection.find_one({"_id": app_id}) if app_record is None: logger.error("[FlowManager] 应用 %s 不存在", app_id) diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index bd8dfc9e808f642a8eecf493a96f762dad8ae7b9..7b23768eaafd18ec8cdc50760f555383203768a6 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -29,8 +29,7 @@ class KnowledgeBaseManager: :param conversation_id: 对话ID :return: 知识库ID列表 """ - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) if not result: err_msg = "[KnowledgeBaseManager] 获取知识库ID失败,未找到对话" @@ -80,8 +79,7 @@ class KnowledgeBaseManager: :param user_sub: 用户sub :return: 知识库ID列表 """ - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) if not result: err_msg = "[KnowledgeBaseManager] 获取知识库ID失败,未找到对话" @@ -128,9 +126,7 @@ class KnowledgeBaseManager: :return: 是否更新成功 """ kb_ids = list(set(kb_ids)) - mongo = MongoDB() - - conv_collection = mongo.get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") conv_dict = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) if not conv_dict: err_msg = "[KnowledgeBaseManager] 更新知识库失败,未找到对话" diff --git a/apps/services/llm.py b/apps/services/llm.py index a57e714e9445ad36e3d357dfe1c6f830944fd9b4..595551323c1b90d9b6546b812c06fb0d5baaba93 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -27,7 +27,7 @@ class LLMManager: def _create_llm_provider_info(llm: dict) -> LLMProviderInfo: """ 从数据库 LLM 文档创建 LLMProviderInfo 对象的辅助方法 - + :param llm: 数据库中的 LLM 文档 :return: LLMProviderInfo 对象 """ @@ -45,40 +45,44 @@ class LLMManager: maxTokens=llm["max_tokens"], isEditable=bool(llm.get("user_sub")), # 系统模型(user_sub="")不可编辑 type=llm_type, # 始终返回列表格式 - + # 模型能力字段 - 基础能力 provider=llm.get("provider", ""), supportsStreaming=llm.get("supports_streaming", True), supportsFunctionCalling=llm.get("supports_function_calling", True), supportsJsonMode=llm.get("supports_json_mode", True), - supportsStructuredOutput=llm.get("supports_structured_output", False), - + supportsStructuredOutput=llm.get( + "supports_structured_output", False), + # 推理能力 supportsThinking=llm.get("supports_thinking", False), canToggleThinking=llm.get("can_toggle_thinking", False), - supportsReasoningContent=llm.get("supports_reasoning_content", False), - + supportsReasoningContent=llm.get( + "supports_reasoning_content", False), + # 参数支持 maxTokensParam=llm.get("max_tokens_param", "max_tokens"), supportsTemperature=llm.get("supports_temperature", True), supportsTopP=llm.get("supports_top_p", True), supportsTopK=llm.get("supports_top_k", False), - supportsFrequencyPenalty=llm.get("supports_frequency_penalty", False), - supportsPresencePenalty=llm.get("supports_presence_penalty", False), + supportsFrequencyPenalty=llm.get( + "supports_frequency_penalty", False), + supportsPresencePenalty=llm.get( + "supports_presence_penalty", False), supportsMinP=llm.get("supports_min_p", False), - + # 高级功能 supportsResponseFormat=llm.get("supports_response_format", True), supportsTools=llm.get("supports_tools", True), supportsToolChoice=llm.get("supports_tool_choice", True), supportsExtraBody=llm.get("supports_extra_body", True), supportsStreamOptions=llm.get("supports_stream_options", True), - + # 特殊参数 supportsEnableThinking=llm.get("supports_enable_thinking", False), supportsThinkingBudget=llm.get("supports_thinking_budget", False), supportsEnableSearch=llm.get("supports_enable_search", False), - + # 其他信息 notes=llm.get("notes", ""), ) @@ -113,8 +117,7 @@ class LLMManager: :param conversation_id: 对话ID :return: 大模型ID """ - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") + conv_collection = MongoDB.get_collection("conversation") result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) if not result: return "" @@ -128,7 +131,7 @@ class LLMManager: :param llm_id: 大模型ID :return: 大模型对象 """ - llm_collection = MongoDB().get_collection("llm") + llm_collection = MongoDB.get_collection("llm") result = await llm_collection.find_one({"_id": llm_id}) @@ -148,7 +151,7 @@ class LLMManager: :param llm_id: 大模型ID :return: 大模型对象 """ - llm_collection = MongoDB().get_collection("llm") + llm_collection = MongoDB.get_collection("llm") result = await llm_collection.find_one({"_id": llm_id, "user_sub": user_sub}) @@ -169,8 +172,7 @@ class LLMManager: :param model_type: 模型类型,可选值:'chat', 'image', 'video', 'speech', 'embedding', 'reranker', 'function_call' :return: 大模型列表 """ - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") # 构建查询条件:包含用户模型和系统模型 base_query = {"$or": [{"user_sub": user_sub}, {"user_sub": ""}]} @@ -198,8 +200,7 @@ class LLMManager: :param req: 创建大模型请求体 :return: 大模型ID """ - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") # 推断模型能力 provider = req.provider or get_provider_from_endpoint( @@ -229,23 +230,24 @@ class LLMManager: # 从model_registry获取完整的模型能力 from apps.llm.model_types import ModelType - capabilities_obj = model_registry.get_model_capabilities(provider, req.model_name, ModelType.CHAT) + capabilities_obj = model_registry.get_model_capabilities( + provider, req.model_name, ModelType.CHAT) # 使用请求中的能力信息,如果没有则从capabilities_obj获取,最后使用默认值 capabilities = { "provider": provider, - + # 基础能力 "supports_streaming": req.supports_streaming if hasattr(req, 'supports_streaming') and req.supports_streaming is not None else (capabilities_obj.supports_streaming if capabilities_obj else True), "supports_function_calling": req.supports_function_calling if req.supports_function_calling is not None else (capabilities_obj.supports_function_calling if capabilities_obj else True), "supports_json_mode": req.supports_json_mode if req.supports_json_mode is not None else (capabilities_obj.supports_json_mode if capabilities_obj else True), "supports_structured_output": req.supports_structured_output if req.supports_structured_output is not None else (capabilities_obj.supports_structured_output if capabilities_obj else False), - + # 推理能力 "supports_thinking": req.supports_thinking if req.supports_thinking is not None else (capabilities_obj.supports_thinking if capabilities_obj else False), "can_toggle_thinking": req.can_toggle_thinking if req.can_toggle_thinking is not None else (capabilities_obj.can_toggle_thinking if capabilities_obj else False), "supports_reasoning_content": req.supports_reasoning_content if hasattr(req, 'supports_reasoning_content') and req.supports_reasoning_content is not None else (capabilities_obj.supports_reasoning_content if capabilities_obj else False), - + # 参数支持 "max_tokens_param": req.max_tokens_param or (capabilities_obj.max_tokens_param if capabilities_obj else "max_tokens"), "supports_temperature": req.supports_temperature if hasattr(req, 'supports_temperature') and req.supports_temperature is not None else (capabilities_obj.supports_temperature if capabilities_obj else True), @@ -254,19 +256,19 @@ class LLMManager: "supports_frequency_penalty": req.supports_frequency_penalty if hasattr(req, 'supports_frequency_penalty') and req.supports_frequency_penalty is not None else (capabilities_obj.supports_frequency_penalty if capabilities_obj else False), "supports_presence_penalty": req.supports_presence_penalty if hasattr(req, 'supports_presence_penalty') and req.supports_presence_penalty is not None else (capabilities_obj.supports_presence_penalty if capabilities_obj else False), "supports_min_p": req.supports_min_p if hasattr(req, 'supports_min_p') and req.supports_min_p is not None else (capabilities_obj.supports_min_p if capabilities_obj else False), - + # 高级功能 "supports_response_format": req.supports_response_format if hasattr(req, 'supports_response_format') and req.supports_response_format is not None else (capabilities_obj.supports_response_format if capabilities_obj else True), "supports_tools": req.supports_tools if hasattr(req, 'supports_tools') and req.supports_tools is not None else (capabilities_obj.supports_tools if capabilities_obj else True), "supports_tool_choice": req.supports_tool_choice if hasattr(req, 'supports_tool_choice') and req.supports_tool_choice is not None else (capabilities_obj.supports_tool_choice if capabilities_obj else True), "supports_extra_body": req.supports_extra_body if hasattr(req, 'supports_extra_body') and req.supports_extra_body is not None else (capabilities_obj.supports_extra_body if capabilities_obj else True), "supports_stream_options": req.supports_stream_options if hasattr(req, 'supports_stream_options') and req.supports_stream_options is not None else (capabilities_obj.supports_stream_options if capabilities_obj else True), - + # 特殊参数 "supports_enable_thinking": req.supports_enable_thinking if hasattr(req, 'supports_enable_thinking') and req.supports_enable_thinking is not None else (capabilities_obj.supports_enable_thinking if capabilities_obj else False), "supports_thinking_budget": req.supports_thinking_budget if hasattr(req, 'supports_thinking_budget') and req.supports_thinking_budget is not None else (capabilities_obj.supports_thinking_budget if capabilities_obj else False), "supports_enable_search": req.supports_enable_search if hasattr(req, 'supports_enable_search') and req.supports_enable_search is not None else (capabilities_obj.supports_enable_search if capabilities_obj else False), - + # 其他信息 "notes": req.notes or (capabilities_obj.notes if capabilities_obj and hasattr(capabilities_obj, 'notes') else ""), } @@ -328,9 +330,8 @@ class LLMManager: logger.error(err) raise ValueError(err) - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") - conv_collection = mongo.get_collection("conversation") + llm_collection = MongoDB.get_collection("llm") + conv_collection = MongoDB.get_collection("conversation") llm_config = await llm_collection.find_one({"_id": llm_id}) if not llm_config: @@ -379,9 +380,8 @@ class LLMManager: llm_id: str, ) -> str: """更新对话的LLM""" - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") - llm_collection = mongo.get_collection("llm") + conv_collection = MongoDB.get_collection("conversation") + llm_collection = MongoDB.get_collection("llm") # 如果llm_id为空,则使用系统默认chat模型 if not llm_id: @@ -441,8 +441,7 @@ class LLMManager: :param user_sub: 用户ID,为空时返回系统级别的模型 :return: embedding模型列表 """ - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") query = {"type": "embedding", "user_sub": user_sub} @@ -462,8 +461,7 @@ class LLMManager: :param user_sub: 用户ID :return: embedding模型列表 """ - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") # 使用$or查询同时获取系统模型和用户模型 query = { @@ -487,8 +485,7 @@ class LLMManager: :param user_sub: 用户ID,为空时返回系统级别的模型 :return: reranker模型列表 """ - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") query = {"type": "reranker", "user_sub": user_sub} @@ -508,8 +505,7 @@ class LLMManager: :param user_sub: 用户ID :return: reranker模型列表 """ - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") # 使用$or查询同时获取系统模型和用户模型 query = { @@ -542,14 +538,13 @@ class LLMManager: f"[LLMManager] 跳过系统{model_type}模型初始化(将使用jaccard算法作为默认)") return - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") # 推断模型能力 # 优先使用配置文件中明确指定的provider,如果没有则从endpoint推断 provider = getattr(model_config, 'provider', '') or getattr( model_config, 'backend', '') or get_provider_from_endpoint(model_config.endpoint) - + # 根据模型类型选择正确的ModelType from apps.llm.model_types import ModelType model_type_map = { @@ -559,11 +554,11 @@ class LLMManager: "function_call": ModelType.FUNCTION_CALL, } registry_model_type = model_type_map.get(model_type, ModelType.CHAT) - + # 从model_registry获取完整的模型能力 capabilities = model_registry.get_model_capabilities( provider, model_config.model, registry_model_type) - + # 获取图标 provider_icon = llm_provider_dict.get(provider, {}).get( "icon", getattr(model_config, 'icon', '')) @@ -581,39 +576,61 @@ class LLMManager: max_tokens=getattr(model_config, 'max_tokens', None), type=[model_type], # 使用列表格式 provider=provider, - + # 基础能力 - 使用getattr安全访问,适用于不同类型的能力对象 - supports_streaming=getattr(capabilities, 'supports_streaming', True) if capabilities else True, - supports_function_calling=getattr(capabilities, 'supports_function_calling', True) if capabilities else True, - supports_json_mode=getattr(capabilities, 'supports_json_mode', True) if capabilities else True, - supports_structured_output=getattr(capabilities, 'supports_structured_output', False) if capabilities else False, - + supports_streaming=getattr( + capabilities, 'supports_streaming', True) if capabilities else True, + supports_function_calling=getattr( + capabilities, 'supports_function_calling', True) if capabilities else True, + supports_json_mode=getattr( + capabilities, 'supports_json_mode', True) if capabilities else True, + supports_structured_output=getattr( + capabilities, 'supports_structured_output', False) if capabilities else False, + # 推理能力 - supports_thinking=getattr(capabilities, 'supports_thinking', False) if capabilities else False, - can_toggle_thinking=getattr(capabilities, 'can_toggle_thinking', False) if capabilities else False, - supports_reasoning_content=getattr(capabilities, 'supports_reasoning_content', False) if capabilities else False, - + supports_thinking=getattr( + capabilities, 'supports_thinking', False) if capabilities else False, + can_toggle_thinking=getattr( + capabilities, 'can_toggle_thinking', False) if capabilities else False, + supports_reasoning_content=getattr( + capabilities, 'supports_reasoning_content', False) if capabilities else False, + # 参数支持 - max_tokens_param=getattr(capabilities, 'max_tokens_param', "max_tokens") if capabilities else "max_tokens", - supports_temperature=getattr(capabilities, 'supports_temperature', True) if capabilities else True, - supports_top_p=getattr(capabilities, 'supports_top_p', True) if capabilities else True, - supports_top_k=getattr(capabilities, 'supports_top_k', False) if capabilities else False, - supports_frequency_penalty=getattr(capabilities, 'supports_frequency_penalty', False) if capabilities else False, - supports_presence_penalty=getattr(capabilities, 'supports_presence_penalty', False) if capabilities else False, - supports_min_p=getattr(capabilities, 'supports_min_p', False) if capabilities else False, - + max_tokens_param=getattr( + capabilities, 'max_tokens_param', "max_tokens") if capabilities else "max_tokens", + supports_temperature=getattr( + capabilities, 'supports_temperature', True) if capabilities else True, + supports_top_p=getattr( + capabilities, 'supports_top_p', True) if capabilities else True, + supports_top_k=getattr( + capabilities, 'supports_top_k', False) if capabilities else False, + supports_frequency_penalty=getattr( + capabilities, 'supports_frequency_penalty', False) if capabilities else False, + supports_presence_penalty=getattr( + capabilities, 'supports_presence_penalty', False) if capabilities else False, + supports_min_p=getattr( + capabilities, 'supports_min_p', False) if capabilities else False, + # 高级功能 - supports_response_format=getattr(capabilities, 'supports_response_format', True) if capabilities else True, - supports_tools=getattr(capabilities, 'supports_tools', True) if capabilities else True, - supports_tool_choice=getattr(capabilities, 'supports_tool_choice', True) if capabilities else True, - supports_extra_body=getattr(capabilities, 'supports_extra_body', True) if capabilities else True, - supports_stream_options=getattr(capabilities, 'supports_stream_options', True) if capabilities else True, - + supports_response_format=getattr( + capabilities, 'supports_response_format', True) if capabilities else True, + supports_tools=getattr( + capabilities, 'supports_tools', True) if capabilities else True, + supports_tool_choice=getattr( + capabilities, 'supports_tool_choice', True) if capabilities else True, + supports_extra_body=getattr( + capabilities, 'supports_extra_body', True) if capabilities else True, + supports_stream_options=getattr( + capabilities, 'supports_stream_options', True) if capabilities else True, + # 特殊参数 - supports_enable_thinking=getattr(capabilities, 'supports_enable_thinking', False) if capabilities else False, - supports_thinking_budget=getattr(capabilities, 'supports_thinking_budget', False) if capabilities else False, - supports_enable_search=getattr(capabilities, 'supports_enable_search', False) if capabilities else False, - + supports_enable_thinking=getattr( + capabilities, 'supports_enable_thinking', False) if capabilities else False, + supports_thinking_budget=getattr( + capabilities, 'supports_thinking_budget', False) if capabilities else False, + supports_enable_search=getattr( + capabilities, 'supports_enable_search', False) if capabilities else False, + # 其他信息 notes=getattr(capabilities, 'notes', "") if capabilities else "", ) @@ -640,8 +657,7 @@ class LLMManager: :return: function call模型ID或chat模型ID,如果都不存在则返回None """ try: - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") # 🔑 第一优先级:应用配置的模型(最高优先级) if app_llm_id: @@ -732,8 +748,7 @@ class LLMManager: 在初始化之前,先清理所有系统模型 """ config = Config().get_config() - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") # 清理所有系统模型(user_sub为空的模型) delete_result = await llm_collection.delete_many({"user_sub": ""}) @@ -781,8 +796,7 @@ class LLMManager: from apps.llm.model_types import ModelType # 获取模型信息(支持系统模型和用户模型) - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") + llm_collection = MongoDB.get_collection("llm") result = await llm_collection.find_one({ "_id": llm_id, diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index c3924ef0994010b90c361f73f33a59b2dd85ddfd..5ad320d33f53223efeceb4de2d4998a445fe5c29 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -51,8 +51,7 @@ class MCPServiceManager: :param str mcp_id: MCP服务ID :return: 是否激活 """ - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") mcp_list = await mcp_collection.find({"_id": mcp_id}, {"activated": True}).to_list(None) return any(user_sub in db_item.get("activated", []) for db_item in mcp_list) @@ -64,8 +63,7 @@ class MCPServiceManager: :param str mcp_id: MCP服务ID :return: MCP服务状态 """ - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") mcp_list = await mcp_collection.find({"_id": mcp_id}, {"status": True}).to_list(None) for db_item in mcp_list: status = db_item.get("status") @@ -131,7 +129,7 @@ class MCPServiceManager: :return: MCP服务详细信息 """ # 验证用户权限 - mcpservice_collection = MongoDB().get_collection("mcp") + mcpservice_collection = MongoDB.get_collection("mcp") db_service = await mcpservice_collection.find_one({"_id": mcpservice_id}) if not db_service: msg = "[MCPServiceManager] MCP服务未找到" @@ -161,7 +159,7 @@ class MCPServiceManager: :return: MCP工具详细信息列表 """ # 获取服务名称 - service_collection = MongoDB().get_collection("mcp") + service_collection = MongoDB.get_collection("mcp") data = await service_collection.find({"_id": service_id}, {"tools": True}).to_list(None) result = [] for item in data: @@ -183,7 +181,7 @@ class MCPServiceManager: :param page: int: 页码 :return: MCP列表 """ - mcpservice_collection = MongoDB().get_collection("mcp") + mcpservice_collection = MongoDB.get_collection("mcp") # 分页查询 skip = (page - 1) * SERVICE_PAGE_SIZE db_mcpservices = await mcpservice_collection.find(search_conditions).skip(skip).limit( @@ -208,7 +206,7 @@ class MCPServiceManager: :param page: int: 页码 :return: (MCP列表, 总数量) """ - mcpservice_collection = MongoDB().get_collection("mcp") + mcpservice_collection = MongoDB.get_collection("mcp") # 获取总数量 total_count = await mcpservice_collection.count_documents(search_conditions) @@ -280,7 +278,7 @@ class MCPServiceManager: ) # 检查是否存在相同服务 - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") db_service = await mcp_collection.find_one({"name": mcp_server.name}) mcp_id = sqids.encode([random.randint(0, 1000000) for _ in range(5)])[:6] # noqa: S311 if db_service: @@ -321,7 +319,7 @@ class MCPServiceManager: msg = "[MCPServiceManager] MCP服务ID为空" raise ValueError(msg) - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") db_service = await mcp_collection.find_one({"_id": data.service_id, "author": user_sub}) if not db_service: msg = "[MCPServiceManager] MCP服务未找到或无权限" @@ -369,7 +367,7 @@ class MCPServiceManager: await MCPLoader.delete_mcp(service_id) # 遍历所有应用,将其中的MCP依赖删除 - app_collection = MongoDB().get_collection("application") + app_collection = MongoDB.get_collection("application") await app_collection.update_many( {"mcp_service": service_id}, {"$pull": {"mcp_service": service_id}}, @@ -388,7 +386,7 @@ class MCPServiceManager: :param service_id: str: MCP服务ID :return: 无 """ - mcp_collection = MongoDB().get_collection("mcp") + mcp_collection = MongoDB.get_collection("mcp") status = await mcp_collection.find({"_id": service_id}, {"status": 1}).to_list() for item in status: mcp_status = item.get("status", MCPInstallStatus.INSTALLING) @@ -466,7 +464,7 @@ class MCPServiceManager: :param install: bool: 是否安装 :return: 无 """ - service_collection = MongoDB().get_collection("mcp") + service_collection = MongoDB.get_collection("mcp") db_service = await service_collection.find_one({"_id": service_id, "author": user_sub}) db_service = MCPCollection.model_validate(db_service) if install: diff --git a/apps/services/node.py b/apps/services/node.py index a09af2e4f46f6c84f8b279a5b23b45c8979c4d70..9a3505a9b2bbc2c5114d855a96f0ead17a7683d2 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -31,7 +31,7 @@ class NodeManager: return SpecialCallType.PLUGIN.value # 其他节点类型:从数据库查询 - node_collection = MongoDB().get_collection("node") + node_collection = MongoDB.get_collection("node") node = await node_collection.find_one({"_id": node_id}, {"call_id": 1}) if not node: err = f"[NodeManager] Node call_id {node_id} not found." @@ -41,7 +41,7 @@ class NodeManager: @staticmethod async def get_node(node_id: str) -> NodePool: """获取Node的类型""" - node_collection = MongoDB().get_collection("node") + node_collection = MongoDB.get_collection("node") node = await node_collection.find_one({"_id": node_id}) if not node: err = f"[NodeManager] Node {node_id} not found." @@ -51,7 +51,7 @@ class NodeManager: @staticmethod async def get_node_name(node_id: str) -> str: """获取node的名称""" - node_collection = MongoDB().get_collection("node") + node_collection = MongoDB.get_collection("node") # 查询 Node 集合获取对应的 name node_doc = await node_collection.find_one({"_id": node_id}, {"name": 1}) if not node_doc: @@ -126,7 +126,7 @@ class NodeManager: # 查找Node信息 logger.info("[NodeManager] 获取节点 %s", node_id) - node_collection = MongoDB().get_collection("node") + node_collection = MongoDB.get_collection("node") node = await node_collection.find_one({"_id": node_id}) if not node: err = f"[NodeManager] Node {node_id} not found." diff --git a/apps/services/predecessor_cache_service.py b/apps/services/predecessor_cache_service.py index 0e88c6bd3527529a05726f670644fc3e77281893..57f1429070c043ae7c4870d3574445de089b8b15 100644 --- a/apps/services/predecessor_cache_service.py +++ b/apps/services/predecessor_cache_service.py @@ -361,7 +361,7 @@ class PredecessorCacheService: try: from apps.common.mongo import MongoDB - app_collection = MongoDB().get_collection("app") + app_collection = MongoDB.get_collection("app") # 查询包含此flow_id的app,同时获取app_id app_record = await app_collection.find_one( diff --git a/apps/services/record.py b/apps/services/record.py index b58df7453c0d02e93de1230560f02f98891e06fe..054b0387a796f78f564f78212328fabf2eee28a4 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -14,17 +14,14 @@ from apps.schemas.enum_var import FlowStatus logger = logging.getLogger(__name__) - - class RecordManager: """问答对相关操作""" @staticmethod async def create_record_group(group_id: str, user_sub: str, conversation_id: str) -> str | None: """创建问答组""" - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - conversation_collection = mongo.get_collection("conversation") + record_group_collection = MongoDB.get_collection("record_group") + conversation_collection = MongoDB.get_collection("conversation") record_group = RecordGroup( _id=group_id, user_sub=user_sub, @@ -32,7 +29,7 @@ class RecordManager: ) try: - async with mongo.get_session() as session, await session.start_transaction(): + async with MongoDB.get_session() as session, await session.start_transaction(): # RecordGroup里面加一条记录 await record_group_collection.insert_one(record_group.model_dump(by_alias=True), session=session) # Conversation里面加一个ID @@ -48,8 +45,7 @@ class RecordManager: @staticmethod async def insert_record_data_into_record_group(user_sub: str, group_id: str, record: Record) -> str | None: """加密问答对,并插入MongoDB中的特定问答组""" - mongo = MongoDB() - group_collection = mongo.get_collection("record_group") + group_collection = MongoDB.get_collection("record_group") try: await group_collection.update_one( {"_id": group_id, "user_sub": user_sub}, @@ -79,8 +75,7 @@ class RecordManager: """ sort_order = -1 if order == "desc" else 1 - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") try: # 得到conversation的全部record_group id record_groups = await record_group_collection.aggregate( @@ -105,7 +100,8 @@ class RecordManager: ) record = await record.to_list(length=1) if not record: - logger.info("[RecordManager] 问答组 %s 没有问答对", record_group_id) + logger.info("[RecordManager] 问答组 %s 没有问答对", + record_group_id) continue records.append(Record.model_validate(record[0]["records"])) @@ -124,7 +120,7 @@ class RecordManager: 包含全部record_group及其关联的record """ - record_group_collection = MongoDB().get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") try: pipeline = [ {"$match": {"conversation_id": conversation_id}}, @@ -134,7 +130,7 @@ class RecordManager: pipeline.append({"$limit": total_pairs}) records = await record_group_collection.aggregate(pipeline) - + return [RecordGroup.model_validate(record) async for record in records] except Exception: logger.exception("[RecordManager] 查询问答组失败") @@ -143,11 +139,13 @@ class RecordManager: @staticmethod async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None: """更新Record关联的Flow状态""" - record_group_collection = MongoDB().get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") try: await record_group_collection.update_many( - {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, - {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, + {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": { + "$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, + {"$set": { + "records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], ) except Exception: @@ -162,7 +160,7 @@ class RecordManager: :return: 记录是否存在 """ try: - record_group_collection = MongoDB().get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") record_data = await record_group_collection.find_one( {"_id": group_id, "user_sub": user_sub, "records.id": record_id}, ) @@ -174,7 +172,7 @@ class RecordManager: @staticmethod async def check_group_id(group_id: str, user_sub: str) -> bool: """检查group_id是否存在""" - record_group_collection = MongoDB().get_collection("record_group") + record_group_collection = MongoDB.get_collection("record_group") try: result = await record_group_collection.find_one({"_id": group_id, "user_sub": user_sub}) return bool(result) diff --git a/apps/services/service.py b/apps/services/service.py index 689d49cc0f27f61c4d0d4537af69a9463fceb67d..4e501a65d64f4987a635fc64160e2ef309d841ba 100644 --- a/apps/services/service.py +++ b/apps/services/service.py @@ -122,7 +122,7 @@ class ServiceCenterManager: # 校验 OpenAPI 规范的 JSON Schema validated_data = await ServiceCenterManager._validate_service_data(data) # 检查是否存在相同服务 - service_collection = MongoDB().get_collection("service") + service_collection = MongoDB.get_collection("service") db_service = await service_collection.find_one( { "name": validated_data.id, @@ -155,7 +155,7 @@ class ServiceCenterManager: ) -> str: """更新服务""" # 验证用户权限 - service_collection = MongoDB().get_collection("service") + service_collection = MongoDB.get_collection("service") db_service = await service_collection.find_one({"_id": service_id}) if not db_service: msg = "Service not found" @@ -186,14 +186,14 @@ class ServiceCenterManager: ) -> tuple[str, list[ServiceApiData]]: """获取服务API列表""" # 获取服务名称 - service_collection = MongoDB().get_collection("service") + service_collection = MongoDB.get_collection("service") db_service = await service_collection.find_one({"_id": service_id}) if not db_service: msg = "Service not found" raise ServiceIDError(msg) service_pool_store = ServicePool.model_validate(db_service) # 根据 service_id 获取 API 列表 - node_collection = MongoDB().get_collection("node") + node_collection = MongoDB.get_collection("node") db_nodes = await node_collection.find({"service_id": service_id}).to_list() api_list: list[ServiceApiData] = [] for db_node in db_nodes: @@ -216,7 +216,7 @@ class ServiceCenterManager: ) -> tuple[str, dict[str, Any]]: """获取服务数据""" # 验证用户权限 - service_collection = MongoDB().get_collection("service") + service_collection = MongoDB.get_collection("service") match_conditions = [ {"author": user_sub}, {"permission.type": PermissionType.PUBLIC.value}, @@ -249,7 +249,7 @@ class ServiceCenterManager: service_id: str, ) -> ServiceMetadata: """获取服务元数据""" - service_collection = MongoDB().get_collection("service") + service_collection = MongoDB.get_collection("service") match_conditions = [ {"author": user_sub}, {"permission.type": PermissionType.PUBLIC.value}, @@ -279,8 +279,8 @@ class ServiceCenterManager: service_id: str, ) -> bool: """删除服务""" - service_collection = MongoDB().get_collection("service") - user_collection = MongoDB().get_collection("user") + service_collection = MongoDB.get_collection("service") + user_collection = MongoDB.get_collection("user") db_service = await service_collection.find_one({"_id": service_id}) if not db_service: msg = "[ServiceCenterManager] Service未找到" @@ -308,8 +308,8 @@ class ServiceCenterManager: favorited: bool, ) -> bool: """修改收藏状态""" - service_collection = MongoDB().get_collection("service") - user_collection = MongoDB().get_collection("user") + service_collection = MongoDB.get_collection("service") + user_collection = MongoDB.get_collection("user") db_service = await service_collection.find_one({"_id": service_id}) if not db_service: msg = f"[ServiceCenterManager] Service未找到: {service_id}" @@ -343,7 +343,7 @@ class ServiceCenterManager: page_size: int, ) -> tuple[list[ServicePool], int]: """基于输入条件获取服务数据""" - service_collection = MongoDB().get_collection("service") + service_collection = MongoDB.get_collection("service") # 获取服务总数 total = await service_collection.count_documents(search_conditions) # 分页查询 @@ -358,7 +358,7 @@ class ServiceCenterManager: @staticmethod async def _get_favorite_service_ids_by_user(user_sub: str) -> list[str]: """获取用户收藏的服务ID""" - user_collection = MongoDB().get_collection("user") + user_collection = MongoDB.get_collection("user") user_doc = await user_collection.find_one({"_id": user_sub}) if user_doc is None: # 用户不存在,返回空的收藏列表 diff --git a/apps/services/session.py b/apps/services/session.py index ebb573218744a345f22193218833d79464aa60f7..e5a3d7026950709d90ae07dd9d07e8399c38d544 100644 --- a/apps/services/session.py +++ b/apps/services/session.py @@ -41,11 +41,11 @@ class SessionManager: if user_sub is not None: data.user_sub = user_sub - + if user_name is not None: data.user_name = user_name - collection = MongoDB().get_collection("session") + collection = MongoDB.get_collection("session") await collection.insert_one(data.model_dump(exclude_none=True, by_alias=True)) await collection.create_index( "expired_at", expireAfterSeconds=0, @@ -57,7 +57,7 @@ class SessionManager: """删除浏览器Session""" if not session_id: return - collection = MongoDB().get_collection("session") + collection = MongoDB.get_collection("session") await collection.delete_one({"_id": session_id}) @staticmethod @@ -67,8 +67,7 @@ class SessionManager: return await SessionManager.create_session(session_ip) ip = None - mongo = MongoDB() - collection = mongo.get_collection("session") + collection = MongoDB.get_collection("session") data = await collection.find_one({"_id": session_id}) if not data: return await SessionManager.create_session(session_ip) @@ -81,8 +80,7 @@ class SessionManager: @staticmethod async def verify_user(session_id: str) -> bool: """验证用户是否在Session中""" - mongo = MongoDB() - collection = mongo.get_collection("session") + collection = MongoDB.get_collection("session") data = await collection.find_one({"_id": session_id}) if not data: return False @@ -91,8 +89,7 @@ class SessionManager: @staticmethod async def get_user(session_id: str) -> str | None: """从Session中获取用户""" - mongo = MongoDB() - collection = mongo.get_collection("session") + collection = MongoDB.get_collection("session") data = await collection.find_one({"_id": session_id}) if not data: return None @@ -109,8 +106,7 @@ class SessionManager: @staticmethod async def get_session_by_user_sub(user_sub: str) -> str | None: """根据用户sub获取Session""" - mongo = MongoDB() - collection = mongo.get_collection("session") + collection = MongoDB.get_collection("session") data = await collection.find_one({"user_sub": user_sub}) if not data: return None diff --git a/apps/services/task.py b/apps/services/task.py index 0f9438ba0babdba48c5dc0fa5c35accc1b610d57..fb42c9efd3f2b270ee8fb03eeb6f4b82a4615541 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -37,7 +37,7 @@ class TaskManager: task_id = last_group.task_id # 查询最后一条问答组关联的任务 - task_collection = MongoDB().get_collection("task") + task_collection = MongoDB.get_collection("task") task = await task_collection.find_one({"_id": task_id}) if not task: # 任务不存在,新建Task @@ -49,8 +49,8 @@ class TaskManager: @staticmethod async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None: """获取组ID的最后一条问答组关联的任务""" - task_collection = MongoDB().get_collection("task") - record_group_collection = MongoDB().get_collection("record_group") + task_collection = MongoDB.get_collection("task") + record_group_collection = MongoDB.get_collection("record_group") record_group = await record_group_collection.find_one({"conversation_id": conversation_id, "_id": group_id}) if not record_group: return None @@ -61,7 +61,7 @@ class TaskManager: @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" - task_collection = MongoDB().get_collection("task") + task_collection = MongoDB.get_collection("task") task = await task_collection.find_one({"_id": task_id}) if not task: return None @@ -70,8 +70,8 @@ class TaskManager: @staticmethod async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: """根据record_group_id获取flow信息""" - record_group_collection = MongoDB().get_collection("record_group") - flow_context_collection = MongoDB().get_collection("flow_context") + record_group_collection = MongoDB.get_collection("record_group") + flow_context_collection = MongoDB.get_collection("flow_context") try: record_group = await record_group_collection.aggregate([ {"$match": {"_id": record_group_id}}, @@ -97,8 +97,8 @@ class TaskManager: @staticmethod async def get_context_by_record_ids(record_group_ids: List[str], record_ids: List[str]) -> List[FlowStepHistory]: """根据record_group_ids获取flow信息""" - record_group_collection = MongoDB().get_collection("record_group") - flow_context_collection = MongoDB().get_collection("flow_context") + record_group_collection = MongoDB.get_collection("record_group") + flow_context_collection = MongoDB.get_collection("flow_context") flow_context_list = [] # 查询所有符合条件的记录 try: @@ -123,7 +123,7 @@ class TaskManager: @staticmethod async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[FlowStepHistory]: """根据task_id获取flow信息""" - flow_context_collection = MongoDB().get_collection("flow_context") + flow_context_collection = MongoDB.get_collection("flow_context") flow_context = [] try: @@ -165,7 +165,7 @@ class TaskManager: @staticmethod async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" - flow_context_collection = MongoDB().get_collection("flow_context") + flow_context_collection = MongoDB.get_collection("flow_context") try: # 删除旧的flow_context await flow_context_collection.delete_many({"task_id": task_id}) @@ -182,8 +182,7 @@ class TaskManager: @staticmethod async def delete_task_by_task_id(task_id: str) -> None: """通过task_id删除Task信息""" - mongo = MongoDB() - task_collection = mongo.get_collection("task") + task_collection = MongoDB.get_collection("task") task = await task_collection.find_one({"_id": task_id}, {"_id": 1}) if task: @@ -192,8 +191,7 @@ class TaskManager: @staticmethod async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]: """通过ConversationID删除Task信息""" - mongo = MongoDB() - task_collection = mongo.get_collection("task") + task_collection = MongoDB.get_collection("task") task_ids = [] try: async for task in task_collection.find( @@ -211,11 +209,10 @@ class TaskManager: @staticmethod async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" - mongo = MongoDB() - task_collection = mongo.get_collection("task") - flow_context_collection = mongo.get_collection("flow_context") + task_collection = MongoDB.get_collection("task") + flow_context_collection = MongoDB.get_collection("flow_context") - async with mongo.get_session() as session, await session.start_transaction(): + async with MongoDB.get_session() as session, await session.start_transaction(): task_ids = [ task["_id"] async for task in task_collection.find( {"conversation_id": conversation_id}, @@ -229,7 +226,7 @@ class TaskManager: @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" - task_collection = MongoDB().get_collection("task") + task_collection = MongoDB.get_collection("task") # 更新已有的Task记录 await task_collection.update_one( diff --git a/apps/services/token.py b/apps/services/token.py index 8c83d3aa2710fd978f62ebaf59b81fa5b2ed4975..2ec4185d0fc4782fc3845706c717d69d635a8fec 100644 --- a/apps/services/token.py +++ b/apps/services/token.py @@ -28,8 +28,7 @@ class TokenManager: err = "用户不存在!" raise ValueError(err) - mongo = MongoDB() - collection = mongo.get_collection("session") + collection = MongoDB.get_collection("session") token_data = await collection.find_one({ "_id": f"{plugin_name}_token_{user_sub}", }) @@ -67,8 +66,7 @@ class TokenManager: expire_time: int, ) -> str | None: """生成插件Token""" - mongo = MongoDB() - collection = mongo.get_collection("session") + collection = MongoDB.get_collection("session") # 获取OIDC token oidc_token = await collection.find_one({ @@ -141,8 +139,7 @@ class TokenManager: @staticmethod async def delete_plugin_token(user_sub: str) -> None: """删除插件token""" - mongo = MongoDB() - collection = mongo.get_collection("token") + collection = MongoDB.get_collection("token") await collection.delete_many({ "user_sub": user_sub, "$or": [ diff --git a/apps/services/user.py b/apps/services/user.py index 8ba8ee2d2ef12887aca54884563d129839c4a4ef..b160521ecb780ffe274acd7d0b7b3a15a5333da7 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -20,33 +20,34 @@ class UserManager: async def _handle_admin_user_creation(user_sub: str, user_name: str) -> str: """处理管理员用户创建时的user_sub逻辑(仅适用于Authelia)""" from apps.common.config import Config - + config = Config().get_config() - + # 只有在使用Authelia provider时才应用此逻辑 if config.login.provider != "authelia": return user_sub - + # 检查是否启用了管理员配置且用户名匹配 if not config.admin.enable or user_name != config.admin.user_name: return user_sub - + # 检查数据库中是否已存在管理员用户 try: - mongo = MongoDB() - user_collection = mongo.get_collection("user") - + user_collection = MongoDB.get_collection("user") + existing_admin = await user_collection.find_one({"_id": config.admin.user_sub}) - + if existing_admin: # 数据库中已存在管理员用户,使用原始的user_sub - logger.info(f"[_handle_admin_user_creation] 管理员用户已存在,使用原始user_sub: {user_sub}") + logger.info( + f"[_handle_admin_user_creation] 管理员用户已存在,使用原始user_sub: {user_sub}") return user_sub else: # 数据库中不存在管理员用户,使用配置的管理员user_sub - logger.info(f"[_handle_admin_user_creation] 管理员用户不存在,使用配置的user_sub: {config.admin.user_sub}") + logger.info( + f"[_handle_admin_user_creation] 管理员用户不存在,使用配置的user_sub: {config.admin.user_sub}") return config.admin.user_sub - + except Exception as e: logger.error(f"[_handle_admin_user_creation] 检查管理员用户时出错: {e}") # 出错时使用原始的user_sub @@ -60,12 +61,11 @@ class UserManager: :param user_sub: 用户sub :param user_name: 用户名 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - + user_collection = MongoDB.get_collection("user") + # 管理员用户特殊处理:检查是否应该使用配置的管理员user_sub final_user_sub = await UserManager._handle_admin_user_creation(user_sub, user_name) - + await user_collection.insert_one(User( _id=final_user_sub, user_name=user_name, @@ -78,8 +78,7 @@ class UserManager: :return: 所有用户的sub列表 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") total = await user_collection.count_documents({}) - len(filter_user_subs) users = await user_collection.find( @@ -96,8 +95,7 @@ class UserManager: :param user_sub: 用户sub :return: 用户信息 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") user_data = await user_collection.find_one({"_id": user_sub}) return User(**user_data) if user_data else None @@ -110,8 +108,7 @@ class UserManager: :param data: 用户更新信息 :return: 是否更新成功 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") update_dict = { "$set": { "auto_execute": data.auto_execute, @@ -129,7 +126,6 @@ class UserManager: :param user_name: 用户名(仅在创建新用户时使用) :return: 更新后的用户信息 """ - mongo = MongoDB() user_data = await UserManager.get_userinfo_by_user_sub(user_sub) if not user_data: await UserManager.add_userinfo(user_sub, user_name) @@ -141,7 +137,7 @@ class UserManager: if refresh_revision: update_dict["$set"]["status"] = "init" # type: ignore[assignment] - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") result = await user_collection.update_one({"_id": user_sub}, update_dict) return result.modified_count > 0 @@ -153,8 +149,7 @@ class UserManager: :param login_time: 登录时间 :return: 用户sub列表 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") return [user["_id"] async for user in user_collection.find({"login_time": {"$lt": login_time}}, {"_id": 1})] @staticmethod @@ -164,8 +159,7 @@ class UserManager: :param user_sub: 用户sub """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") result = await user_collection.find_one_and_delete({"_id": user_sub}) if not result: return @@ -182,26 +176,29 @@ class UserManager: :param user_sub: 用户sub :param data: 用户偏好设置更新信息 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - + user_collection = MongoDB.get_collection("user") + # 构建更新字典,只更新非None的字段,使用别名字段名以保持与模型一致 preferences_update = {} if data.reasoning_model_preference is not None: - preferences_update["preferences.reasoningModelPreference"] = data.reasoning_model_preference.model_dump(by_alias=True) + preferences_update["preferences.reasoningModelPreference"] = data.reasoning_model_preference.model_dump( + by_alias=True) if data.embedding_model_preference is not None: - preferences_update["preferences.embeddingModelPreference"] = data.embedding_model_preference.model_dump(by_alias=True) + preferences_update["preferences.embeddingModelPreference"] = data.embedding_model_preference.model_dump( + by_alias=True) if data.reranker_preference is not None: - preferences_update["preferences.rerankerPreference"] = data.reranker_preference.model_dump(by_alias=True) + preferences_update["preferences.rerankerPreference"] = data.reranker_preference.model_dump( + by_alias=True) if data.function_call_model_preference is not None: - preferences_update["preferences.functionCallModelPreference"] = data.function_call_model_preference.model_dump(by_alias=True) + preferences_update["preferences.functionCallModelPreference"] = data.function_call_model_preference.model_dump( + by_alias=True) if data.search_method_preference is not None: preferences_update["preferences.searchMethodPreference"] = data.search_method_preference if data.chain_of_thought_preference is not None: preferences_update["preferences.chainOfThoughtPreference"] = data.chain_of_thought_preference if data.auto_execute_preference is not None: preferences_update["preferences.autoExecutePreference"] = data.auto_execute_preference - + if preferences_update: update_dict = {"$set": preferences_update} await user_collection.update_one({"_id": user_sub}, update_dict) @@ -214,43 +211,49 @@ class UserManager: :param user_sub: 用户sub :return: 用户偏好设置 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") user_data = await user_collection.find_one({"_id": user_sub}, {"preferences": 1}) if user_data and "preferences" in user_data: preferences_data = user_data["preferences"] - + # 数据迁移:将旧的下划线格式字段名转换为驼峰格式 migration_needed = False if "reasoning_model_preference" in preferences_data: - preferences_data["reasoningModelPreference"] = preferences_data.pop("reasoning_model_preference") + preferences_data["reasoningModelPreference"] = preferences_data.pop( + "reasoning_model_preference") migration_needed = True if "embedding_model_preference" in preferences_data: - preferences_data["embeddingModelPreference"] = preferences_data.pop("embedding_model_preference") + preferences_data["embeddingModelPreference"] = preferences_data.pop( + "embedding_model_preference") migration_needed = True if "reranker_preference" in preferences_data: - preferences_data["rerankerPreference"] = preferences_data.pop("reranker_preference") + preferences_data["rerankerPreference"] = preferences_data.pop( + "reranker_preference") migration_needed = True if "function_call_model_preference" in preferences_data: - preferences_data["functionCallModelPreference"] = preferences_data.pop("function_call_model_preference") + preferences_data["functionCallModelPreference"] = preferences_data.pop( + "function_call_model_preference") migration_needed = True if "search_method_preference" in preferences_data: - preferences_data["searchMethodPreference"] = preferences_data.pop("search_method_preference") + preferences_data["searchMethodPreference"] = preferences_data.pop( + "search_method_preference") migration_needed = True if "chain_of_thought_preference" in preferences_data: - preferences_data["chainOfThoughtPreference"] = preferences_data.pop("chain_of_thought_preference") + preferences_data["chainOfThoughtPreference"] = preferences_data.pop( + "chain_of_thought_preference") migration_needed = True if "auto_execute_preference" in preferences_data: - preferences_data["autoExecutePreference"] = preferences_data.pop("auto_execute_preference") + preferences_data["autoExecutePreference"] = preferences_data.pop( + "auto_execute_preference") migration_needed = True - + # 如果进行了迁移,更新数据库 if migration_needed: await user_collection.update_one( - {"_id": user_sub}, + {"_id": user_sub}, {"$set": {"preferences": preferences_data}} ) - + # 使用model_validate来处理从数据库读取的数据,这样会正确处理别名映射 return UserPreferences.model_validate(preferences_data) else: diff --git a/apps/services/user_domain.py b/apps/services/user_domain.py index 8ce47645dc7e72bd3d78fca8bb83e96837483bbe..e69026f926f01e30746fe268d1945a048373bddf 100644 --- a/apps/services/user_domain.py +++ b/apps/services/user_domain.py @@ -15,8 +15,7 @@ class UserDomainManager: @staticmethod async def get_user_domain_by_user_sub_and_topk(user_sub: str, topk: int) -> list[str]: """根据用户ID,查询用户最常涉及的n个领域""" - mongo = MongoDB() - user_collection = mongo.get_collection("user") + user_collection = MongoDB.get_collection("user") domains = await user_collection.aggregate( [ {"$project": {"_id": 1, "domains": 1}}, @@ -32,9 +31,8 @@ class UserDomainManager: @staticmethod async def update_user_domain_by_user_sub_and_domain_name(user_sub: str, domain_name: str) -> None: """增加特定用户特定领域的频次""" - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") - user_collection = mongo.get_collection("user") + domain_collection = MongoDB.get_collection("domain") + user_collection = MongoDB.get_collection("user") # 检查领域是否存在 domain = await domain_collection.find_one({"_id": domain_name})