diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java
index 87c6af271cf54aa6e1872e2a59dd99314855029a..ab495abee4cdbdaa8e594e7a54633e484d0fd621 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java
@@ -36,7 +36,7 @@ public class ChatGPTConfig {
private ChatCompletionResponse chatCompletionResponse; //暂时没用
private HttpLoggingInterceptor httpLoggingInterceptor;
- public ChatGPTConfig(String customization) throws InterruptedException {
+ public ChatGPTConfig(String customization) {
if (customization == OPENAI_CLIENT) {
initHttpLoggingInterceptor();
initOkHttpClient();
diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java
index 4df66cc5d226cbcfe8cc73c3a1fcd710d028e321..1bf4335d5f730d2dc4dc4e68f8baa6481306f86b 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java
@@ -19,14 +19,7 @@ public class ChatGPTMessageHandler {
public String responseHandler(String jsonResponse) {
-// // 使用 Fastjson 解析 JSON 数据
-// JSONObject jsonObject = JSON.parseObject(jsonResponse);
-//
-// // 提取 choices 数组中第一个元素的 delta 对象下的 content 字段
-// String content = jsonObject.getJSONArray("choices").getJSONObject(1)
-// .getJSONObject("delta").getString("content");
-// //System.out.println(fullSentence.toString()); // 输出完整句子到控制台
-// return content;
+
if (Objects.equals(jsonResponse, CHATGPT_DONE)) {
return null;
}
@@ -44,14 +37,14 @@ public class ChatGPTMessageHandler {
return content;
//responses.add(content);
} else {
- //System.out.println("Content is empty or not available.");
+ System.out.println("Content is empty or not available.");
}
} else {
- //System.out.println("Delta is null or does not contain 'content'.");
+ System.out.println("Delta is null or does not contain 'content'.");
}
}
} else {
- //System.out.println("Choices array is empty.");
+ System.out.println("Choices array is empty.");
}
return null;
}
diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java
index f09701e9f6b9a9c3b653d2e1938547c43b319611..048d7d5309b4e4ff93b1d3f7adf3278afadf7459 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java
@@ -31,7 +31,7 @@ public class CustomChatGPT {
//阻塞
// @Autowired
- public CustomChatGPT(String customization) throws InterruptedException {
+ public CustomChatGPT(String customization) {
chatGPTConfig = new ChatGPTConfig(customization);
openAiClientType = customization;
}
diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java
index 3a775fe0258fee4573424a14599ceb3179685038..c9389347c213d96ce57251ac0c4b910a7214ca5d 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java
@@ -15,6 +15,8 @@ import static com.groupshell.constant.ChatGPTConstants.CHATGPT_DONE;
public class CustomEventSourceListener extends ConsoleEventSourceListener {
ChatGPTMessageHandler messageHandler = new ChatGPTMessageHandler();
+ Boolean isChatGPTDone = false;
+
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("OpenAI建立sse连接...");
@@ -23,12 +25,14 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener {
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
//log.info("OpenAI返回数据:{}", data);
+ isChatGPTDone = false;
String response = messageHandler.responseHandler(data);
messageHandler.responses.add(response);
if (data.equals(CHATGPT_DONE)) {
log.info("OpenAI返回数据结束了");
messageHandler.printResponses();
messageHandler.responseFormat(messageHandler.responses);
+ isChatGPTDone = true;
//eventSource.request();
return;
}
@@ -59,4 +63,8 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener {
public ChatGPTMessageHandler getChatGPTMessageHandler() {
return messageHandler;
}
+
+ public Boolean getChatGPTDone() {
+ return isChatGPTDone;
+ }
}
diff --git a/pom.xml b/pom.xml
index f13229623d06184505cb77c38ca861cb670f38ae..bab936ecb2afb0e32f4659a21e74e7f7b89188fc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -25,7 +25,7 @@
-
+
@@ -36,6 +36,6 @@
-
+
\ No newline at end of file
diff --git a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
index e1b0c2788d31a7b9cef4781d88650e5074677b29..ecd5e2b6f70e88fdf641a5f68dc781e668978423 100644
--- a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
+++ b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
@@ -6,6 +6,7 @@ import com.groupshell.dto.MessageDTO;
import com.groupshell.entity.Message;
import com.groupshell.entity.User;
import com.groupshell.entity.UserGroup;
+import com.groupshell.entity.chatgpt.CustomChatGPT;
import com.groupshell.service.MessageService;
import com.groupshell.service.UserGroupService;
import com.groupshell.service.UserService;
@@ -27,7 +28,12 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT;
/**
* WebSocket服务
@@ -42,6 +48,7 @@ public class WebSocketServer
private final MessageService messageService;
//存放会话对象
private static final Map sessionMap=new ConcurrentHashMap<>();
+ private final CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT);
private static final Map> groupId2userIds=new ConcurrentHashMap<>();
@@ -130,27 +137,61 @@ public class WebSocketServer
//如果是向gpt提问,还需要存储并推送gpt的回答
if(messageDTO.getGpt())
{
- //todo 输入messageDTO.getContent(),输出gpt的回答answer
- String answer="";
- //存数据库
- Message message2=Message.builder()
- .groupId(messageDTO.getGroupId())
- .content(answer)
- .createTime(LocalDateTime.now())
- .build();
- messageService.save(message2);
- //向浏览器推送消息
- MessageVO messageVO2=MessageVO.builder()
- .system(false)
- .reminder(false)
- .chat(true)
- .username("GroupShellGPT")
- .groupId(messageDTO.getGroupId())
- .content(answer)
- .createTime(message2.getCreateTime()
- .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm")))
- .build();
- send(JSON.toJSONString(messageVO2),userIds);
+ try {
+ CompletableFuture future = CompletableFuture.supplyAsync(() -> {
+ customChatGPT.sendMessageToChatGPT(messageDTO.getContent());
+
+ while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) {
+ //轮询:性能开销比较大
+ //若gpt回复还未结束,则等待其回复完整
+ try {
+ Thread.sleep(100); // Reduce CPU usage with a small sleep
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IllegalStateException("Task interrupted", e);
+ }
+ }
+ String answer = customChatGPT.getChatGPTConfig()
+ .getEventSourceListener()
+ .getChatGPTMessageHandler().responsesFomatted;
+ return answer;
+ });
+
+ future.orTimeout(20, TimeUnit.SECONDS) //20秒超时
+ .thenAccept(answer -> {
+ //存数据库
+ Message message2=Message.builder()
+ .groupId(messageDTO.getGroupId())
+ .content(answer)
+ .createTime(LocalDateTime.now())
+ .build();
+ messageService.save(message2);
+ //向浏览器推送消息
+ MessageVO messageVO2=MessageVO.builder()
+ .system(false)
+ .reminder(false)
+ .chat(true)
+ .username("GroupShellGPT")
+ .groupId(messageDTO.getGroupId())
+ .content(answer)
+ .createTime(message2.getCreateTime()
+ .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm")))
+ .build();
+ send(JSON.toJSONString(messageVO2),userIds);
+ }).exceptionally(e -> {
+ if (e instanceof TimeoutException) {
+ log.error("ChatGPT response timeout", e);
+ // Handle timeout scenario
+ } else {
+ log.error("处理 ChatGPT 回答时发生错误", e);
+ }
+ return null;
+ });
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+
}
log.info("处理完成");
}