diff --git a/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/McpOptions.java b/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/McpOptions.java index 45944b1b7b7c2c612c0dc5a043bd13cd710ff21a..fc922311a25db6b830f45908fbc0507d52469daf 100644 --- a/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/McpOptions.java +++ b/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/McpOptions.java @@ -14,6 +14,7 @@ import com.alibaba.fastjson2.JSONObject; import tech.smartboot.feat.Feat; import tech.smartboot.feat.ai.mcp.model.Implementation; import tech.smartboot.feat.ai.mcp.model.Roots; +import tech.smartboot.feat.core.client.HttpRest; import java.util.ArrayList; import java.util.List; @@ -27,6 +28,7 @@ public class McpOptions { private String baseUrl; private String mcpEndpoint = "/mcp"; private String sseEndpoint = "/sse"; + private Consumer onInitRest; private final Implementation implementation = Implementation.of("feat-mcp-client", "Feat MCP", Feat.VERSION); private boolean roots; private boolean sampling; @@ -61,10 +63,19 @@ public class McpOptions { return sseEndpoint; } - public void setSseEndpoint(String sseEndpoint) { + public McpOptions setSseEndpoint(String sseEndpoint) { this.sseEndpoint = sseEndpoint; + return this; } + public Consumer getOnInitRest() { + return this.onInitRest; + } + + public McpOptions setOnInitRest(Consumer onInitRest) { + this.onInitRest = onInitRest; + return this; + } public Implementation getImplementation() { return implementation; diff --git a/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/SseTransport.java b/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/SseTransport.java index aeaf9d3b30a7fe97cac1e15adbc442b341b8dd74..35cbeea71b06f4cdd61050f35de75b108a0e859b 100644 --- a/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/SseTransport.java +++ b/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/SseTransport.java @@ -44,8 +44,8 @@ final class SseTransport extends Transport { public SseTransport(McpOptions options) { super(options); - httpClient = new HttpClient(options.getBaseUrl()); - sseClient = new HttpClient(options.getBaseUrl()); + httpClient = new HttpClient(options.getBaseUrl()).setOnInitRest(options.getOnInitRest()); + sseClient = new HttpClient(options.getBaseUrl()).setOnInitRest(options.getOnInitRest()); sseClient.post(options.getSseEndpoint()).header(header -> header.set(HeaderName.ACCEPT, HeaderValue.ContentType.EVENT_STREAM).set(HeaderName.CACHE_CONTROL.getName(), HeaderValue.NO_CACHE).set(HeaderName.CONNECTION.getName(), HeaderValue.Connection.KEEPALIVE)).onResponseBody(new ServerSentEventStream() { @Override public void onEvent(HttpResponse httpResponse, Map event) { diff --git a/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/StreamableTransport.java b/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/StreamableTransport.java index 0c62a0fcaccaeda3bfa062bbe9dc74c0a43fa44a..f31284e34c6472b742a29c2bd36d2993271413db 100644 --- a/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/StreamableTransport.java +++ b/feat-ai/src/main/java/tech/smartboot/feat/ai/mcp/client/StreamableTransport.java @@ -42,8 +42,8 @@ final class StreamableTransport extends Transport { public StreamableTransport(McpOptions options) { super(options); - httpClient = new HttpClient(options.getBaseUrl()); - sseClient = new HttpClient(options.getBaseUrl()); + httpClient = new HttpClient(options.getBaseUrl()).setOnInitRest(options.getOnInitRest()); + sseClient = new HttpClient(options.getBaseUrl()).setOnInitRest(options.getOnInitRest()); } @Override diff --git a/feat-core/src/main/java/tech/smartboot/feat/core/client/HttpClient.java b/feat-core/src/main/java/tech/smartboot/feat/core/client/HttpClient.java index 1bc7845314355d819aad619ee3427795e80a8bc9..673c64d725e16b9761356ff64b8719ab125a9f48 100644 --- a/feat-core/src/main/java/tech/smartboot/feat/core/client/HttpClient.java +++ b/feat-core/src/main/java/tech/smartboot/feat/core/client/HttpClient.java @@ -28,6 +28,7 @@ import java.util.Base64; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; /** * @author 三刀 zhengjunweimail@163.com @@ -61,6 +62,12 @@ public final class HttpClient { private final HttpMessageProcessor processor = new HttpMessageProcessor(); private final String uri; + private Consumer onInitRest; + + public HttpClient setOnInitRest(Consumer onInitRest) { + this.onInitRest = onInitRest; + return this; + } public HttpClient(String url) { int schemaIndex = url.indexOf("://"); if (schemaIndex == -1) { @@ -185,6 +192,9 @@ public final class HttpClient { releaseConnection(client); return null; }); + if(onInitRest != null) { + onInitRest.accept(httpRestImpl); + } }