diff --git a/.gitignore b/.gitignore index 11b4659f0719370830cd16dea54e5e212a4187a7..72be2dc78dfc6ba9933f7e9102645eacf4424aa9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ apps/utils/init *.bak apps/embedding -logs \ No newline at end of file +logs +test/run_api.py \ No newline at end of file diff --git a/apps/common/queue.py b/apps/common/queue.py index 0aac111dc6e54cdb5affd47f519718402b5ec9ac..eab5b7aab6ba361def2083f9cbdd4544debf2979 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -65,7 +65,7 @@ class MessageQueue: if not history_ids: # 如果new_history为空,则说明是第一次执行,创建一个空值 flow = MessageFlow( - appId=tcb.flow_state.plugin_id, + appId=tcb.flow_state.app_id, flowId=tcb.flow_state.name, stepId="start", stepStatus=StepStatus.RUNNING, @@ -75,7 +75,8 @@ class MessageQueue: history = tcb.flow_context[tcb.flow_state.step_id] flow = MessageFlow( - appId=history.plugin_id, + # TODO:appId 和 flowId 暂时使用flow_id + appId=history.flow_id, flowId=history.flow_id, stepId=history.step_id, stepStatus=history.status, @@ -86,9 +87,9 @@ class MessageQueue: message = MessageBase( event=event_type, id=tcb.record.id, - group_id=tcb.record.group_id, - conversation_id=tcb.record.conversation_id, - task_id=tcb.record.task_id, + groupId=tcb.record.group_id, + conversationId=tcb.record.conversation_id, + taskId=tcb.record.task_id, metadata=metadata, flow=flow, content=data, diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 6c45ce8c0a8a12f438b85f1dfff1ea614bfadb14..b12b5ba755e4508e81c026a7273ba425f40b2154 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -20,9 +20,10 @@ async def verify_user(request: HTTPConnection) -> None: :param request: HTTP请求 :return: """ - session_id = request.cookies["ECSESSION"] - if not await SessionManager.verify_user(session_id): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + pass + # session_id = request.cookies["ECSESSION"] + # if not await SessionManager.verify_user(session_id): + # raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") async def get_session(request: HTTPConnection) -> str: """验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401;参数级dependence @@ -35,18 +36,20 @@ async def get_session(request: HTTPConnection) -> str: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") return session_id -async def get_user(request: HTTPConnection) -> str: - """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence +# async def get_user(request: HTTPConnection) -> str: +# """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence - :param request: HTTP请求体 - :return: 用户sub - """ - session_id = request.cookies["ECSESSION"] - user = await SessionManager.get_user(session_id) - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - return user +# :param request: HTTP请求体 +# :return: 用户sub +# """ +# session_id = request.cookies["ECSESSION"] +# user = await SessionManager.get_user(session_id) +# if not user: +# raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") +# return user +async def get_user(request: HTTPConnection) -> str: + return "test" async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: """验证API Key是否有效;无效则抛出HTTP 401;接口级dependence diff --git a/apps/entities/flow.py b/apps/entities/flow.py index ad6caa1e5edc5a68a87f2ae45e6f36876f473537..c674afb1308205d937722d73a36a7438cc3ffa01 100644 --- a/apps/entities/flow.py +++ b/apps/entities/flow.py @@ -122,7 +122,7 @@ class AppLink(BaseModel): """App的相关链接""" title: str = Field(description="链接标题") - url: HttpUrl = Field(..., description="链接地址") + url: str = Field(..., description="链接地址") class Permission(BaseModel): diff --git a/apps/entities/plugin.py b/apps/entities/plugin.py index af30820824b768343d7943dcbc2eb3c6622f8996..52074931ae00471fe6d2d80b2ed16a9623f84022 100644 --- a/apps/entities/plugin.py +++ b/apps/entities/plugin.py @@ -7,7 +7,7 @@ from typing import Any from pydantic import BaseModel, Field from apps.common.queue import MessageQueue -from apps.entities.task import FlowHistory, RequestDataPlugin +from apps.entities.task import FlowHistory, RequestDataApp class SysCallVars(BaseModel): @@ -42,7 +42,7 @@ class SysExecVars(BaseModel): question: str = Field(description="当前Agent的目标") task_id: str = Field(description="当前Executor关联的TaskID") session_id: str = Field(description="当前用户的Session ID") - plugin_data: RequestDataPlugin = Field(description="传递给Executor中Call的参数") + App_data: RequestDataApp = Field(description="传递给Executor中Call的参数") background: ExecutorBackground = Field(description="当前Executor的背景信息") class Config: diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index 9061c4b68d1befb0e8c2f3c541a559744ef97741..50ab6a6769ccef503543c4002c3ee1dc9e26c5d0 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -22,11 +22,11 @@ class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" question: str = Field(max_length=2000, description="用户输入") - conversation_id: str - group_id: str + conversation_id: str = Field(default=None, alias="conversationId", description="会话ID") + group_id: Optional[str] = Field(default=None, alias="groupId", description="群组ID") language: str = Field(default="zh", description="语言") - files: list[str] = Field(default=[]) - apps: list[RequestDataApp] = Field(default=[]) + files: list[str] = Field(default=[], description="文件列表") + app: list[RequestDataApp] = Field(default=[], description="应用列表") features: RequestDataFeatures = Field(description="消息功能设置") diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py index e2c5580041005bf303de6d699ad3a0b56c429928..534ba0ef78e31fbe56ffedd89f368c5f323fdde9 100644 --- a/apps/entities/response_data.py +++ b/apps/entities/response_data.py @@ -284,7 +284,7 @@ class GetAppListMsg(BaseModel): """GET /api/app Result数据结构""" page_number: int = Field(..., alias="currentPage", description="当前页码") - page_count: int = Field(..., alias="totalPages", description="总页数") + app_count: int = Field(..., alias="total", description="总页数") applications: list[AppCenterCardItem] = Field(..., description="应用列表") diff --git a/apps/manager/appcenter.py b/apps/manager/appcenter.py index f94866bb338b77416e7957826fac725603ec3736..ea8c9cedcf99c75da50fab5986547ce47480b574 100644 --- a/apps/manager/appcenter.py +++ b/apps/manager/appcenter.py @@ -2,6 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. """ +from datetime import datetime, timezone import uuid from enum import Enum from typing import Any, Optional @@ -112,12 +113,16 @@ class AppCenterManager: "_id": {"$in": fav_app}, "published": True, } + print(base_filter) filters: dict[str, Any] = AppCenterManager._build_filters( base_filter, search_type, keyword, ) if keyword else base_filter + print(filters) + print(page, page_size) apps, total_pages = await AppCenterManager._search_apps_by_filter(filters, page, page_size) + print(apps) return [ AppCenterCardItem( appId=app.id, @@ -197,7 +202,8 @@ class AppCenterManager: } if published_false_needed: update_data["published"] = False - await app_collection.update_one({"_id": app_id}, {"$set": update_data}) + #TODO: 格式有問題 + await app_collection.update_one({"_id": app_id}, {"$set": jsonable_encoder(update_data)}) return True except Exception as e: LOGGER.error(f"[AppCenterManager] Update app failed: {e}") @@ -329,7 +335,9 @@ class AppCenterManager: try: app_collection = MongoDB.get_collection("app") total_apps = await app_collection.count_documents(search_conditions) - total_pages = (total_apps + page_size - 1) // page_size + # TODO: 暂时修改为 total_apps + total_pages = total_apps + # total_pages = (total_apps + page_size - 1) // page_size db_data = await app_collection.find(search_conditions) \ .sort("created_at", -1) \ .skip((page - 1) * page_size) \ @@ -351,3 +359,47 @@ class AppCenterManager: except Exception as e: LOGGER.info(f"[AppCenterManager] Get favorite app ids by user_sub failed: {e}") return [] + + @staticmethod + async def update_recent_app(user_sub: str, app_id: str) -> bool: + """更新用户的最近使用应用列表 + + :param user_sub: 用户唯一标识 + :param app_id: 应用唯一标识 + :return: 更新是否成功 + """ + try: + # 获取 user 集合 + user_collection = MongoDB.get_collection("user") + + # 获取当前时间戳 + current_time = round(datetime.now(tz=timezone.utc).timestamp(), 3) + + # 更新用户的 app_usage 字段 + result = await user_collection.update_one( + {"_id": user_sub}, # 查询条件 + { + "$set": { + f"app_usage.{app_id}.last_used": current_time # 更新最后使用时间 + }, + "$inc": { + f"app_usage.{app_id}.count": 1 # 增加使用次数 + } + }, + upsert=True # 如果 app_usage 字段或 app_id 不存在,则创建 + ) + + # 检查更新是否成功 + if result.modified_count > 0 or result.upserted_id is not None: + print("YES") + LOGGER.info(f"[AppCenterManager] Updated recent app for user {user_sub}: {app_id}") + return True + else: + print("NO") + LOGGER.warning(f"[AppCenterManager] No changes made for user {user_sub}") + return False + + except Exception as e: + print(e) + LOGGER.error(f"[AppCenterManager] Failed to update recent app: {e}") + return False \ No newline at end of file diff --git a/apps/manager/conversation.py b/apps/manager/conversation.py index f595d0c5b1103eea9e52319ba0304d7e7beac25d..f816ef6b05da4eeb925d949b4af595655539e769 100644 --- a/apps/manager/conversation.py +++ b/apps/manager/conversation.py @@ -31,6 +31,7 @@ class ConversationManager: try: conv_collection = MongoDB.get_collection("conversation") result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) + print(result, conversation_id, user_sub) if not result: return None return Conversation.model_validate(result) @@ -48,6 +49,7 @@ class ConversationManager: app_id=app_id, is_debug=is_debug, ) + print(conv) try: async with MongoDB.get_session() as session, await session.start_transaction(): conv_collection = MongoDB.get_collection("conversation") @@ -56,18 +58,20 @@ class ConversationManager: update_data: dict[str, dict[str, Any]] = { "$push": {"conversations": conversation_id}, } + if app_id: # 非调试模式下更新应用使用情况 - if not is_debug: - update_data["$set"] = {f"app_usage.{app_id}.last_used": round(datetime.now(timezone.utc).timestamp(), 3)} - update_data["$inc"] = {f"app_usage.{app_id}.count": 1} - await user_collection.update_one( - {"_id": user_sub}, - update_data, - session=session, - ) - await session.commit_transaction() + if not is_debug: + update_data["$set"] = {f"app_usage.{app_id}.last_used": round(datetime.now(timezone.utc).timestamp(), 3)} + update_data["$inc"] = {f"app_usage.{app_id}.count": 1} + await user_collection.update_one( + {"_id": user_sub}, + update_data, + session=session, + ) + await session.commit_transaction() return conv except Exception as e: + print(e) LOGGER.info(f"[ConversationManager] Add conversation by user_sub failed: {e}") return None diff --git a/apps/manager/record.py b/apps/manager/record.py index 54731ad9136fe3ce7ab5bb7777d6e1641ad6987e..a6079db35622f9a17e4d936a040c75fb5811db9b 100644 --- a/apps/manager/record.py +++ b/apps/manager/record.py @@ -37,6 +37,7 @@ class RecordManager: # Conversation里面加一个ID await conversation_collection.update_one({"_id": conversation_id}, {"$push": {"record_groups": group_id}}, session=session) except Exception as e: + print(e) LOGGER.info(f"Create record group failed: {e}") return None diff --git a/apps/models/mongo.py b/apps/models/mongo.py index 7cb254198995a7c8cafed8dc0076971542969fe7..27d1ac65d8ebe55103ffad231bbf9c6221fceb4b 100644 --- a/apps/models/mongo.py +++ b/apps/models/mongo.py @@ -21,8 +21,7 @@ class MongoDB: """MongoDB连接""" _client: AsyncMongoClient = AsyncMongoClient( - f"mongodb://{urllib.parse.quote_plus(config['MONGODB_USER'])}:{urllib.parse.quote_plus(config['MONGODB_PWD'])}@{config['MONGODB_HOST']}:{config['MONGODB_PORT']}/?directConnection=true&replicaSet=rs0", - ) + f"mongodb://{urllib.parse.quote_plus(config['MONGODB_USER'])}:{urllib.parse.quote_plus(config['MONGODB_PWD'])}@{config['MONGODB_HOST']}:{config['MONGODB_PORT']}/?directConnection=true&replicaSet=rs0", ) @classmethod def get_collection(cls, collection_name: str) -> AsyncCollection: diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 8ba2f16c30552788ab320ab91afbe4e0a841020b..45a1ee0fa7e88ad1335be1c8a87be410cac2f178 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -5,10 +5,11 @@ Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. from typing import Annotated, Optional, Union from fastapi import APIRouter, Body, Depends, Path, Query, status +from fastapi.requests import HTTPConnection from fastapi.responses import JSONResponse -from apps.dependency.csrf import verify_csrf_token -from apps.dependency.user import get_user, verify_user +# from apps.dependency.csrf import verify_csrf_token +# from apps.dependency.user import get_user, verify_user from apps.entities.appcenter import AppPermissionData from apps.entities.enum_var import SearchType from apps.entities.request_data import CreateAppRequest, ModFavAppRequest @@ -30,9 +31,10 @@ from apps.manager.flow import FlowManager router = APIRouter( prefix="/api/app", tags=["appcenter"], - dependencies=[Depends(verify_user)], + # dependencies=[Depends(verify_user)], ) - +async def get_user(request: HTTPConnection) -> str: + return "test" @router.get("", response_model=Union[GetAppListRsp, ResponseData]) async def get_applications( # noqa: ANN201, PLR0913 @@ -53,34 +55,35 @@ async def get_applications( # noqa: ANN201, PLR0913 result={}, ) - app_cards, total_pages = [], -1 + app_cards, total_apps = [], -1 if my_app: # 筛选我创建的 - app_cards, total_pages = await AppCenterManager.fetch_user_apps( + app_cards, total_apps = await AppCenterManager.fetch_user_apps( user_sub, search_type, keyword, page, page_size) elif my_fav: # 筛选已收藏的 - app_cards, total_pages = await AppCenterManager.fetch_favorite_apps( + app_cards, total_apps = await AppCenterManager.fetch_favorite_apps( user_sub, search_type, keyword, page, page_size) else: # 获取所有应用 - app_cards, total_pages = await AppCenterManager.fetch_all_apps( + app_cards, total_apps = await AppCenterManager.fetch_all_apps( user_sub, search_type, keyword, page, page_size) - if total_pages == -1: + if total_apps == -1: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, message="查询失败", result={}, ).model_dump(exclude_none=True, by_alias=True)) + #TODO: 返回总量数 return JSONResponse(status_code=status.HTTP_200_OK, content=GetAppListRsp( code=status.HTTP_200_OK, message="查询成功", result=GetAppListMsg( currentPage=page, - totalPages=total_pages, + total=total_apps, applications=app_cards, ), ).model_dump(exclude_none=True, by_alias=True)) -@router.post("", dependencies=[Depends(verify_csrf_token)], response_model=Union[BaseAppOperationRsp, ResponseData]) +@router.post("", response_model=Union[BaseAppOperationRsp, ResponseData]) async def create_or_update_application( # noqa: ANN201 request: Annotated[CreateAppRequest, Body(...)], user_sub: Annotated[str, Depends(get_user)], @@ -124,6 +127,7 @@ async def get_recently_used_applications( # noqa: ANN201 """获取最近使用的应用""" recent_apps = await AppCenterManager.get_recently_used_apps( count, user_sub) + print(recent_apps) if recent_apps is None: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -173,7 +177,7 @@ async def get_application( # noqa: ANN201 @router.delete( "/{appId}", - dependencies=[Depends(verify_csrf_token)], + # dependencies=[Depends(verify_csrf_token)], response_model=Union[BaseAppOperationRsp, ResponseData], ) async def delete_application( # noqa: ANN201 @@ -217,7 +221,9 @@ async def delete_application( # noqa: ANN201 ).model_dump(exclude_none=True, by_alias=True)) -@router.post("/{appId}", dependencies=[Depends(verify_csrf_token)], response_model=BaseAppOperationRsp) +@router.post("/{appId}", + # dependencies=[Depends(verify_csrf_token)], + response_model=BaseAppOperationRsp) async def publish_application( # noqa: ANN201 app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], user_sub: Annotated[str, Depends(get_user)], @@ -251,7 +257,7 @@ async def publish_application( # noqa: ANN201 ).model_dump(exclude_none=True, by_alias=True)) -@router.put("/{appId}", dependencies=[Depends(verify_csrf_token)], response_model=Union[ModFavAppRsp, ResponseData]) +@router.put("/{appId}", response_model=Union[ModFavAppRsp, ResponseData]) async def modify_favorite_application( # noqa: ANN201 app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], request: Annotated[ModFavAppRequest, Body(...)], diff --git a/apps/routers/chat.py b/apps/routers/chat.py index ded2007c60cfbfc8358ec78f295705b776b9685c..3dced9e04ede85dab057654a3eef6c031c5dd77e 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -9,17 +9,19 @@ from collections.abc import AsyncGenerator from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.requests import HTTPConnection from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue from apps.common.wordscheck import WordsCheck from apps.constants import LOGGER -from apps.dependency import ( - get_session, - get_user, - verify_csrf_token, - verify_user, -) + +# from apps.dependency import ( +# # get_session, +# # get_user, +# # verify_csrf_token, +# # verify_user, +# ) from apps.entities.request_data import RequestData from apps.entities.response_data import ResponseData from apps.manager import ( @@ -27,6 +29,7 @@ from apps.manager import ( TaskManager, UserBlacklistManager, ) +from apps.manager.appcenter import AppCenterManager from apps.scheduler.scheduler import Scheduler from apps.service.activity import Activity @@ -36,7 +39,8 @@ router = APIRouter( prefix="/api", tags=["chat"], ) - +async def get_user(request: HTTPConnection) -> str: + return "test" async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" @@ -112,11 +116,14 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) await Activity.remove_active(user_sub) -@router.post("/chat", dependencies=[Depends(verify_csrf_token), Depends(verify_user)]) +@router.post("/chat", + # dependencies=[Depends(verify_csrf_token), Depends(verify_user)] + ) async def chat( post_body: RequestData, user_sub: Annotated[str, Depends(get_user)], - session_id: Annotated[str, Depends(get_session)], + # session_id: Annotated[str, Depends(get_session)], + session_id : str = "1234567890" ) -> StreamingResponse: """LLM流式对话接口""" # 问题黑名单检测 @@ -126,9 +133,12 @@ async def chat( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") # 限流检查 - if await Activity.is_active(user_sub): - raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") + # if await Activity.is_active(user_sub): + # raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") + # print(post_body.app) + if post_body.app and post_body.app[0].app_id: + await AppCenterManager.update_recent_app(user_sub, post_body.app[0].app_id) res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, @@ -139,7 +149,10 @@ async def chat( ) -@router.post("/stop", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +@router.post("/stop", response_model=ResponseData, + # dependencies=[Depends(verify_csrf_token)] + ) + async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """停止生成""" await Activity.remove_active(user_sub) diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 80f36d7ef660e51605501357c24980908386e73c..4b8d38068c5a0bc79cdc3b448e8f84d42856f94a 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -5,12 +5,13 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from datetime import datetime from typing import Annotated, Optional +from fastapi.requests import HTTPConnection import pytz from fastapi import APIRouter, Depends, Query, Request, status from fastapi.responses import JSONResponse from apps.constants import LOGGER -from apps.dependency import get_user, verify_csrf_token, verify_user +# from apps.dependency import get_user, verify_csrf_token, verify_user from apps.entities.collection import Audit, Conversation from apps.entities.request_data import ( DeleteConversationData, @@ -39,13 +40,15 @@ router = APIRouter( prefix="/api/conversation", tags=["conversation"], dependencies=[ - Depends(verify_user), + # Depends(verify_user), ], ) +async def get_user(request: HTTPConnection) -> str: + return "test" async def create_new_conversation( - user_sub: str, + user_sub: Annotated[str,Depends(get_user)], conv_list: list[Conversation], app_id: str = "", is_debug: bool = False, @@ -59,10 +62,11 @@ async def create_new_conversation( conv_records = await RecordManager.query_record_by_conversation_id(user_sub, last_conv.id, 1, "desc") if len(conv_records) > 0: create_new = True + # return last_conv # 新建对话 if create_new: - if not AppManager.validate_user_app_access(user_sub, app_id): + if app_id and not await AppManager.validate_user_app_access(user_sub, app_id): err = "Invalid app_id." raise RuntimeError(err) new_conv = await ConversationManager.add_conversation_by_user_sub(user_sub, @@ -123,17 +127,21 @@ async def get_conversation_list(user_sub: Annotated[str, Depends(get_user)]): # -@router.post("", dependencies=[Depends(verify_csrf_token)], response_model=AddConversationRsp) -async def add_conversation( +@router.post("", + # dependencies=[Depends(verify_csrf_token)], + response_model=AddConversationRsp) +async def add_conversation( # noqa: ANN201 user_sub: Annotated[str, Depends(get_user)], appId: Optional[str] = None, # noqa: N803 isDebug: Optional[bool] = None, # noqa: N803 -): +): """手动创建新对话""" conversations = await ConversationManager.get_conversation_by_user_sub(user_sub) # 尝试创建新对话 try: app_id = appId if appId else "" + if appId: + conversations = [] is_debug = isDebug if isDebug is not None else False new_conv = await create_new_conversation(user_sub, conversations, app_id=app_id, is_debug=is_debug) @@ -157,7 +165,9 @@ async def add_conversation( ).model_dump(exclude_none=True, by_alias=True)) -@router.put("", response_model=UpdateConversationRsp, dependencies=[Depends(verify_csrf_token)]) +@router.put("", response_model=UpdateConversationRsp, + # dependencies=[Depends(verify_csrf_token)] + ) async def update_conversation( # noqa: ANN201 post_body: ModifyConversationData, conversationId: Annotated[str, Query()], # noqa: N803 @@ -206,7 +216,9 @@ async def update_conversation( # noqa: ANN201 ) -@router.delete("", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +@router.delete("", response_model=ResponseData, + # dependencies=[Depends(verify_csrf_token)] + ) async def delete_conversation(request: Request, post_body: DeleteConversationData, user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """删除特定对话""" deleted_conversation = [] diff --git a/apps/routers/record.py b/apps/routers/record.py index 361adc749767e96ee62d90258bc7ba1e93a8112d..7c66e4c2a5deb5bd2d8ec8118e056b481ee717a5 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -6,10 +6,11 @@ import json from typing import Annotated from fastapi import APIRouter, Depends, status +from fastapi.requests import HTTPConnection from fastapi.responses import JSONResponse from apps.common.security import Security -from apps.dependency import get_user, verify_user +# from apps.dependency import get_user, verify_user from apps.entities.collection import ( RecordContent, ) @@ -28,14 +29,16 @@ router = APIRouter( prefix="/api/record", tags=["record"], dependencies=[ - Depends(verify_user), + # Depends(verify_user), ], ) - +async def get_user(request: HTTPConnection) -> str: + return "test" @router.get("/{conversation_id}", response_model=RecordListRsp, responses={status.HTTP_403_FORBIDDEN: {"model": ResponseData}}) async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """获取某个对话的所有问答对""" + print(user_sub, conversation_id) cur_conv = await ConversationManager.get_conversation_by_conversation_id(user_sub, conversation_id) # 判断conversation是否合法 if not cur_conv: diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 4804cce66cab3c540d4b92ea0951b18cbb7f097c..155203268f769e4d013e09adcaf944f5a6d401bc 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -4,12 +4,12 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Optional -from apps.entities.task import RequestDataPlugin +from apps.entities.task import RequestDataApp from apps.llm.patterns import Select from apps.scheduler.pool.pool import Pool -async def choose_flow(task_id: str, question: str, origin_plugin_list: list[RequestDataPlugin]) -> tuple[str, Optional[RequestDataPlugin]]: +async def choose_flow(task_id: str, question: str, origin_app: RequestDataApp) -> tuple[str, Optional[RequestDataApp]]: """依据用户的输入和选择,构造对应的Flow。 - 当用户没有选择任何Plugin时,直接进行智能问答 @@ -17,61 +17,8 @@ async def choose_flow(task_id: str, question: str, origin_plugin_list: list[Requ - 当用户选择Plugin时,在plugin内挑选最适合的flow :param question: 用户输入(用户问题) - :param origin_plugin_list: 用户选择的插件,可以一次选择多个 - :result: 经LLM选择的Plugin ID和Flow ID + :param origin_app: 用户选择的app信息 + :result: 经LLM选择的App ID和Flow ID """ - # 去掉无效的插件选项:plugin_id为空 - plugin_ids = [] - flow_ids = [] - for item in origin_plugin_list: - if not item.plugin_id: - continue - plugin_ids.append(item.plugin_id) - if item.flow_id: - flow_ids.append(item) - - # 用户什么都不选,直接智能问答 - if len(plugin_ids) == 0: - return "", None - - # 用户只选了auto - if len(plugin_ids) == 1 and plugin_ids[0] == "auto": - # 用户要求自动识别 - plugin_top = Pool().get_k_plugins(question) - # 聚合插件的Flow - plugin_ids = [str(plugin.name) for plugin in plugin_top] - - # 用户固定了Flow的ID - if len(flow_ids) > 0: - # 直接使用对应的Flow,不选择 - return plugin_ids[0], flow_ids[0] - - # 用户选了插件 - flows = Pool().get_k_flows(question, plugin_ids) - - # 使用大模型选择Top1 Flow - flow_list = [{ - "name": str(item.plugin) + "/" + str(item.name), - "description": str(item.description), - } for item in flows] - - if len(plugin_ids) == 1 and plugin_ids[0] == "auto": - # 用户选择自动识别时,包含智能问答 - flow_list += [{ - "name": "KnowledgeBase", - "description": "当上述工具无法直接解决用户问题时,使用知识库进行回答。", - }] - - # 返回top1 Flow的ID - selected_id = await Select().generate(task_id=task_id, choices=flow_list, question=question) - if selected_id == "KnowledgeBase": - return "", None - - plugin_id = selected_id.split("/")[0] - flow_id = selected_id.split("/")[1] - return plugin_id, RequestDataPlugin( - plugin_id=plugin_id, - flow_id=flow_id, - params={}, - auth={}, - ) + # TODO: 根据用户选择的App,选一次top_k flow + return "", None \ No newline at end of file diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index fb7bd87bfb52a87466e2e75670ff9b8911e12f89..9771b1081acb384872a46e87d31281a2612cdd85 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -34,17 +34,17 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques # 组装feature if is_flow: feature = InitContentFeature( - max_tokens=post_body.features.max_tokens, - context_num=post_body.features.context_num, - enable_feedback=False, - enable_regenerate=False, + maxTokens=post_body.features.max_tokens, + contextNum=post_body.features.context_num, + enableFeedback=False, + enableRegenerate=False, ) else: feature = InitContentFeature( - max_tokens=post_body.features.max_tokens, - context_num=post_body.features.context_num, - enable_feedback=True, - enable_regenerate=True, + maxTokens=post_body.features.max_tokens, + contextNum=post_body.features.context_num, + enableFeedback=True, + enableRegenerate=True, ) # 保存必要信息到Task @@ -54,7 +54,7 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques await TaskManager.set_task(task_id, task) # 推送初始化消息 - await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, created_at=created_at).model_dump(exclude_none=True, by_alias=True)) + await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True)) async def push_rag_message(task_id: str, queue: MessageQueue, user_sub: str, rag_data: RAGQueryReq) -> None: @@ -108,9 +108,9 @@ async def _push_rag_chunk(task_id: str, queue: MessageQueue, content: str, rag_i async def push_document_message(queue: MessageQueue, doc: Union[RecordDocument, Document]) -> None: """推送文档消息""" content = DocumentAddContent( - document_id=doc.id, - document_name=doc.name, - document_type=doc.type, - document_size=round(doc.size, 2), + documentId=doc.id, + documentName=doc.name, + documentType=doc.type, + documentSize=round(doc.size, 2), ) await queue.push_output(event_type=EventType.DOCUMENT_ADD, data=content.model_dump(exclude_none=True, by_alias=True)) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 4c078d079de5b28d68ce7e9a8daf688bc26ea8ff..bacd40ee5597411c78daa50463530700ad61f969 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -19,14 +19,14 @@ from apps.entities.plugin import ExecutorBackground, SysExecVars from apps.entities.rag_data import RAGQueryReq from apps.entities.record import RecordDocument from apps.entities.request_data import RequestData -from apps.entities.task import RequestDataPlugin +from apps.entities.task import RequestDataApp from apps.manager import ( DocumentManager, RecordManager, TaskManager, UserManager, ) -from apps.scheduler.executor import Executor +# from apps.scheduler.executor import Executor from apps.scheduler.scheduler.context import generate_facts, get_context from apps.scheduler.scheduler.flow import choose_flow from apps.scheduler.scheduler.message import ( @@ -73,7 +73,8 @@ class Scheduler: # 捕获所有异常:出现问题就输出日志,并停止queue try: # 根据用户的请求,返回插件ID列表,选择Flow - self._plugin_id, user_selected_flow = await choose_flow(self._task_id, post_body.question, post_body.plugins) + # self._plugin_id, user_selected_flow = await choose_flow(self._task_id, post_body.question, post_body.apps) + user_selected_flow = None # 获取当前问答可供关联的文档 docs, doc_ids = await self._get_docs(user_sub, post_body) # 获取上下文;最多20轮 @@ -89,7 +90,7 @@ class Scheduler: question=post_body.question, language=post_body.language, document_ids=doc_ids, - kb_sn=None if not user_info.kb_id else user_info.kb_id, + kb_sn=None if user_info is None or not user_info.kb_id else user_info.kb_id, history=context, top_k=5, ) @@ -121,14 +122,14 @@ class Scheduler: # 如果需要生成推荐问题,则生成 if need_recommend: routine_results = await asyncio.gather( - generate_facts(self._task_id, post_body.question), - plan_next_flow(user_sub, self._task_id, self._queue, post_body.plugins), + # generate_facts(self._task_id, post_body.question), + plan_next_flow(user_sub, self._task_id, self._queue, post_body.app[0]), ) - else: - routine_results = await asyncio.gather(generate_facts(self._task_id, post_body.question)) + # else: + # routine_results = await asyncio.gather(generate_facts(self._task_id, post_body.question)) - # 保存事实信息 - self._facts = routine_results[0] + # # 保存事实信息 + # self._facts = routine_results[0] # 发送结束消息 await self._queue.push_output(event_type=EventType.DONE, data={}) @@ -139,7 +140,7 @@ class Scheduler: await self._queue.close() - async def run_executor(self, session_id: str, post_body: RequestData, background: ExecutorBackground, user_selected_flow: RequestDataPlugin) -> bool: + async def run_executor(self, session_id: str, post_body: RequestData, background: ExecutorBackground, user_selected_flow: RequestDataApp) -> bool: """构造FlowExecutor,并执行所选择的流""" # 获取当前Task task = await TaskManager.get_task(self._task_id) @@ -201,7 +202,9 @@ class Scheduler: user_sub=user_sub, data=encrypt_data, key=encrypt_config, - facts=self._facts, + # facts=self._facts, + #TODO:暂停 + facts=[], metadata=task.record.metadata, created_at=task.record.created_at, flow=task.new_context, @@ -214,7 +217,7 @@ class Scheduler: if not record_group: LOGGER.error("[Scheduler] Create record group failed.") return - + print(record_group) # 修改文件状态 await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) # 保存Record diff --git a/apps/service/suggestion.py b/apps/service/suggestion.py index b23e95e020f64b328c28d1a73d5f73b14d424c6b..4e4deb9d5a20a91dd7d9305aa1d580aa59d07f71 100644 --- a/apps/service/suggestion.py +++ b/apps/service/suggestion.py @@ -11,7 +11,7 @@ from apps.constants import LOGGER from apps.entities.collection import RecordContent from apps.entities.enum_var import EventType from apps.entities.message import SuggestContent -from apps.entities.task import RequestDataPlugin +from apps.entities.task import RequestDataApp from apps.llm.patterns.recommend import Recommend from apps.manager import ( RecordManager, @@ -28,7 +28,7 @@ USER_TOP_DOMAINS_NUM = 5 HISTORY_QUESTIONS_NUM = 4 -async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_selected_plugins: list[RequestDataPlugin]) -> None: # noqa: C901, PLR0912 +async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_selected_plugins: RequestDataApp) -> None: # noqa: C901, PLR0912 """生成用户“下一步”Flow的推荐。 - 若Flow的配置文件中已定义`next_flow[]`字段,则直接使用该字段给定的值 @@ -69,99 +69,102 @@ async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_ generated_questions += f"{question}\n" content = SuggestContent( question=question, - plugin_id="", - flow_id="", - flow_description="", + appId="", + flowId="", + flowDescription="", ) await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) return # 当前使用了Flow flow_id = task.flow_state.name - plugin_id = task.flow_state.plugin_id - _, flow_data = Pool().get_flow(flow_id, plugin_id) - if flow_data is None: - err = "Flow数据不存在" - raise ValueError(err) - - if flow_data.next_flow is None: - # 根据用户选择的插件,选一次top_k flow - plugin_ids = [] - for plugin in user_selected_plugins: - if plugin.plugin_id and plugin.plugin_id not in plugin_ids: - plugin_ids.append(plugin.plugin_id) - result = Pool().get_k_flows(task.record.content.question, plugin_ids) - for i, flow in enumerate(result): - if i >= MAX_RECOMMEND: - break - # 改写问题 - rewrite_question = await Recommend().generate( - task_id=task_id, - action_description=flow.description, - history_questions=last_n_questions, - recent_question=current_record, - user_preference=str(user_domain), - shown_questions=generated_questions, - ) - generated_questions += f"{rewrite_question}\n" - - content = SuggestContent( - plugin_id=plugin_id, - flow_id=flow_id, - flow_description=str(flow.description), - question=rewrite_question, - ) - await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) - return - - # 当前有next_flow - for i, next_flow in enumerate(flow_data.next_flow): - # 取前MAX_RECOMMEND个Flow,保持顺序 - if i >= MAX_RECOMMEND: - break - - if next_flow.plugin is not None: - next_flow_plugin_id = next_flow.plugin - else: - next_flow_plugin_id = plugin_id - - flow_metadata, _ = Pool().get_flow( - next_flow.id, - next_flow_plugin_id, - ) - - # flow不合法 - if flow_metadata is None: - LOGGER.error(f"Flow {next_flow.id} in {next_flow_plugin_id} not found") - continue - - # 如果设置了question,直接使用这个question - if next_flow.question is not None: - content = SuggestContent( - plugin_id=next_flow_plugin_id, - flow_id=next_flow.id, - flow_description=str(flow_metadata.description), - question=next_flow.question, - ) - await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) - continue - - # 没有设置question,则需要生成问题 - rewrite_question = await Recommend().generate( - task_id=task_id, - action_description=flow_metadata.description, - history_questions=last_n_questions, - recent_question=current_record, - user_preference=str(user_domain), - shown_questions=generated_questions, - ) - generated_questions += f"{rewrite_question}\n" - content = SuggestContent( - plugin_id=next_flow_plugin_id, - flow_id=next_flow.id, - flow_description=str(flow_metadata.description), - question=rewrite_question, - ) - await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) - continue + app_id = task.flow_state.app_id return + # TODO: 推荐flow待完善 + # _, flow_data = Pool().get_flow(flow_id, app_id) + # if flow_data is None: + # err = "Flow数据不存在" + # raise ValueError(err) + + # if flow_data.next_flow is None: + # # 根据用户选择的插件,选一次top_k flow + # app_ids = [] + # for plugin in user_selected_plugins: + # if plugin.app_id and plugin.app_id not in app_ids: + # app_ids.append(plugin.app_id) + # result = Pool().get_k_flows(task.record.content.question, app_ids) + # for i, flow in enumerate(result): + # if i >= MAX_RECOMMEND: + # break + # # 改写问题 + # rewrite_question = await Recommend().generate( + # task_id=task_id, + # action_description=flow.description, + # history_questions=last_n_questions, + # recent_question=current_record, + # user_preference=str(user_domain), + # shown_questions=generated_questions, + # ) + # generated_questions += f"{rewrite_question}\n" + + # content = SuggestContent( + # app_id=app_id, + # flow_id=flow_id, + # flow_description=str(flow.description), + # question=rewrite_question, + # ) + # await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) + # return + + # # 当前有next_flow + # for i, next_flow in enumerate(flow_data.next_flow): + # # 取前MAX_RECOMMEND个Flow,保持顺序 + # if i >= MAX_RECOMMEND: + # break + + # if next_flow.plugin is not None: + # next_flow_app_id = next_flow.plugin + # else: + # next_flow_app_id = app_id + + # flow_metadata, _ = next_flow.id, next_flow_app_id, + # # flow_metadata, _ = Pool().get_flow( + # # next_flow.id, + # # next_flow_app_id, + # # ) + + # # flow不合法 + # if flow_metadata is None: + # LOGGER.error(f"Flow {next_flow.id} in {next_flow_app_id} not found") + # continue + + # # 如果设置了question,直接使用这个question + # if next_flow.question is not None: + # content = SuggestContent( + # appId=next_flow_app_id, + # flowId=next_flow.id, + # flowDescription=str(flow_metadata.description), + # question=next_flow.question, + # ) + # await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) + # continue + + # # 没有设置question,则需要生成问题 + # rewrite_question = await Recommend().generate( + # task_id=task_id, + # action_description=flow_metadata.description, + # history_questions=last_n_questions, + # recent_question=current_record, + # user_preference=str(user_domain), + # shown_questions=generated_questions, + # ) + # generated_questions += f"{rewrite_question}\n" + # content = SuggestContent( + # appId=next_flow_app_id, + # flowId=next_flow.id, + # flowDescription=str(flow_metadata.description), + # question=rewrite_question, + # ) + # await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) + # continue + # return