From 444171c2d1e8caa29cdc6f74af2c95dc57be9cfe Mon Sep 17 00:00:00 2001 From: EricCheng <12955029+ericchengscut@user.noreply.gitee.com> Date: Thu, 9 May 2024 17:07:14 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E5=90=91GPT=E5=8F=91=E9=80=81=E6=B6=88=E6=81=AF=E4=B8=8E?= =?UTF-8?q?=E6=8E=A5=E5=8F=97=E5=85=B6=E5=9B=9E=E5=A4=8D=E7=9A=84=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=88=E5=90=8C=E6=AD=A5=E5=BC=8F=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../entity/chatgpt/ChatGPTConfig.java | 2 +- .../entity/chatgpt/CustomChatGPT.java | 2 +- .../chatgpt/CustomEventSourceListener.java | 8 +++ pom.xml | 4 +- .../groupshell/websocket/WebSocketServer.java | 62 ++++++++++++------- 5 files changed, 52 insertions(+), 26 deletions(-) diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java index 87c6af2..ab495ab 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java @@ -36,7 +36,7 @@ public class ChatGPTConfig { private ChatCompletionResponse chatCompletionResponse; //暂时没用 private HttpLoggingInterceptor httpLoggingInterceptor; - public ChatGPTConfig(String customization) throws InterruptedException { + public ChatGPTConfig(String customization) { if (customization == OPENAI_CLIENT) { initHttpLoggingInterceptor(); initOkHttpClient(); diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java index f09701e..048d7d5 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java @@ -31,7 +31,7 @@ public class CustomChatGPT { //阻塞 // @Autowired - public CustomChatGPT(String customization) throws InterruptedException { + public CustomChatGPT(String customization) { chatGPTConfig = new ChatGPTConfig(customization); openAiClientType = customization; } diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java index 3a775fe..c938934 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java @@ -15,6 +15,8 @@ import static com.groupshell.constant.ChatGPTConstants.CHATGPT_DONE; public class CustomEventSourceListener extends ConsoleEventSourceListener { ChatGPTMessageHandler messageHandler = new ChatGPTMessageHandler(); + Boolean isChatGPTDone = false; + @Override public void onOpen(EventSource eventSource, Response response) { log.info("OpenAI建立sse连接..."); @@ -23,12 +25,14 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener { @Override public void onEvent(EventSource eventSource, String id, String type, String data) { //log.info("OpenAI返回数据:{}", data); + isChatGPTDone = false; String response = messageHandler.responseHandler(data); messageHandler.responses.add(response); if (data.equals(CHATGPT_DONE)) { log.info("OpenAI返回数据结束了"); messageHandler.printResponses(); messageHandler.responseFormat(messageHandler.responses); + isChatGPTDone = true; //eventSource.request(); return; } @@ -59,4 +63,8 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener { public ChatGPTMessageHandler getChatGPTMessageHandler() { return messageHandler; } + + public Boolean getChatGPTDone() { + return isChatGPTDone; + } } diff --git a/pom.xml b/pom.xml index f132296..bab936e 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ - + @@ -36,6 +36,6 @@ - + \ No newline at end of file diff --git a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java index b602062..495e8c3 100644 --- a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java +++ b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java @@ -5,6 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.groupshell.dto.MessageDTO; import com.groupshell.entity.Message; import com.groupshell.entity.UserGroup; +import com.groupshell.entity.chatgpt.CustomChatGPT; import com.groupshell.service.MessageService; import com.groupshell.service.UserGroupService; import com.groupshell.service.UserService; @@ -27,6 +28,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT; + /** * WebSocket服务 */ @@ -40,7 +43,7 @@ public class WebSocketServer private final MessageService messageService; //存放会话对象 private static final Map sessionMap=new ConcurrentHashMap<>(); - + private CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); public WebSocketServer(MessageService messageService,UserService userService,UserGroupService userGroupService) { @@ -104,27 +107,42 @@ public class WebSocketServer //如果是向gpt提问,还需要存储并推送gpt的回答 if(messageDTO.getGpt()) { - //todo 输入messageDTO.getContent(),输出gpt的回答answer - String answer=""; - //存数据库 - Message message2=Message.builder() - .groupId(messageDTO.getGroupId()) - .content(answer) - .createTime(LocalDateTime.now()) - .build(); - messageService.save(message2); - //向浏览器推送消息 - MessageVO messageVO2=MessageVO.builder() - .system(false) - .reminder(false) - .chat(true) - .username("GroupShellGPT") - .groupId(messageDTO.getGroupId()) - .content(answer) - .createTime(message2.getCreateTime() - .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm"))) - .build(); - send(JSON.toJSONString(messageVO2),userIds); + try { + //todo 输入messageDTO.getContent(),输出gpt的回答answer + //每次都new一个对象性能开销会不会很大 + //CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); + customChatGPT.sendMessageToChatGPT(messageDTO.getContent()); + String answer = ""; + while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) { + //若gpt回复还未结束,则等待其回复完整 + } + answer = customChatGPT.getChatGPTConfig() + .getEventSourceListener() + .getChatGPTMessageHandler().responsesFomatted; + //存数据库 + Message message2=Message.builder() + .groupId(messageDTO.getGroupId()) + .content(answer) + .createTime(LocalDateTime.now()) + .build(); + messageService.save(message2); + //向浏览器推送消息 + MessageVO messageVO2=MessageVO.builder() + .system(false) + .reminder(false) + .chat(true) + .username("GroupShellGPT") + .groupId(messageDTO.getGroupId()) + .content(answer) + .createTime(message2.getCreateTime() + .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm"))) + .build(); + send(JSON.toJSONString(messageVO2),userIds); + } + catch (Exception e) { + e.printStackTrace(); + } + } log.info("处理完成"); } -- Gitee From 8632f8c710aaa1f1cb0b1288e842c4240865d561 Mon Sep 17 00:00:00 2001 From: EricCheng <12955029+ericchengscut@user.noreply.gitee.com> Date: Thu, 9 May 2024 17:51:58 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E5=90=91GPT=E5=8F=91=E9=80=81=E6=B6=88=E6=81=AF=E4=B8=8E?= =?UTF-8?q?=E6=8E=A5=E5=8F=97=E5=85=B6=E5=9B=9E=E5=A4=8D=E7=9A=84=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=88=E5=9C=A8=E7=AD=89=E5=BE=85gpt=E5=9B=9E?= =?UTF-8?q?=E5=A4=8D=E6=97=B6=EF=BC=8C=E5=88=9B=E5=BB=BA=E6=96=B0=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=EF=BC=8C=E5=BC=82=E6=AD=A5=E7=AD=89=E5=BE=85=EF=BC=8C?= =?UTF-8?q?=E4=BB=A5=E5=85=8D=E9=98=BB=E5=A1=9E=E4=B8=BB=E7=A8=8B=E5=BA=8F?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../entity/chatgpt/ChatGPTMessageHandler.java | 15 +--- .../groupshell/websocket/WebSocketServer.java | 84 ++++++++++++------- 2 files changed, 57 insertions(+), 42 deletions(-) diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java index 4df66cc..1bf4335 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java @@ -19,14 +19,7 @@ public class ChatGPTMessageHandler { public String responseHandler(String jsonResponse) { -// // 使用 Fastjson 解析 JSON 数据 -// JSONObject jsonObject = JSON.parseObject(jsonResponse); -// -// // 提取 choices 数组中第一个元素的 delta 对象下的 content 字段 -// String content = jsonObject.getJSONArray("choices").getJSONObject(1) -// .getJSONObject("delta").getString("content"); -// //System.out.println(fullSentence.toString()); // 输出完整句子到控制台 -// return content; + if (Objects.equals(jsonResponse, CHATGPT_DONE)) { return null; } @@ -44,14 +37,14 @@ public class ChatGPTMessageHandler { return content; //responses.add(content); } else { - //System.out.println("Content is empty or not available."); + System.out.println("Content is empty or not available."); } } else { - //System.out.println("Delta is null or does not contain 'content'."); + System.out.println("Delta is null or does not contain 'content'."); } } } else { - //System.out.println("Choices array is empty."); + System.out.println("Choices array is empty."); } return null; } diff --git a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java index 495e8c3..79a4c9d 100644 --- a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java +++ b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java @@ -26,7 +26,10 @@ import java.time.format.DateTimeFormatter; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT; @@ -43,7 +46,7 @@ public class WebSocketServer private final MessageService messageService; //存放会话对象 private static final Map sessionMap=new ConcurrentHashMap<>(); - private CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); + private final CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); public WebSocketServer(MessageService messageService,UserService userService,UserGroupService userGroupService) { @@ -108,36 +111,55 @@ public class WebSocketServer if(messageDTO.getGpt()) { try { - //todo 输入messageDTO.getContent(),输出gpt的回答answer - //每次都new一个对象性能开销会不会很大 - //CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); - customChatGPT.sendMessageToChatGPT(messageDTO.getContent()); - String answer = ""; - while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) { - //若gpt回复还未结束,则等待其回复完整 - } - answer = customChatGPT.getChatGPTConfig() - .getEventSourceListener() - .getChatGPTMessageHandler().responsesFomatted; - //存数据库 - Message message2=Message.builder() - .groupId(messageDTO.getGroupId()) - .content(answer) - .createTime(LocalDateTime.now()) - .build(); - messageService.save(message2); - //向浏览器推送消息 - MessageVO messageVO2=MessageVO.builder() - .system(false) - .reminder(false) - .chat(true) - .username("GroupShellGPT") - .groupId(messageDTO.getGroupId()) - .content(answer) - .createTime(message2.getCreateTime() - .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm"))) - .build(); - send(JSON.toJSONString(messageVO2),userIds); + CompletableFuture future = CompletableFuture.supplyAsync(() -> { + customChatGPT.sendMessageToChatGPT(messageDTO.getContent()); + + while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) { + //轮询:性能开销比较大 + //若gpt回复还未结束,则等待其回复完整 + try { + Thread.sleep(100); // Reduce CPU usage with a small sleep + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException("Task interrupted", e); + } + } + String answer = customChatGPT.getChatGPTConfig() + .getEventSourceListener() + .getChatGPTMessageHandler().responsesFomatted; + return answer; + }); + + future.orTimeout(20, TimeUnit.SECONDS) //20秒超时 + .thenAccept(answer -> { + //存数据库 + Message message2=Message.builder() + .groupId(messageDTO.getGroupId()) + .content(answer) + .createTime(LocalDateTime.now()) + .build(); + messageService.save(message2); + //向浏览器推送消息 + MessageVO messageVO2=MessageVO.builder() + .system(false) + .reminder(false) + .chat(true) + .username("GroupShellGPT") + .groupId(messageDTO.getGroupId()) + .content(answer) + .createTime(message2.getCreateTime() + .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm"))) + .build(); + send(JSON.toJSONString(messageVO2),userIds); + }).exceptionally(e -> { + if (e instanceof TimeoutException) { + log.error("ChatGPT response timeout", e); + // Handle timeout scenario + } else { + log.error("处理 ChatGPT 回答时发生错误", e); + } + return null; + }); } catch (Exception e) { e.printStackTrace(); -- Gitee