Files
OmniSocketGo/cmd/internal/server/kcp_hub.go
nnbcccscdscdsc be013b701b feat:KCP协议
2026-03-24 21:09:06 +08:00

174 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package server
import (
"fmt"
"sync"
kcp "github.com/xtaci/kcp-go/v5"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
// KCPOption 用于配置 KCPHub 的可选行为。
type KCPOption func(*KCPHub)
// WithKCPLogger 为 KCP hub 注入时延日志记录器。
func WithKCPLogger(logger latencylog.Logger) KCPOption {
return func(hub *KCPHub) {
hub.logger = logger
}
}
// KCPHub 管理已注册 peer 的 KCP 会话,并负责在它们之间转发消息。
type KCPHub struct {
mu sync.RWMutex
peers map[string]*transport.KCPConn
logger latencylog.Logger
}
// NewKCPHub 创建一个空的 KCP 连接中心。
func NewKCPHub(opts ...KCPOption) *KCPHub {
hub := &KCPHub{
peers: make(map[string]*transport.KCPConn),
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
return hub
}
// HasPeer 返回给定 ID 是否已经注册到 hub。
func (h *KCPHub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
// ServeSession 处理一条新接入的 KCP 会话。
func (h *KCPHub) ServeSession(session *kcp.UDPSession) error {
conn, err := transport.NewKCPConn(
session,
transport.WithKCPLogger(h.logger, latencylog.NodeRoleServer, "hub"),
)
if err != nil {
_ = session.Close()
return fmt.Errorf("server: create kcp transport conn: %w", err)
}
peerID, err := h.registerConn(conn)
if err != nil {
_ = conn.Close()
return err
}
defer h.unregister(peerID, conn)
return h.receivePeerLoop(peerID, conn)
}
// 注册新连接时KCPHub 期望第一条消息是一个 register 消息,包含 peer 的 ID
func (h *KCPHub) registerConn(conn *transport.KCPConn) (string, error) {
msg, err := conn.Receive()
if err != nil {
return "", fmt.Errorf("server: receive kcp register: %w", err)
}
if msg.Type != protocol.MessageTypeRegister {
if sendErr := sendKCPServerError(conn, msg.From, "first message must be register"); sendErr != nil {
return "", fmt.Errorf("server: reject unregistered kcp peer: %w", sendErr)
}
return "", fmt.Errorf("server: first kcp message must be register, got %s", msg.Type)
}
h.mu.Lock()
defer h.mu.Unlock()
if _, exists := h.peers[msg.From]; exists {
if sendErr := sendKCPServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
return "", fmt.Errorf("server: duplicate kcp peer id %s: %w", msg.From, sendErr)
}
return "", fmt.Errorf("server: duplicate kcp peer id: %s", msg.From)
}
h.peers[msg.From] = conn
return msg.From, nil
}
// handlePeerMessage 处理已注册 peer 发来的消息,并将其转发给目标 peer。
func (h *KCPHub) handlePeerMessage(peerID string, conn *transport.KCPConn, msg protocol.Message) error {
switch msg.Type {
case protocol.MessageTypeText, protocol.MessageTypeFile:
msg.From = peerID
targetConn, ok := h.lookup(msg.To)
if !ok {
return sendKCPServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
}
if err := targetConn.Send(msg); err != nil {
h.unregister(msg.To, targetConn)
_ = targetConn.Close()
return sendKCPServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
}
return nil
case protocol.MessageTypeRegister, protocol.MessageTypeError:
if err := sendKCPServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
return fmt.Errorf("server: send kcp protocol error: %w", err)
}
return fmt.Errorf("server: unexpected kcp message type from peer %s: %s", peerID, msg.Type)
default:
if err := sendKCPServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
return fmt.Errorf("server: send unsupported kcp type error: %w", err)
}
return fmt.Errorf("server: unsupported kcp message type: %s", msg.Type)
}
}
// receivePeerLoop 持续读取 peer 发来的消息,并交给 handlePeerMessage 处理,直到连接出错。
func (h *KCPHub) receivePeerLoop(peerID string, conn *transport.KCPConn) error {
for {
msg, err := conn.Receive()
if err != nil {
_ = conn.Close()
return fmt.Errorf("transport: kcp receive loop read: %w", err)
}
if err := h.handlePeerMessage(peerID, conn, msg); err != nil {
_ = conn.Close()
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
}
}
}
func (h *KCPHub) lookup(peerID string) (*transport.KCPConn, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, ok := h.peers[peerID]
return conn, ok
}
func (h *KCPHub) unregister(peerID string, conn *transport.KCPConn) {
h.mu.Lock()
defer h.mu.Unlock()
current, ok := h.peers[peerID]
if ok && current == conn {
delete(h.peers, peerID)
}
}
func sendKCPServerError(conn *transport.KCPConn, to, message string) error {
return conn.Send(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
})
}