From 444171c2d1e8caa29cdc6f74af2c95dc57be9cfe Mon Sep 17 00:00:00 2001
From: EricCheng <12955029+ericchengscut@user.noreply.gitee.com>
Date: Thu, 9 May 2024 17:07:14 +0800
Subject: [PATCH 1/2] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0?=
=?UTF-8?q?=E5=90=91GPT=E5=8F=91=E9=80=81=E6=B6=88=E6=81=AF=E4=B8=8E?=
=?UTF-8?q?=E6=8E=A5=E5=8F=97=E5=85=B6=E5=9B=9E=E5=A4=8D=E7=9A=84=E5=8A=9F?=
=?UTF-8?q?=E8=83=BD=EF=BC=88=E5=90=8C=E6=AD=A5=E5=BC=8F=EF=BC=89?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../entity/chatgpt/ChatGPTConfig.java | 2 +-
.../entity/chatgpt/CustomChatGPT.java | 2 +-
.../chatgpt/CustomEventSourceListener.java | 8 +++
pom.xml | 4 +-
.../groupshell/websocket/WebSocketServer.java | 62 ++++++++++++-------
5 files changed, 52 insertions(+), 26 deletions(-)
diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java
index 87c6af2..ab495ab 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java
@@ -36,7 +36,7 @@ public class ChatGPTConfig {
private ChatCompletionResponse chatCompletionResponse; //暂时没用
private HttpLoggingInterceptor httpLoggingInterceptor;
- public ChatGPTConfig(String customization) throws InterruptedException {
+ public ChatGPTConfig(String customization) {
if (customization == OPENAI_CLIENT) {
initHttpLoggingInterceptor();
initOkHttpClient();
diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java
index f09701e..048d7d5 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java
@@ -31,7 +31,7 @@ public class CustomChatGPT {
//阻塞
// @Autowired
- public CustomChatGPT(String customization) throws InterruptedException {
+ public CustomChatGPT(String customization) {
chatGPTConfig = new ChatGPTConfig(customization);
openAiClientType = customization;
}
diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java
index 3a775fe..c938934 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomEventSourceListener.java
@@ -15,6 +15,8 @@ import static com.groupshell.constant.ChatGPTConstants.CHATGPT_DONE;
public class CustomEventSourceListener extends ConsoleEventSourceListener {
ChatGPTMessageHandler messageHandler = new ChatGPTMessageHandler();
+ Boolean isChatGPTDone = false;
+
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("OpenAI建立sse连接...");
@@ -23,12 +25,14 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener {
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
//log.info("OpenAI返回数据:{}", data);
+ isChatGPTDone = false;
String response = messageHandler.responseHandler(data);
messageHandler.responses.add(response);
if (data.equals(CHATGPT_DONE)) {
log.info("OpenAI返回数据结束了");
messageHandler.printResponses();
messageHandler.responseFormat(messageHandler.responses);
+ isChatGPTDone = true;
//eventSource.request();
return;
}
@@ -59,4 +63,8 @@ public class CustomEventSourceListener extends ConsoleEventSourceListener {
public ChatGPTMessageHandler getChatGPTMessageHandler() {
return messageHandler;
}
+
+ public Boolean getChatGPTDone() {
+ return isChatGPTDone;
+ }
}
diff --git a/pom.xml b/pom.xml
index f132296..bab936e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -25,7 +25,7 @@
-
+
@@ -36,6 +36,6 @@
-
+
\ No newline at end of file
diff --git a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
index b602062..495e8c3 100644
--- a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
+++ b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
@@ -5,6 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.groupshell.dto.MessageDTO;
import com.groupshell.entity.Message;
import com.groupshell.entity.UserGroup;
+import com.groupshell.entity.chatgpt.CustomChatGPT;
import com.groupshell.service.MessageService;
import com.groupshell.service.UserGroupService;
import com.groupshell.service.UserService;
@@ -27,6 +28,8 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
+import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT;
+
/**
* WebSocket服务
*/
@@ -40,7 +43,7 @@ public class WebSocketServer
private final MessageService messageService;
//存放会话对象
private static final Map sessionMap=new ConcurrentHashMap<>();
-
+ private CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT);
public WebSocketServer(MessageService messageService,UserService userService,UserGroupService userGroupService)
{
@@ -104,27 +107,42 @@ public class WebSocketServer
//如果是向gpt提问,还需要存储并推送gpt的回答
if(messageDTO.getGpt())
{
- //todo 输入messageDTO.getContent(),输出gpt的回答answer
- String answer="";
- //存数据库
- Message message2=Message.builder()
- .groupId(messageDTO.getGroupId())
- .content(answer)
- .createTime(LocalDateTime.now())
- .build();
- messageService.save(message2);
- //向浏览器推送消息
- MessageVO messageVO2=MessageVO.builder()
- .system(false)
- .reminder(false)
- .chat(true)
- .username("GroupShellGPT")
- .groupId(messageDTO.getGroupId())
- .content(answer)
- .createTime(message2.getCreateTime()
- .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm")))
- .build();
- send(JSON.toJSONString(messageVO2),userIds);
+ try {
+ //todo 输入messageDTO.getContent(),输出gpt的回答answer
+ //每次都new一个对象性能开销会不会很大
+ //CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT);
+ customChatGPT.sendMessageToChatGPT(messageDTO.getContent());
+ String answer = "";
+ while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) {
+ //若gpt回复还未结束,则等待其回复完整
+ }
+ answer = customChatGPT.getChatGPTConfig()
+ .getEventSourceListener()
+ .getChatGPTMessageHandler().responsesFomatted;
+ //存数据库
+ Message message2=Message.builder()
+ .groupId(messageDTO.getGroupId())
+ .content(answer)
+ .createTime(LocalDateTime.now())
+ .build();
+ messageService.save(message2);
+ //向浏览器推送消息
+ MessageVO messageVO2=MessageVO.builder()
+ .system(false)
+ .reminder(false)
+ .chat(true)
+ .username("GroupShellGPT")
+ .groupId(messageDTO.getGroupId())
+ .content(answer)
+ .createTime(message2.getCreateTime()
+ .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm")))
+ .build();
+ send(JSON.toJSONString(messageVO2),userIds);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+
}
log.info("处理完成");
}
--
Gitee
From 8632f8c710aaa1f1cb0b1288e842c4240865d561 Mon Sep 17 00:00:00 2001
From: EricCheng <12955029+ericchengscut@user.noreply.gitee.com>
Date: Thu, 9 May 2024 17:51:58 +0800
Subject: [PATCH 2/2] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0?=
=?UTF-8?q?=E5=90=91GPT=E5=8F=91=E9=80=81=E6=B6=88=E6=81=AF=E4=B8=8E?=
=?UTF-8?q?=E6=8E=A5=E5=8F=97=E5=85=B6=E5=9B=9E=E5=A4=8D=E7=9A=84=E5=8A=9F?=
=?UTF-8?q?=E8=83=BD=EF=BC=88=E5=9C=A8=E7=AD=89=E5=BE=85gpt=E5=9B=9E?=
=?UTF-8?q?=E5=A4=8D=E6=97=B6=EF=BC=8C=E5=88=9B=E5=BB=BA=E6=96=B0=E7=BA=BF?=
=?UTF-8?q?=E7=A8=8B=EF=BC=8C=E5=BC=82=E6=AD=A5=E7=AD=89=E5=BE=85=EF=BC=8C?=
=?UTF-8?q?=E4=BB=A5=E5=85=8D=E9=98=BB=E5=A1=9E=E4=B8=BB=E7=A8=8B=E5=BA=8F?=
=?UTF-8?q?=EF=BC=89?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../entity/chatgpt/ChatGPTMessageHandler.java | 15 +---
.../groupshell/websocket/WebSocketServer.java | 84 ++++++++++++-------
2 files changed, 57 insertions(+), 42 deletions(-)
diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java
index 4df66cc..1bf4335 100644
--- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java
+++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java
@@ -19,14 +19,7 @@ public class ChatGPTMessageHandler {
public String responseHandler(String jsonResponse) {
-// // 使用 Fastjson 解析 JSON 数据
-// JSONObject jsonObject = JSON.parseObject(jsonResponse);
-//
-// // 提取 choices 数组中第一个元素的 delta 对象下的 content 字段
-// String content = jsonObject.getJSONArray("choices").getJSONObject(1)
-// .getJSONObject("delta").getString("content");
-// //System.out.println(fullSentence.toString()); // 输出完整句子到控制台
-// return content;
+
if (Objects.equals(jsonResponse, CHATGPT_DONE)) {
return null;
}
@@ -44,14 +37,14 @@ public class ChatGPTMessageHandler {
return content;
//responses.add(content);
} else {
- //System.out.println("Content is empty or not available.");
+ System.out.println("Content is empty or not available.");
}
} else {
- //System.out.println("Delta is null or does not contain 'content'.");
+ System.out.println("Delta is null or does not contain 'content'.");
}
}
} else {
- //System.out.println("Choices array is empty.");
+ System.out.println("Choices array is empty.");
}
return null;
}
diff --git a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
index 495e8c3..79a4c9d 100644
--- a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
+++ b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java
@@ -26,7 +26,10 @@ import java.time.format.DateTimeFormatter;
import java.util.Collection;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT;
@@ -43,7 +46,7 @@ public class WebSocketServer
private final MessageService messageService;
//存放会话对象
private static final Map sessionMap=new ConcurrentHashMap<>();
- private CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT);
+ private final CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT);
public WebSocketServer(MessageService messageService,UserService userService,UserGroupService userGroupService)
{
@@ -108,36 +111,55 @@ public class WebSocketServer
if(messageDTO.getGpt())
{
try {
- //todo 输入messageDTO.getContent(),输出gpt的回答answer
- //每次都new一个对象性能开销会不会很大
- //CustomChatGPT customChatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT);
- customChatGPT.sendMessageToChatGPT(messageDTO.getContent());
- String answer = "";
- while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) {
- //若gpt回复还未结束,则等待其回复完整
- }
- answer = customChatGPT.getChatGPTConfig()
- .getEventSourceListener()
- .getChatGPTMessageHandler().responsesFomatted;
- //存数据库
- Message message2=Message.builder()
- .groupId(messageDTO.getGroupId())
- .content(answer)
- .createTime(LocalDateTime.now())
- .build();
- messageService.save(message2);
- //向浏览器推送消息
- MessageVO messageVO2=MessageVO.builder()
- .system(false)
- .reminder(false)
- .chat(true)
- .username("GroupShellGPT")
- .groupId(messageDTO.getGroupId())
- .content(answer)
- .createTime(message2.getCreateTime()
- .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm")))
- .build();
- send(JSON.toJSONString(messageVO2),userIds);
+ CompletableFuture future = CompletableFuture.supplyAsync(() -> {
+ customChatGPT.sendMessageToChatGPT(messageDTO.getContent());
+
+ while (!customChatGPT.getChatGPTConfig().getEventSourceListener().getChatGPTDone()) {
+ //轮询:性能开销比较大
+ //若gpt回复还未结束,则等待其回复完整
+ try {
+ Thread.sleep(100); // Reduce CPU usage with a small sleep
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IllegalStateException("Task interrupted", e);
+ }
+ }
+ String answer = customChatGPT.getChatGPTConfig()
+ .getEventSourceListener()
+ .getChatGPTMessageHandler().responsesFomatted;
+ return answer;
+ });
+
+ future.orTimeout(20, TimeUnit.SECONDS) //20秒超时
+ .thenAccept(answer -> {
+ //存数据库
+ Message message2=Message.builder()
+ .groupId(messageDTO.getGroupId())
+ .content(answer)
+ .createTime(LocalDateTime.now())
+ .build();
+ messageService.save(message2);
+ //向浏览器推送消息
+ MessageVO messageVO2=MessageVO.builder()
+ .system(false)
+ .reminder(false)
+ .chat(true)
+ .username("GroupShellGPT")
+ .groupId(messageDTO.getGroupId())
+ .content(answer)
+ .createTime(message2.getCreateTime()
+ .format(DateTimeFormatter.ofPattern("yyyy/MM/dd " + "HH" + ":mm")))
+ .build();
+ send(JSON.toJSONString(messageVO2),userIds);
+ }).exceptionally(e -> {
+ if (e instanceof TimeoutException) {
+ log.error("ChatGPT response timeout", e);
+ // Handle timeout scenario
+ } else {
+ log.error("处理 ChatGPT 回答时发生错误", e);
+ }
+ return null;
+ });
}
catch (Exception e) {
e.printStackTrace();
--
Gitee