x
This commit is contained in:
107
internal/mcp/mcp.go
Normal file
107
internal/mcp/mcp.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
ProcessChanCap = 1000
|
||||
)
|
||||
|
||||
type MCP struct {
|
||||
sessions map[string]*Session
|
||||
sessionMu sync.Mutex
|
||||
|
||||
ProcessChan chan ProcessCtx
|
||||
}
|
||||
|
||||
func NewMCP() *MCP {
|
||||
return &MCP{
|
||||
sessions: make(map[string]*Session),
|
||||
ProcessChan: make(chan ProcessCtx, ProcessChanCap),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MCP) HandleSSE(c *gin.Context) {
|
||||
id := uuid.New().String()
|
||||
m.sessionMu.Lock()
|
||||
m.sessions[id] = NewSession(c, id)
|
||||
m.sessionMu.Unlock()
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
<-c.Request.Context().Done()
|
||||
return false
|
||||
})
|
||||
|
||||
m.sessionMu.Lock()
|
||||
delete(m.sessions, id)
|
||||
m.sessionMu.Unlock()
|
||||
}
|
||||
|
||||
func (m *MCP) GetSession(id string) *Session {
|
||||
m.sessionMu.Lock()
|
||||
defer m.sessionMu.Unlock()
|
||||
return m.sessions[id]
|
||||
}
|
||||
|
||||
func (m *MCP) HandleMessages(c *gin.Context) {
|
||||
|
||||
// panic("xxx")
|
||||
|
||||
// 啊这, 一个 sessionid 有 3 种写法 session_id, sessionId, sessionid
|
||||
// 官方 SDK 是 session_id: https://github.com/modelcontextprotocol/python-sdk/blob/c897868/src/mcp/server/sse.py#L98
|
||||
// 写的是 sessionId: https://github.com/modelcontextprotocol/inspector/blob/aeaf32f/server/src/index.ts#L157
|
||||
|
||||
sessionID := c.Query("session_id")
|
||||
if sessionID == "" {
|
||||
sessionID = c.Query("sessionId")
|
||||
}
|
||||
if sessionID == "" {
|
||||
sessionID = c.Param("sessionid")
|
||||
}
|
||||
if sessionID == "" {
|
||||
c.JSON(http.StatusBadRequest, ErrInvalidSessionID.JsonRPC())
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
session := m.GetSession(sessionID)
|
||||
if session == nil {
|
||||
c.JSON(http.StatusNotFound, ErrSessionNotFound.JsonRPC())
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrInvalidRequest.JsonRPC())
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("收到消息: %v\n", req)
|
||||
select {
|
||||
case m.ProcessChan <- ProcessCtx{Session: session, Request: &req}:
|
||||
default:
|
||||
c.JSON(http.StatusTooManyRequests, ErrTooManyRequests.JsonRPC())
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.String(http.StatusAccepted, "Accepted")
|
||||
}
|
||||
|
||||
func (m *MCP) Close() {
|
||||
close(m.ProcessChan)
|
||||
}
|
||||
|
||||
type ProcessCtx struct {
|
||||
Session *Session
|
||||
Request *Request
|
||||
}
|
||||
Reference in New Issue
Block a user