Files
OmniSocketGo/cmd/internal/server/kcp_hub.go
2026-03-28 15:28:19 +08:00

415 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package server
import (
"errors"
"fmt"
"log"
"net"
"strings"
"sync"
"time"
kcp "github.com/xtaci/kcp-go/v5"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
const kcpRelayMaxDatagramSize = 60 * 1024
var (
errKCPRelayUnavailable = errors.New("server: kcp relay socket is not configured")
errKCPRelayPeerUnknown = errors.New("server: kcp relay peer address is unknown")
errKCPRelayTooLarge = errors.New("server: kcp relay message too large")
errKCPUnknownLocalTarget = errors.New("server: unknown local kcp target")
)
// KCPOption 用于配置 KCPHub 的可选行为。
type KCPOption func(*KCPHub)
// WithKCPLogger 为 KCP hub 注入时延日志记录器。
func WithKCPLogger(logger latencylog.Logger) KCPOption {
return func(hub *KCPHub) {
hub.logger = logger
}
}
// WithKCPSessionStatsLogger 为 KCP hub 注入会话统计日志器。
func WithKCPSessionStatsLogger(logger transport.KCPSessionStatsLogger, interval time.Duration) KCPOption {
return func(hub *KCPHub) {
hub.sessionStatsLogger = logger
hub.sessionStatsInterval = interval
}
}
// KCPHub 管理已注册 peer 的 KCP 会话,并负责在它们之间转发消息。
type KCPHub struct {
mu sync.RWMutex
peers map[string]*transport.KCPConn
logger latencylog.Logger
sessionStatsLogger transport.KCPSessionStatsLogger
sessionStatsInterval time.Duration
relaySocket net.PacketConn
relayPeerAddr net.Addr
relayLearnPeer bool
}
// NewKCPHub 创建一个空的 KCP 连接中心。
func NewKCPHub(opts ...KCPOption) *KCPHub {
hub := &KCPHub{
peers: make(map[string]*transport.KCPConn),
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
return hub
}
// SetRelaySocket 配置 KCPHub 的原始 UDP relay 信道。
func (h *KCPHub) SetRelaySocket(conn net.PacketConn, peerAddr net.Addr, learnPeer bool) {
h.mu.Lock()
defer h.mu.Unlock()
h.relaySocket = conn
h.relayPeerAddr = cloneRelayAddr(peerAddr)
h.relayLearnPeer = learnPeer
}
// HasPeer 返回给定 ID 是否已经注册到 hub。
func (h *KCPHub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
// ServeRelay 持续从 relay UDP socket 读取消息,并尝试本地投递。
func (h *KCPHub) ServeRelay() error {
h.mu.RLock()
conn := h.relaySocket
h.mu.RUnlock()
if conn == nil {
return errKCPRelayUnavailable
}
buffer := make([]byte, kcpRelayMaxDatagramSize)
for {
n, addr, err := conn.ReadFrom(buffer)
if err != nil {
if isExpectedRelayServeExit(err) {
return nil
}
return fmt.Errorf("server: relay receive packet: %w", err)
}
if !h.acceptRelayPeer(addr) {
log.Printf("kcp relay dropped packet from unexpected peer %s", addr)
continue
}
msg, err := protocol.DecodeMessage(buffer[:n])
if err != nil {
log.Printf("kcp relay dropped invalid packet from %s: %v", addr, err)
continue
}
if !isRelayBusinessOrErrorMessage(msg.Type) {
log.Printf("kcp relay dropped unsupported message type %s from %s", msg.Type, addr)
continue
}
if err := h.deliverRelayedMessage(msg); err != nil {
log.Printf("kcp relay delivery for %s -> %s failed: %v", msg.From, msg.To, err)
}
}
}
// ServeSession 处理一条新接入的 KCP 会话。
func (h *KCPHub) ServeSession(session *kcp.UDPSession) error {
sessionDesc := describeKCPSession(session)
log.Printf("kcp hub accepted session %s", sessionDesc)
conn, err := transport.NewKCPConn(
session,
transport.WithKCPLogger(h.logger, latencylog.NodeRoleServer, "hub"),
transport.WithKCPSessionStatsLogger(h.sessionStatsLogger, h.sessionStatsInterval),
)
if err != nil {
_ = session.Close()
return fmt.Errorf("server: create kcp transport conn: %w", err)
}
peerID, err := h.registerConn(conn, sessionDesc)
if err != nil {
_ = conn.Close()
return err
}
defer h.unregister(peerID, conn, sessionDesc)
return h.receivePeerLoop(peerID, conn, sessionDesc)
}
// 注册新连接时KCPHub 期望第一条消息是一个 register 消息,包含 peer 的 ID
func (h *KCPHub) registerConn(conn *transport.KCPConn, sessionDesc string) (string, error) {
msg, err := conn.Receive()
if err != nil {
log.Printf("kcp hub session %s failed before register: %v", sessionDesc, err)
return "", fmt.Errorf("server: receive kcp register: %w", err)
}
if msg.Type != protocol.MessageTypeRegister {
log.Printf("kcp hub rejecting session %s: first message type=%s from=%s", sessionDesc, msg.Type, msg.From)
if sendErr := sendKCPServerError(conn, msg.From, "first message must be register"); sendErr != nil {
return "", fmt.Errorf("server: reject unregistered kcp peer: %w", sendErr)
}
return "", fmt.Errorf("server: first kcp message must be register, got %s", msg.Type)
}
h.mu.Lock()
defer h.mu.Unlock()
if _, exists := h.peers[msg.From]; exists {
log.Printf("kcp hub rejecting duplicate peer %q on session %s", msg.From, sessionDesc)
if sendErr := sendKCPServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
return "", fmt.Errorf("server: duplicate kcp peer id %s: %w", msg.From, sendErr)
}
return "", fmt.Errorf("server: duplicate kcp peer id: %s", msg.From)
}
h.peers[msg.From] = conn
log.Printf("kcp hub registered peer %q on session %s (peers=%d)", msg.From, sessionDesc, len(h.peers))
return msg.From, nil
}
// handlePeerMessage 处理已注册 peer 发来的消息,并将其转发给目标 peer。
func (h *KCPHub) handlePeerMessage(peerID string, conn *transport.KCPConn, msg protocol.Message) error {
switch msg.Type {
case protocol.MessageTypeText, protocol.MessageTypeFile:
msg.From = peerID
if err := h.deliverToLocalPeer(msg); err == nil {
return nil
} else if !errors.Is(err, errKCPUnknownLocalTarget) {
log.Printf("kcp hub local delivery failed for %s -> %s: %v", peerID, msg.To, err)
return sendKCPServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
}
log.Printf("kcp hub local target miss for %s -> %s; attempting relay", peerID, msg.To)
err := h.forwardToRelay(msg)
switch {
case err == nil:
return nil
case errors.Is(err, errKCPRelayUnavailable):
log.Printf("kcp hub target %s unavailable for %s: no relay configured", msg.To, peerID)
return sendKCPServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
case errors.Is(err, errKCPRelayPeerUnknown):
log.Printf("kcp hub relay peer address is unknown for %s -> %s", peerID, msg.To)
return sendKCPServerError(conn, peerID, "failed to relay to remote peer")
case errors.Is(err, errKCPRelayTooLarge):
log.Printf("kcp hub relay rejected oversize message %s -> %s (%d bytes)", peerID, msg.To, len(msg.Body))
return sendKCPServerError(conn, peerID, "message too large for relay udp")
default:
log.Printf("kcp hub relay forward failed for %s -> %s: %v", peerID, msg.To, err)
return sendKCPServerError(conn, peerID, "failed to relay to remote peer")
}
case protocol.MessageTypeRegister, protocol.MessageTypeError:
if err := sendKCPServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
return fmt.Errorf("server: send kcp protocol error: %w", err)
}
return fmt.Errorf("server: unexpected kcp message type from peer %s: %s", peerID, msg.Type)
default:
if err := sendKCPServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
return fmt.Errorf("server: send unsupported kcp type error: %w", err)
}
return fmt.Errorf("server: unsupported kcp message type: %s", msg.Type)
}
}
// receivePeerLoop 持续读取 peer 发来的消息,并交给 handlePeerMessage 处理,直到连接出错。
func (h *KCPHub) receivePeerLoop(peerID string, conn *transport.KCPConn, sessionDesc string) error {
for {
msg, err := conn.Receive()
if err != nil {
_ = conn.Close()
log.Printf("kcp hub receive loop ending for peer %q on session %s: %v", peerID, sessionDesc, err)
return fmt.Errorf("transport: kcp receive loop read: %w", err)
}
if err := h.handlePeerMessage(peerID, conn, msg); err != nil {
_ = conn.Close()
log.Printf("kcp hub handler ending for peer %q on session %s: %v", peerID, sessionDesc, err)
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
}
}
}
func (h *KCPHub) deliverRelayedMessage(msg protocol.Message) error {
if err := h.deliverToLocalPeer(msg); err == nil {
return nil
} else if !errors.Is(err, errKCPUnknownLocalTarget) {
if msg.Type == protocol.MessageTypeError {
log.Printf("kcp relay dropped undeliverable server error to %s: %v", msg.To, err)
return nil
}
return h.forwardRelayServerError(msg.From, fmt.Sprintf("failed to forward to %s", msg.To))
}
if msg.Type == protocol.MessageTypeError {
log.Printf("kcp relay dropped server error for unknown local peer %s", msg.To)
return nil
}
log.Printf("kcp hub relayed target miss for %s -> %s; sending error back", msg.From, msg.To)
return h.forwardRelayServerError(msg.From, fmt.Sprintf("unknown target: %s", msg.To))
}
func (h *KCPHub) deliverToLocalPeer(msg protocol.Message) error {
targetConn, ok := h.lookup(msg.To)
if !ok {
return fmt.Errorf("%w: %s", errKCPUnknownLocalTarget, msg.To)
}
if err := targetConn.Send(msg); err != nil {
h.unregister(msg.To, targetConn, "local-forward-failure")
_ = targetConn.Close()
return fmt.Errorf("server: forward to local peer %s: %w", msg.To, err)
}
return nil
}
func (h *KCPHub) forwardToRelay(msg protocol.Message) error {
payload, err := protocol.EncodeMessage(msg)
if err != nil {
return fmt.Errorf("server: encode relay message: %w", err)
}
if len(payload) > kcpRelayMaxDatagramSize {
return errKCPRelayTooLarge
}
h.mu.RLock()
conn := h.relaySocket
peerAddr := cloneRelayAddr(h.relayPeerAddr)
h.mu.RUnlock()
if conn == nil {
return errKCPRelayUnavailable
}
if peerAddr == nil {
return errKCPRelayPeerUnknown
}
if _, err := conn.WriteTo(payload, peerAddr); err != nil {
return fmt.Errorf("server: relay write to %s: %w", peerAddr, err)
}
return nil
}
func (h *KCPHub) forwardRelayServerError(to, message string) error {
return h.forwardToRelay(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
})
}
func (h *KCPHub) acceptRelayPeer(addr net.Addr) bool {
h.mu.Lock()
defer h.mu.Unlock()
if h.relayPeerAddr == nil && h.relayLearnPeer {
h.relayPeerAddr = cloneRelayAddr(addr)
log.Printf("kcp hub learned relay peer %s", addr)
return true
}
if h.relayPeerAddr == nil {
return true
}
return sameRelayAddr(h.relayPeerAddr, addr)
}
func (h *KCPHub) lookup(peerID string) (*transport.KCPConn, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, ok := h.peers[peerID]
return conn, ok
}
func (h *KCPHub) unregister(peerID string, conn *transport.KCPConn, sessionDesc string) {
h.mu.Lock()
defer h.mu.Unlock()
current, ok := h.peers[peerID]
if ok && current == conn {
delete(h.peers, peerID)
log.Printf("kcp hub unregistered peer %q from session %s (peers=%d)", peerID, sessionDesc, len(h.peers))
}
}
func sendKCPServerError(conn *transport.KCPConn, to, message string) error {
return conn.Send(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
})
}
func isRelayBusinessOrErrorMessage(messageType protocol.MessageType) bool {
switch messageType {
case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError:
return true
default:
return false
}
}
func isExpectedRelayServeExit(err error) bool {
return errors.Is(err, net.ErrClosed) || strings.Contains(err.Error(), "use of closed network connection")
}
func cloneRelayAddr(addr net.Addr) net.Addr {
if addr == nil {
return nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return addr
}
ipCopy := make([]byte, len(udpAddr.IP))
copy(ipCopy, udpAddr.IP)
return &net.UDPAddr{
IP: ipCopy,
Port: udpAddr.Port,
Zone: udpAddr.Zone,
}
}
func sameRelayAddr(left, right net.Addr) bool {
if left == nil || right == nil {
return left == right
}
return left.String() == right.String()
}
func describeKCPSession(session *kcp.UDPSession) string {
if session == nil {
return "conv=<nil> remote=<nil> local=<nil>"
}
return fmt.Sprintf("conv=%d remote=%s local=%s", session.GetConv(), addrString(session.RemoteAddr()), addrString(session.LocalAddr()))
}
func addrString(addr net.Addr) string {
if addr == nil {
return "<nil>"
}
return addr.String()
}