diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java index 77ff899d04b5c083b533373c50d6133a18344fca..de6b85c0e10932d394fe7be2a4108a3cb92b770b 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTConfig.java @@ -76,9 +76,9 @@ public class ChatGPTConfig { this.okHttpClient = new OkHttpClient.Builder() // .proxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890))) // 设置代理 .addInterceptor(httpLoggingInterceptor) - .connectTimeout(15, TimeUnit.SECONDS) - .writeTimeout(15, TimeUnit.SECONDS) - .readTimeout(15, TimeUnit.SECONDS) + .connectTimeout(20, TimeUnit.SECONDS) + .writeTimeout(20, TimeUnit.SECONDS) + .readTimeout(20, TimeUnit.SECONDS) .build(); } diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java index 42052f07fbdacb9d8669975bfc03473706b6815d..31b8fcd0529bcfd66715f6f0bfa2b730b9c1e9f8 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/ChatGPTMessageHandler.java @@ -15,7 +15,7 @@ public class ChatGPTMessageHandler { // public String singleResponse; public List responses = new ArrayList(); - public String responsesFomatted; + public String responsesFormatted; public String responseHandler(String jsonResponse) { @@ -33,9 +33,9 @@ public class ChatGPTMessageHandler { if (delta != null && delta.containsKey("content")) { String content = delta.getString("content"); if (content != null && !content.isEmpty()) { - //System.out.println(content); + return content; - //responses.add(content); + } else { System.out.println("Content is empty or not available."); } @@ -54,7 +54,7 @@ public class ChatGPTMessageHandler { for (String word : responses) { fullSentence = fullSentence.append(word); } - this.responsesFomatted = fullSentence.toString(); + this.responsesFormatted = fullSentence.toString(); return fullSentence.toString(); } diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java index 048d7d5309b4e4ff93b1d3f7adf3278afadf7459..1512e74398fc504b10c9eac9af66176d0541255c 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/CustomChatGPT.java @@ -2,6 +2,7 @@ package com.groupshell.entity.chatgpt; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; +import com.unfbx.chatgpt.entity.chat.ChatCompletion; import com.unfbx.chatgpt.entity.chat.Message; import lombok.AllArgsConstructor; import lombok.Builder; @@ -68,9 +69,66 @@ public class CustomChatGPT { System.out.println("聊天结束。"); break; } - sendMessageToChatGPT(input); - System.out.println("打印ChatCompletion的Messages"); - System.out.println(chatGPTConfig.getChatCompletion().getMessages()); + if ("model".equalsIgnoreCase(input)) { + System.out.println("请输入需要更换的模型名称:"); + String inputModel = scanner.nextLine(); + switch (inputModel) { + case "gpt3.5-turbo": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_3_5_TURBO); + break; + } + case "gpt3.5-turbo-0613": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_3_5_TURBO_0613); + //System.out.println("切换模型为gpt3.5-turbo-0613"); + break; + } + case "gpt3.5-turbo-16k": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_3_5_TURBO_16K); + //System.out.println("切换模型为gpt3.5-turbo-16k"); + break; + } + case "gpt3.5-turbo-16k-0613": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_3_5_TURBO_16K_0613); + //System.out.println("切换模型为gpt3.5-turbo-16k-0613"); + break; + } + case "gpt4": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_4); + //System.out.println("切换模型为gpt4"); + break; + } + case "gpt4-0613": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_4_0613); + //System.out.println("切换模型为gpt4-0613"); + break; + } + case "gpt4-32k": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_4_32K); + //System.out.println("切换模型为gpt4-32k"); + break; + } + case "gpt4-32k-0314": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_4_32K_0314); + //System.out.println("切换模型为gpt4-32k-0314"); + break; + } + case "gpt4-32k-0613": { + chatGPTConfig.setChatCompletion(ChatCompletion.Model.GPT_4_32K_0613); + //System.out.println("切换模型为gpt4-32k-0613"); + break; + } + default: { + System.out.println("请输入正确的模型名称。"); + break; + } + } + System.out.println("现在的模型为"+chatGPTConfig.getChatCompletion().getModel()); + } + else { + sendMessageToChatGPT(input); + System.out.println("打印ChatCompletion的Messages"); + System.out.println(chatGPTConfig.getChatCompletion().getMessages()); + } } scanner.close(); } diff --git a/pojo/src/main/java/com/groupshell/entity/chatgpt/TestChatGPT.java b/pojo/src/main/java/com/groupshell/entity/chatgpt/TestChatGPT.java index c4bba68a8f4444be41a8e8ce20a2f77a622d3465..ebe4f2a799141276064b902b9c6ffa12522f3c99 100644 --- a/pojo/src/main/java/com/groupshell/entity/chatgpt/TestChatGPT.java +++ b/pojo/src/main/java/com/groupshell/entity/chatgpt/TestChatGPT.java @@ -1,12 +1,14 @@ package com.groupshell.entity.chatgpt; +import com.unfbx.chatgpt.entity.chat.ChatCompletion; + import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT; -public class TestChatGPT -{ - public static void main(String[] args) throws InterruptedException - { - CustomChatGPT chatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); - chatGPT.startChattingWithConsole(); - } +public class TestChatGPT { + public static void main(String[] args) throws Exception { + CustomChatGPT chatGPT = new CustomChatGPT(OPENAI_STREAM_CLIENT); + //chatGPT.getChasetChatCompletion(ChatCompletion.Model.GPT_4_0613); + chatGPT.startChattingWithConsole(); + } + } diff --git a/server/src/main/java/com/groupshell/service/impl/ChatGPTServiceImpl.java b/server/src/main/java/com/groupshell/service/impl/ChatGPTServiceImpl.java index 0165d4cd3a8af619f8f064fe4868c866d9423d6a..7b8bc4a075ad9cb6a9964ea4cf510b38ab6e2ddf 100644 --- a/server/src/main/java/com/groupshell/service/impl/ChatGPTServiceImpl.java +++ b/server/src/main/java/com/groupshell/service/impl/ChatGPTServiceImpl.java @@ -1,20 +1,10 @@ package com.groupshell.service.impl; -import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.groupshell.entity.chatgpt.ChatGPTMessageHandler; import com.groupshell.entity.chatgpt.CustomChatGPT; -import com.groupshell.entity.chatgpt.CustomEventSourceListener; import com.groupshell.service.ChatGPTService; -import com.unfbx.chatgpt.entity.chat.ChatCompletion; -import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse; -import com.unfbx.chatgpt.entity.chat.Message; -import jakarta.annotation.Resource; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -import java.io.IOException; -import java.util.Collections; - import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT; @Service @@ -46,8 +36,7 @@ public class ChatGPTServiceImpl implements ChatGPTService { return customChatGPT.getChatGPTConfig() .getEventSourceListener() .getChatGPTMessageHandler() - .responsesFomatted; - //return messageHandler.responsesFomatted; + .responsesFormatted; } diff --git a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java index ad7ecdff77c0e4acecfe24eeeb2db624cee8a96f..0d4cf40aa128d2fcf23e6fa1ea8417a18106f2c0 100644 --- a/server/src/main/java/com/groupshell/websocket/WebSocketServer.java +++ b/server/src/main/java/com/groupshell/websocket/WebSocketServer.java @@ -25,8 +25,6 @@ import java.time.format.DateTimeFormatter; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import static com.groupshell.constant.ChatGPTConstants.OPENAI_STREAM_CLIENT; @@ -141,45 +139,47 @@ public class WebSocketServer if(messageDTO.getGpt()) { customChatGPT.sendMessageToChatGPT(messageDTO.getContent()); - while(!customChatGPT.getChatGPTConfig() - .getEventSourceListener() - .getChatGPTDone()) - { - //轮询:性能开销比较大 - //若gpt回复还未结束,则等待其回复完整 - try - { - Thread.sleep(100); // Reduce CPU usage with a small sleep - log.info("$$$$$$$$$"); - }catch(InterruptedException e) - { - Thread.currentThread() - .interrupt(); - throw new IllegalStateException("Task interrupted",e); - } - } - String answer=customChatGPT.getChatGPTConfig() - .getEventSourceListener() - .getChatGPTMessageHandler().responsesFomatted; - Message message2=Message.builder() - .groupId(groupId) - .content(answer) - .createTime(LocalDateTime.now()) - .build(); - //向浏览器推送消息 - MessageVO messageVO2=MessageVO.builder() - .system(false) - .reminder(false) - .chat(true) - .username("GroupShellGPT") - .content(answer) - .self(false) - .createTime(message2.getCreateTime() - .format(DateTimeFormatter.ofPattern("yyyy" + "/MM/dd " + "HH" + ":mm"))) - .build(); - send(JSON.toJSONString(messageVO2),userIds); - //存数据库 - messageService.save(message2); + while(!customChatGPT.getChatGPTConfig() + .getEventSourceListener() + .getChatGPTDone()) + { + //轮询:性能开销比较大 + //若gpt回复还未结束,则等待其回复完整 + try + { + Thread.sleep(100); // Reduce CPU usage with a small sleep + log.info("$$$$$$$$$"); + }catch(InterruptedException e) + { + Thread.currentThread() + .interrupt(); + throw new IllegalStateException("Task interrupted",e); + } + } + String answer=customChatGPT.getChatGPTConfig() + .getEventSourceListener() + .getChatGPTMessageHandler().responsesFormatted; + log.info(answer+"**************"); + + //存数据库 + Message message2=Message.builder() + .groupId(groupId) + .content(answer) + .createTime(LocalDateTime.now()) + .build(); + messageService.save(message2); + //向浏览器推送消息 + MessageVO messageVO2=MessageVO.builder() + .system(false) + .reminder(false) + .chat(true) + .username("GroupShellGPT") + .content(answer) + .self(false) + .createTime(message2.getCreateTime() + .format(DateTimeFormatter.ofPattern("yyyy" + "/MM/dd " + "HH" + ":mm"))) + .build(); + send(JSON.toJSONString(messageVO2),userIds); } log.info("处理完成"); }