Ai
1 Star 0 Fork 0

tmfll/mcp-golang

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
server.go 18.35 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
package mcp_golang
import (
"encoding/json"
"fmt"
"github.com/invopop/jsonschema"
"github.com/metoro-io/mcp-golang/internal/protocol"
"github.com/metoro-io/mcp-golang/internal/tools"
"github.com/metoro-io/mcp-golang/transport"
"reflect"
"strings"
)
// Here we define the actual MCP server that users will create and run
// A server can be passed a number of handlers to handle requests from clients
// Additionally it can be parametrized by a transport. This transport will be used to actually send and receive messages.
// So for example if the stdio transport is used, the server will read from stdin and write to stdout
// If the SSE transport is used, the server will send messages over an SSE connection and receive messages from HTTP POST requests.
// The interface that we're looking to support is something like [gin](https://github.com/gin-gonic/gin)s interface
type toolResponseSent struct {
Response *ToolResponse
Error error
}
// Custom JSON marshaling for ToolResponse
func (c toolResponseSent) MarshalJSON() ([]byte, error) {
if c.Error != nil {
errorText := c.Error.Error()
c.Response = NewToolReponse(NewTextContent(errorText))
}
return json.Marshal(struct {
Content []*Content `json:"content" yaml:"content" mapstructure:"content"`
IsError bool `json:"isError" yaml:"isError" mapstructure:"isError"`
}{
Content: c.Response.Content,
IsError: c.Error != nil,
})
}
// Custom JSON marshaling for ToolResponse
func (c resourceResponseSent) MarshalJSON() ([]byte, error) {
if c.Error != nil {
errorText := c.Error.Error()
c.Response = NewResourceResponse(NewTextEmbeddedResource(c.Uri, errorText, "text/plain"))
}
return json.Marshal(c.Response)
}
type resourceResponseSent struct {
Response *ResourceResponse
Uri string
Error error
}
func newResourceResponseSentError(err error) *resourceResponseSent {
return &resourceResponseSent{
Error: err,
}
}
// newToolResponseSent creates a new toolResponseSent
func newResourceResponseSent(response *ResourceResponse) *resourceResponseSent {
return &resourceResponseSent{
Response: response,
}
}
type promptResponseSent struct {
Response *PromptResponse
Error error
}
func newPromptResponseSentError(err error) *promptResponseSent {
return &promptResponseSent{
Error: err,
}
}
// newToolResponseSent creates a new toolResponseSent
func newPromptResponseSent(response *PromptResponse) *promptResponseSent {
return &promptResponseSent{
Response: response,
}
}
// Custom JSON marshaling for PromptResponse
func (c promptResponseSent) MarshalJSON() ([]byte, error) {
if c.Error != nil {
errorText := c.Error.Error()
c.Response = NewPromptResponse("error", NewPromptMessage(NewTextContent(errorText), RoleUser))
}
return json.Marshal(c)
}
type Server struct {
transport transport.Transport
tools map[string]*tool
prompts map[string]*prompt
resources map[string]*resource
serverInstructions *string
serverName string
serverVersion string
}
type prompt struct {
Name string
Description string
Handler func(baseGetPromptRequestParamsArguments) *promptResponseSent
PromptInputSchema *promptSchema
}
type tool struct {
Name string
Description string
Handler func(baseCallToolRequestParams) *toolResponseSent
ToolInputSchema *jsonschema.Schema
}
type resource struct {
Name string
Description string
Uri string
mimeType string
Handler func() *resourceResponseSent
}
func NewServer(transport transport.Transport) *Server {
return &Server{
transport: transport,
tools: make(map[string]*tool),
prompts: make(map[string]*prompt),
resources: make(map[string]*resource),
}
}
// RegisterTool registers a new tool with the server
func (s *Server) RegisterTool(name string, description string, handler any) error {
err := validateToolHandler(handler)
if err != nil {
return err
}
inputSchema := createJsonSchemaFromHandler(handler)
s.tools[name] = &tool{
Name: name,
Description: description,
Handler: createWrappedToolHandler(handler),
ToolInputSchema: inputSchema,
}
return nil
}
func (s *Server) RegisterResource(uri string, name string, description string, mimeType string, handler any) error {
err := validateResourceHandler(handler)
if err != nil {
panic(err)
}
s.resources[uri] = &resource{
Name: name,
Description: description,
Uri: uri,
mimeType: mimeType,
Handler: createWrappedResourceHandler(handler),
}
return nil
}
func createWrappedResourceHandler(userHandler any) func() *resourceResponseSent {
handlerValue := reflect.ValueOf(userHandler)
return func() *resourceResponseSent {
// Call the handler with no arguments
output := handlerValue.Call([]reflect.Value{})
if len(output) != 2 {
return newResourceResponseSentError(fmt.Errorf("handler must return exactly two values, got %d", len(output)))
}
if !output[0].CanInterface() {
return newResourceResponseSentError(fmt.Errorf("handler must return a struct, got %s", output[0].Type().Name()))
}
promptR := output[0].Interface()
if !output[1].CanInterface() {
return newResourceResponseSentError(fmt.Errorf("handler must return an error, got %s", output[1].Type().Name()))
}
errorOut := output[1].Interface()
if errorOut == nil {
return newResourceResponseSent(promptR.(*ResourceResponse))
}
return newResourceResponseSentError(errorOut.(error))
}
}
// We just want to check that handler takes no arguments and returns a ResourceResponse and an error
func validateResourceHandler(handler any) error {
handlerValue := reflect.ValueOf(handler)
handlerType := handlerValue.Type()
if handlerType.NumIn() != 0 {
return fmt.Errorf("handler must take no arguments, got %d", handlerType.NumIn())
}
if handlerType.NumOut() != 2 {
return fmt.Errorf("handler must return exactly two values, got %d", handlerType.NumOut())
}
//if handlerType.Out(0) != reflect.TypeOf((*ResourceResponse)(nil)).Elem() {
// return fmt.Errorf("handler must return ResourceResponse, got %s", handlerType.Out(0).Name())
//}
//if handlerType.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
// return fmt.Errorf("handler must return error, got %s", handlerType.Out(1).Name())
//}
return nil
}
func (s *Server) RegisterPrompt(name string, description string, handler any) error {
err := validatePromptHandler(handler)
if err != nil {
return err
}
promptSchema := createPromptSchemaFromHandler(handler)
s.prompts[name] = &prompt{
Name: name,
Description: description,
Handler: createWrappedPromptHandler(handler),
PromptInputSchema: promptSchema,
}
return nil
}
func createWrappedPromptHandler(userHandler any) func(baseGetPromptRequestParamsArguments) *promptResponseSent {
handlerValue := reflect.ValueOf(userHandler)
handlerType := handlerValue.Type()
argumentType := handlerType.In(0)
return func(arguments baseGetPromptRequestParamsArguments) *promptResponseSent {
// Instantiate a struct of the type of the arguments
if !reflect.New(argumentType).CanInterface() {
return newPromptResponseSentError(fmt.Errorf("arguments must be a struct"))
}
unmarshaledArguments := reflect.New(argumentType).Interface()
// Unmarshal the JSON into the correct type
err := json.Unmarshal(arguments.Arguments, &unmarshaledArguments)
if err != nil {
return newPromptResponseSentError(fmt.Errorf("failed to unmarshal arguments: %w", err))
}
// Need to dereference the unmarshaled arguments
of := reflect.ValueOf(unmarshaledArguments)
if of.Kind() != reflect.Ptr || !of.Elem().CanInterface() {
return newPromptResponseSentError(fmt.Errorf("arguments must be a struct"))
}
// Call the handler with the typed arguments
output := handlerValue.Call([]reflect.Value{of.Elem()})
if len(output) != 2 {
return newPromptResponseSentError(fmt.Errorf("handler must return exactly two values, got %d", len(output)))
}
if !output[0].CanInterface() {
return newPromptResponseSentError(fmt.Errorf("handler must return a struct, got %s", output[0].Type().Name()))
}
promptR := output[0].Interface()
if !output[1].CanInterface() {
return newPromptResponseSentError(fmt.Errorf("handler must return an error, got %s", output[1].Type().Name()))
}
errorOut := output[1].Interface()
if errorOut == nil {
return newPromptResponseSent(promptR.(*PromptResponse))
}
return newPromptResponseSentError(errorOut.(error))
}
}
// Get the argument and iterate over the fields, we pull description from the jsonschema description tag
// We pull required from the jsonschema required tag
// Example:
// type Content struct {
// Title string `json:"title" jsonschema:"description=The title to submit,required"`
// Description *string `json:"description" jsonschema:"description=The description to submit"`
// }
// Then we get the jsonschema for the struct where Title is a required field and Description is an optional field
func createPromptSchemaFromHandler(handler any) *promptSchema {
handlerValue := reflect.ValueOf(handler)
handlerType := handlerValue.Type()
argumentType := handlerType.In(0)
promptSchema := promptSchema{
Arguments: make([]promptSchemaArgument, argumentType.NumField()),
}
for i := 0; i < argumentType.NumField(); i++ {
field := argumentType.Field(i)
fieldName := field.Name
jsonSchemaTags := strings.Split(field.Tag.Get("jsonschema"), ",")
var description *string
var required = false
for _, tag := range jsonSchemaTags {
if strings.HasPrefix(tag, "description=") {
s := strings.TrimPrefix(tag, "description=")
description = &s
}
if tag == "required" {
required = true
}
}
promptSchema.Arguments[i] = promptSchemaArgument{
Name: fieldName,
Description: description,
Required: &required,
}
}
return &promptSchema
}
// A prompt can only take a struct with fields of type string or *string as the argument
func validatePromptHandler(handler any) error {
handlerValue := reflect.ValueOf(handler)
handlerType := handlerValue.Type()
argumentType := handlerType.In(0)
if argumentType.Kind() != reflect.Struct {
return fmt.Errorf("argument must be a struct")
}
for i := 0; i < argumentType.NumField(); i++ {
field := argumentType.Field(i)
isValid := false
if field.Type.Kind() == reflect.String {
isValid = true
}
if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.String {
isValid = true
}
if !isValid {
return fmt.Errorf("all fields of the struct must be of type string or *string, found %s", field.Type.Kind())
}
}
return nil
}
// Creates a full JSON schema from a user provided handler by introspecting the arguments
func createJsonSchemaFromHandler(handler any) *jsonschema.Schema {
handlerValue := reflect.ValueOf(handler)
handlerType := handlerValue.Type()
argumentType := handlerType.In(0)
inputSchema := jsonSchemaReflector.ReflectFromType(argumentType)
return inputSchema
}
// This takes a user provided handler and returns a wrapped handler which can be used to actually answer requests
// Concretely, it will deserialize the arguments and call the user provided handler and then serialize the response
// If the handler returns an error, it will be serialized and sent back as a tool error rather than a protocol error
func createWrappedToolHandler(userHandler any) func(baseCallToolRequestParams) *toolResponseSent {
handlerValue := reflect.ValueOf(userHandler)
handlerType := handlerValue.Type()
argumentType := handlerType.In(0)
return func(arguments baseCallToolRequestParams) *toolResponseSent {
// Instantiate a struct of the type of the arguments
if !reflect.New(argumentType).CanInterface() {
return newToolResponseSentError(fmt.Errorf("arguments must be a struct"))
}
unmarshaledArguments := reflect.New(argumentType).Interface()
// Unmarshal the JSON into the correct type
err := json.Unmarshal(arguments.Arguments, &unmarshaledArguments)
if err != nil {
return newToolResponseSentError(fmt.Errorf("failed to unmarshal arguments: %w", err))
}
// Need to dereference the unmarshaled arguments
of := reflect.ValueOf(unmarshaledArguments)
if of.Kind() != reflect.Ptr || !of.Elem().CanInterface() {
return newToolResponseSentError(fmt.Errorf("arguments must be a struct"))
}
// Call the handler with the typed arguments
output := handlerValue.Call([]reflect.Value{of.Elem()})
if len(output) != 2 {
return newToolResponseSentError(fmt.Errorf("handler must return exactly two values, got %d", len(output)))
}
if !output[0].CanInterface() {
return newToolResponseSentError(fmt.Errorf("handler must return a struct, got %s", output[0].Type().Name()))
}
tool := output[0].Interface()
if !output[1].CanInterface() {
return newToolResponseSentError(fmt.Errorf("handler must return an error, got %s", output[1].Type().Name()))
}
errorOut := output[1].Interface()
if errorOut == nil {
return newToolResponseSent(tool.(*ToolResponse))
}
return newToolResponseSentError(errorOut.(error))
}
}
func (s *Server) Serve() error {
protocol := protocol.NewProtocol(nil)
protocol.SetRequestHandler("initialize", s.handleInitialize)
protocol.SetRequestHandler("tools/list", s.handleListTools)
protocol.SetRequestHandler("tools/call", s.handleToolCalls)
protocol.SetRequestHandler("prompts/list", s.handleListPrompts)
protocol.SetRequestHandler("prompts/get", s.handlePromptCalls)
protocol.SetRequestHandler("resources/list", s.handleListResources)
protocol.SetRequestHandler("resources/read", s.handleResourceCalls)
return protocol.Connect(s.transport)
}
func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (interface{}, error) {
return initializeResult{
Meta: nil,
Capabilities: s.generateCapabilities(),
Instructions: s.serverInstructions,
ProtocolVersion: "2024-11-05",
ServerInfo: implementation{
Name: s.serverName,
Version: s.serverVersion,
},
}, nil
}
func (s *Server) handleListTools(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (interface{}, error) {
return tools.ToolsResponse{
Tools: func() []tools.ToolRetType {
var ts []tools.ToolRetType
for _, tool := range s.tools {
ts = append(ts, tools.ToolRetType{
Name: tool.Name,
Description: &tool.Description,
InputSchema: tool.ToolInputSchema,
})
}
return ts
}(),
}, nil
}
func (s *Server) handleToolCalls(req *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (interface{}, error) {
params := baseCallToolRequestParams{}
// Instantiate a struct of the type of the arguments
err := json.Unmarshal(req.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
}
for name, tool := range s.tools {
if name != params.Name {
continue
}
return tool.Handler(params), nil
}
return nil, fmt.Errorf("unknown tool: %s", req.Method)
}
func (s *Server) generateCapabilities() serverCapabilities {
f := false
return serverCapabilities{
Tools: func() *serverCapabilitiesTools {
if s.tools == nil {
return nil
}
return &serverCapabilitiesTools{
ListChanged: &f,
}
}(),
}
}
func (s *Server) handleListPrompts(request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (interface{}, error) {
return listPromptsResult{
Prompts: func() []*promptSchema {
var prompts []*promptSchema
for _, prompt := range s.prompts {
prompts = append(prompts, prompt.PromptInputSchema)
}
return prompts
}(),
}, nil
}
func (s *Server) handleListResources(request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (interface{}, error) {
return listResourcesResult{
Resources: func() []*resourceSchema {
var resources []*resourceSchema
for _, resource := range s.resources {
resources = append(resources, &resourceSchema{
Annotations: nil,
Description: &resource.Description,
MimeType: &resource.mimeType,
Name: resource.Name,
Uri: resource.Uri,
})
}
return resources
}(),
}, nil
}
func (s *Server) handlePromptCalls(req *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (interface{}, error) {
params := baseGetPromptRequestParamsArguments{}
// Instantiate a struct of the type of the arguments
err := json.Unmarshal(req.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
}
for name, prompter := range s.prompts {
if name != params.Name {
continue
}
return prompter.Handler(params), nil
}
return nil, fmt.Errorf("unknown prompt: %s", req.Method)
}
func (s *Server) handleResourceCalls(req *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (interface{}, error) {
params := readResourceRequestParams{}
// Instantiate a struct of the type of the arguments
err := json.Unmarshal(req.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
}
for name, resource := range s.resources {
if name != params.Uri {
continue
}
return resource.Handler(), nil
}
return nil, fmt.Errorf("unknown prompt: %s", req.Method)
}
func validateToolHandler(handler any) error {
handlerValue := reflect.ValueOf(handler)
handlerType := handlerValue.Type()
if handlerType.NumIn() != 1 {
return fmt.Errorf("handler must take exactly one argument, got %d", handlerType.NumIn())
}
if handlerType.NumOut() != 2 {
return fmt.Errorf("handler must return exactly two values, got %d", handlerType.NumOut())
}
// Check that the output type is *tools.ToolResponse
if handlerType.Out(0) != reflect.PointerTo(reflect.TypeOf(ToolResponse{})) {
return fmt.Errorf("handler must return *tools.ToolResponse, got %s", handlerType.Out(0).Name())
}
// Check that the output type is error
if handlerType.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("handler must return error, got %s", handlerType.Out(1).Name())
}
return nil
}
var (
jsonSchemaReflector = jsonschema.Reflector{
BaseSchemaID: "",
Anonymous: true,
AssignAnchor: false,
AllowAdditionalProperties: true,
RequiredFromJSONSchemaTags: true,
DoNotReference: true,
ExpandedStruct: true,
FieldNameTag: "",
IgnoredTypes: nil,
Lookup: nil,
Mapper: nil,
Namer: nil,
KeyNamer: nil,
AdditionalFields: nil,
CommentMap: nil,
}
)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/tmfll/mcp-golang.git
git@gitee.com:tmfll/mcp-golang.git
tmfll
mcp-golang
mcp-golang
Chrisbattarbee-patch-1

搜索帮助