diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java index 87c6af271cf54aa6e1872e2a59dd99314855029a..ab495abee4cdbdaa8e594e7a54633e484d0fd621 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java @@ -36,7 +36,7 @@ public class ChatGPTConfig { private ChatCompletionResponse chatCompletionResponse; //暂时没用 private HttpLoggingInterceptor httpLoggingInterceptor; - public ChatGPTConfig(String customization) throws InterruptedException { + public ChatGPTConfig(String customization) { if (customization == OPENAI_CLIENT) { initHttpLoggingInterceptor(); initOkHttpClient(); diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java index 4df66cc5d226cbcfe8cc73c3a1fcd710d028e321..1bf4335d5f730d2dc4dc4e68f8baa6481306f86b 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java @@ -19,14 +19,7 @@ public class ChatGPTMessageHandler { public String responseHandler(String jsonResponse) { -// // 使用 Fastjson 解析 JSON 数据 -// JSONObject jsonObject = JSON.parseObject(jsonResponse); -// -// // 提取 choices 数组中第一个元素的 delta 对象下的 content 字段 -// String content = jsonObject.getJSONArray("choices").getJSONObject(1) -// .getJSONObject("delta").getString("content"); -// //System.out.println(fullSentence.toString()); // 输出完整句子到控制台 -// return content; + if (Objects.equals(jsonResponse, CHATGPT_DONE)) { return null; } @@ -44,14 +37,14 @@ public class ChatGPTMessageHandler { return content; //responses.add(content); } else { - //System.out.println("Content is empty or not available."); + System.out.println("Content is empty or not available."); } } else { - //System.out.println("Delta is null or does not contain 'content'."); + System.out.println("Delta is null or does not contain 'content'."); } } } else { - //System.out.println("Choices array is empty."); + System.out.println("Choices array is empty."); } return null; } diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java index f09701e9f6b9a9c3b653d2e1938547c43b319611..048d7d5309b4e4ff93b1d3f7adf3278afadf7459 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java @@ -31,7 +31,7 @@ public class CustomChatGPT { //阻塞 // @Autowired - public CustomChatGPT(String customization) throws InterruptedException { + public CustomChatGPT(String customization) { chatGPTConfig = new ChatGPTConfig(customization); openAiClientType = customization; } diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java index 3a775fe0258fee4573424a14599ceb3179685038..c9389347c213d96ce57251ac0c4b910a7214ca5d 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java @@ -15,6 +15,8 @@ import static com.groupshell.constant.ChatGPTConstants.CHATGPT_DONE; public class CustomEventSourceListener extends ConsoleEventSourceListener { ChatGPTMessageHandler messageHandler = new ChatGPTMessageHandler(); + Boolean isChatGPTDone = false; + @Override public void onOpen(EventSource eventSource, Response response) { log.info("OpenAI建立sse连接..."); @@ -23,12 +25,14 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener { @Override public void onEvent(EventSource eventSource, String id, String type, String data) { //log.info("OpenAI返回数据:{}", data); + isChatGPTDone = false; String response = messageHandler.responseHandler(data); messageHandler.responses.add(response); if (data.equals(CHATGPT_DONE)) { log.info("OpenAI返回数据结束了"); messageHandler.printResponses(); messageHandler.responseFormat(messageHandler.responses); + isChatGPTDone = true; //eventSource.request(); return; } @@ -59,4 +63,8 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener { public ChatGPTMessageHandler getChatGPTMessageHandler() { return messageHandler; } + + public Boolean getChatGPTDone() { + return isChatGPTDone; + } } diff --git a/pom.xml b/pom.xml index f13229623d06184505cb77c38ca861cb670f38ae..bab936ecb2afb0e32f4659a21e74e7f7b89188fc 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ - + @@ -36,6 +36,6 @@ - + \ No newline at end of file diff --git a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java index e1b0c2788d31a7b9cef4781d88650e5074677b29..ecd5e2b6f70e88fdf641a5f68dc781e668978423 100644 --- a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java +++ b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java @@ -6,6 +6,7 @@ import com.groupshell.dto.MessageDTO; import com.groupshell.entity.Message; import com.groupshell.entity.User; import com.groupshell.entity.UserGroup; +import com.groupshell.entity.chatgpt.CustomChatGPT; import com.groupshell.service.MessageService; import com.groupshell.service.UserGroupService; import com.groupshell.service.UserService; @@ -27,7 +28,12 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT; /** * WebSocket服务 @@ -42,6 +48,7 @@ public class WebSocketServer private final MessageService messageService; //存放会话对象 private static final Map sessionMap=new ConcurrentHashMap<>(); + private final CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); private static final Map> groupId2userIds=new ConcurrentHashMap<>(); @@ -130,27 +137,61 @@ public class WebSocketServer //如果是向gpt提问,还需要存储并推送gpt的回答 if(messageDTO.getGpt()) { - //todo 输入messageDTO.getContent(),输出gpt的回答answer - String answer=""; - //存数据库 - Message message2=Message.builder() - .groupId(messageDTO.getGroupId()) - .content(answer) - .createTime(LocalDateTime.now()) - .build(); - messageService.save(message2); - //向浏览器推送消息 - MessageVO messageVO2=MessageVO.builder() - .system(false) - .reminder(false) - .chat(true) - .username("GroupShellGPT") - .groupId(messageDTO.getGroupId()) - .content(answer) - .createTime(message2.getCreateTime() - .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm"))) - .build(); - send(JSON.toJSONString(messageVO2),userIds); + try { + CompletableFuture future = CompletableFuture.supplyAsync(() -> { + customChatGPT.sendMessageToChatGPT(messageDTO.getContent()); + + while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) { + //轮询:性能开销比较大 + //若gpt回复还未结束,则等待其回复完整 + try { + Thread.sleep(100); // Reduce CPU usage with a small sleep + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException("Task interrupted", e); + } + } + String answer = customChatGPT.getChatGPTConfig() + .getEventSourceListener() + .getChatGPTMessageHandler().responsesFomatted; + return answer; + }); + + future.orTimeout(20, TimeUnit.SECONDS) //20秒超时 + .thenAccept(answer -> { + //存数据库 + Message message2=Message.builder() + .groupId(messageDTO.getGroupId()) + .content(answer) + .createTime(LocalDateTime.now()) + .build(); + messageService.save(message2); + //向浏览器推送消息 + MessageVO messageVO2=MessageVO.builder() + .system(false) + .reminder(false) + .chat(true) + .username("GroupShellGPT") + .groupId(messageDTO.getGroupId()) + .content(answer) + .createTime(message2.getCreateTime() + .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm"))) + .build(); + send(JSON.toJSONString(messageVO2),userIds); + }).exceptionally(e -> { + if (e instanceof TimeoutException) { + log.error("ChatGPT response timeout", e); + // Handle timeout scenario + } else { + log.error("处理 ChatGPT 回答时发生错误", e); + } + return null; + }); + } + catch (Exception e) { + e.printStackTrace(); + } + } log.info("处理完成"); }