415 lines
12 KiB
Go
415 lines
12 KiB
Go
package server
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
kcp "github.com/xtaci/kcp-go/v5"
|
||
|
||
"omnisocketgo/cmd/internal/latencylog"
|
||
"omnisocketgo/cmd/internal/protocol"
|
||
"omnisocketgo/cmd/internal/transport"
|
||
)
|
||
|
||
const kcpRelayMaxDatagramSize = 60 * 1024
|
||
|
||
var (
|
||
errKCPRelayUnavailable = errors.New("server: kcp relay socket is not configured")
|
||
errKCPRelayPeerUnknown = errors.New("server: kcp relay peer address is unknown")
|
||
errKCPRelayTooLarge = errors.New("server: kcp relay message too large")
|
||
errKCPUnknownLocalTarget = errors.New("server: unknown local kcp target")
|
||
)
|
||
|
||
// KCPOption 用于配置 KCPHub 的可选行为。
|
||
type KCPOption func(*KCPHub)
|
||
|
||
// WithKCPLogger 为 KCP hub 注入时延日志记录器。
|
||
func WithKCPLogger(logger latencylog.Logger) KCPOption {
|
||
return func(hub *KCPHub) {
|
||
hub.logger = logger
|
||
}
|
||
}
|
||
|
||
// WithKCPSessionStatsLogger 为 KCP hub 注入会话统计日志器。
|
||
func WithKCPSessionStatsLogger(logger transport.KCPSessionStatsLogger, interval time.Duration) KCPOption {
|
||
return func(hub *KCPHub) {
|
||
hub.sessionStatsLogger = logger
|
||
hub.sessionStatsInterval = interval
|
||
}
|
||
}
|
||
|
||
// KCPHub 管理已注册 peer 的 KCP 会话,并负责在它们之间转发消息。
|
||
type KCPHub struct {
|
||
mu sync.RWMutex
|
||
peers map[string]*transport.KCPConn
|
||
logger latencylog.Logger
|
||
sessionStatsLogger transport.KCPSessionStatsLogger
|
||
sessionStatsInterval time.Duration
|
||
relaySocket net.PacketConn
|
||
relayPeerAddr net.Addr
|
||
relayLearnPeer bool
|
||
}
|
||
|
||
// NewKCPHub 创建一个空的 KCP 连接中心。
|
||
func NewKCPHub(opts ...KCPOption) *KCPHub {
|
||
hub := &KCPHub{
|
||
peers: make(map[string]*transport.KCPConn),
|
||
logger: latencylog.NoopLogger{},
|
||
}
|
||
for _, opt := range opts {
|
||
opt(hub)
|
||
}
|
||
if hub.logger == nil {
|
||
hub.logger = latencylog.NoopLogger{}
|
||
}
|
||
return hub
|
||
}
|
||
|
||
// SetRelaySocket 配置 KCPHub 的原始 UDP relay 信道。
|
||
func (h *KCPHub) SetRelaySocket(conn net.PacketConn, peerAddr net.Addr, learnPeer bool) {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
h.relaySocket = conn
|
||
h.relayPeerAddr = cloneRelayAddr(peerAddr)
|
||
h.relayLearnPeer = learnPeer
|
||
}
|
||
|
||
// HasPeer 返回给定 ID 是否已经注册到 hub。
|
||
func (h *KCPHub) HasPeer(peerID string) bool {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
_, ok := h.peers[peerID]
|
||
return ok
|
||
}
|
||
|
||
// ServeRelay 持续从 relay UDP socket 读取消息,并尝试本地投递。
|
||
func (h *KCPHub) ServeRelay() error {
|
||
h.mu.RLock()
|
||
conn := h.relaySocket
|
||
h.mu.RUnlock()
|
||
|
||
if conn == nil {
|
||
return errKCPRelayUnavailable
|
||
}
|
||
|
||
buffer := make([]byte, kcpRelayMaxDatagramSize)
|
||
for {
|
||
n, addr, err := conn.ReadFrom(buffer)
|
||
if err != nil {
|
||
if isExpectedRelayServeExit(err) {
|
||
return nil
|
||
}
|
||
return fmt.Errorf("server: relay receive packet: %w", err)
|
||
}
|
||
|
||
if !h.acceptRelayPeer(addr) {
|
||
log.Printf("kcp relay dropped packet from unexpected peer %s", addr)
|
||
continue
|
||
}
|
||
|
||
msg, err := protocol.DecodeMessage(buffer[:n])
|
||
if err != nil {
|
||
log.Printf("kcp relay dropped invalid packet from %s: %v", addr, err)
|
||
continue
|
||
}
|
||
|
||
if !isRelayBusinessOrErrorMessage(msg.Type) {
|
||
log.Printf("kcp relay dropped unsupported message type %s from %s", msg.Type, addr)
|
||
continue
|
||
}
|
||
|
||
if err := h.deliverRelayedMessage(msg); err != nil {
|
||
log.Printf("kcp relay delivery for %s -> %s failed: %v", msg.From, msg.To, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ServeSession 处理一条新接入的 KCP 会话。
|
||
func (h *KCPHub) ServeSession(session *kcp.UDPSession) error {
|
||
sessionDesc := describeKCPSession(session)
|
||
log.Printf("kcp hub accepted session %s", sessionDesc)
|
||
|
||
conn, err := transport.NewKCPConn(
|
||
session,
|
||
transport.WithKCPLogger(h.logger, latencylog.NodeRoleServer, "hub"),
|
||
transport.WithKCPSessionStatsLogger(h.sessionStatsLogger, h.sessionStatsInterval),
|
||
)
|
||
if err != nil {
|
||
_ = session.Close()
|
||
return fmt.Errorf("server: create kcp transport conn: %w", err)
|
||
}
|
||
|
||
peerID, err := h.registerConn(conn, sessionDesc)
|
||
if err != nil {
|
||
_ = conn.Close()
|
||
return err
|
||
}
|
||
defer h.unregister(peerID, conn, sessionDesc)
|
||
|
||
return h.receivePeerLoop(peerID, conn, sessionDesc)
|
||
}
|
||
|
||
// 注册新连接时,KCPHub 期望第一条消息是一个 register 消息,包含 peer 的 ID
|
||
func (h *KCPHub) registerConn(conn *transport.KCPConn, sessionDesc string) (string, error) {
|
||
msg, err := conn.Receive()
|
||
if err != nil {
|
||
log.Printf("kcp hub session %s failed before register: %v", sessionDesc, err)
|
||
return "", fmt.Errorf("server: receive kcp register: %w", err)
|
||
}
|
||
|
||
if msg.Type != protocol.MessageTypeRegister {
|
||
log.Printf("kcp hub rejecting session %s: first message type=%s from=%s", sessionDesc, msg.Type, msg.From)
|
||
if sendErr := sendKCPServerError(conn, msg.From, "first message must be register"); sendErr != nil {
|
||
return "", fmt.Errorf("server: reject unregistered kcp peer: %w", sendErr)
|
||
}
|
||
return "", fmt.Errorf("server: first kcp message must be register, got %s", msg.Type)
|
||
}
|
||
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
if _, exists := h.peers[msg.From]; exists {
|
||
log.Printf("kcp hub rejecting duplicate peer %q on session %s", msg.From, sessionDesc)
|
||
if sendErr := sendKCPServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
|
||
return "", fmt.Errorf("server: duplicate kcp peer id %s: %w", msg.From, sendErr)
|
||
}
|
||
return "", fmt.Errorf("server: duplicate kcp peer id: %s", msg.From)
|
||
}
|
||
|
||
h.peers[msg.From] = conn
|
||
log.Printf("kcp hub registered peer %q on session %s (peers=%d)", msg.From, sessionDesc, len(h.peers))
|
||
return msg.From, nil
|
||
}
|
||
|
||
// handlePeerMessage 处理已注册 peer 发来的消息,并将其转发给目标 peer。
|
||
func (h *KCPHub) handlePeerMessage(peerID string, conn *transport.KCPConn, msg protocol.Message) error {
|
||
switch msg.Type {
|
||
case protocol.MessageTypeText, protocol.MessageTypeFile:
|
||
msg.From = peerID
|
||
|
||
if err := h.deliverToLocalPeer(msg); err == nil {
|
||
return nil
|
||
} else if !errors.Is(err, errKCPUnknownLocalTarget) {
|
||
log.Printf("kcp hub local delivery failed for %s -> %s: %v", peerID, msg.To, err)
|
||
return sendKCPServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
|
||
}
|
||
|
||
log.Printf("kcp hub local target miss for %s -> %s; attempting relay", peerID, msg.To)
|
||
err := h.forwardToRelay(msg)
|
||
switch {
|
||
case err == nil:
|
||
return nil
|
||
case errors.Is(err, errKCPRelayUnavailable):
|
||
log.Printf("kcp hub target %s unavailable for %s: no relay configured", msg.To, peerID)
|
||
return sendKCPServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
|
||
case errors.Is(err, errKCPRelayPeerUnknown):
|
||
log.Printf("kcp hub relay peer address is unknown for %s -> %s", peerID, msg.To)
|
||
return sendKCPServerError(conn, peerID, "failed to relay to remote peer")
|
||
case errors.Is(err, errKCPRelayTooLarge):
|
||
log.Printf("kcp hub relay rejected oversize message %s -> %s (%d bytes)", peerID, msg.To, len(msg.Body))
|
||
return sendKCPServerError(conn, peerID, "message too large for relay udp")
|
||
default:
|
||
log.Printf("kcp hub relay forward failed for %s -> %s: %v", peerID, msg.To, err)
|
||
return sendKCPServerError(conn, peerID, "failed to relay to remote peer")
|
||
}
|
||
case protocol.MessageTypeRegister, protocol.MessageTypeError:
|
||
if err := sendKCPServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
|
||
return fmt.Errorf("server: send kcp protocol error: %w", err)
|
||
}
|
||
return fmt.Errorf("server: unexpected kcp message type from peer %s: %s", peerID, msg.Type)
|
||
default:
|
||
if err := sendKCPServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
|
||
return fmt.Errorf("server: send unsupported kcp type error: %w", err)
|
||
}
|
||
return fmt.Errorf("server: unsupported kcp message type: %s", msg.Type)
|
||
}
|
||
}
|
||
|
||
// receivePeerLoop 持续读取 peer 发来的消息,并交给 handlePeerMessage 处理,直到连接出错。
|
||
func (h *KCPHub) receivePeerLoop(peerID string, conn *transport.KCPConn, sessionDesc string) error {
|
||
for {
|
||
msg, err := conn.Receive()
|
||
if err != nil {
|
||
_ = conn.Close()
|
||
log.Printf("kcp hub receive loop ending for peer %q on session %s: %v", peerID, sessionDesc, err)
|
||
return fmt.Errorf("transport: kcp receive loop read: %w", err)
|
||
}
|
||
|
||
if err := h.handlePeerMessage(peerID, conn, msg); err != nil {
|
||
_ = conn.Close()
|
||
log.Printf("kcp hub handler ending for peer %q on session %s: %v", peerID, sessionDesc, err)
|
||
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (h *KCPHub) deliverRelayedMessage(msg protocol.Message) error {
|
||
if err := h.deliverToLocalPeer(msg); err == nil {
|
||
return nil
|
||
} else if !errors.Is(err, errKCPUnknownLocalTarget) {
|
||
if msg.Type == protocol.MessageTypeError {
|
||
log.Printf("kcp relay dropped undeliverable server error to %s: %v", msg.To, err)
|
||
return nil
|
||
}
|
||
return h.forwardRelayServerError(msg.From, fmt.Sprintf("failed to forward to %s", msg.To))
|
||
}
|
||
|
||
if msg.Type == protocol.MessageTypeError {
|
||
log.Printf("kcp relay dropped server error for unknown local peer %s", msg.To)
|
||
return nil
|
||
}
|
||
|
||
log.Printf("kcp hub relayed target miss for %s -> %s; sending error back", msg.From, msg.To)
|
||
return h.forwardRelayServerError(msg.From, fmt.Sprintf("unknown target: %s", msg.To))
|
||
}
|
||
|
||
func (h *KCPHub) deliverToLocalPeer(msg protocol.Message) error {
|
||
targetConn, ok := h.lookup(msg.To)
|
||
if !ok {
|
||
return fmt.Errorf("%w: %s", errKCPUnknownLocalTarget, msg.To)
|
||
}
|
||
if err := targetConn.Send(msg); err != nil {
|
||
h.unregister(msg.To, targetConn, "local-forward-failure")
|
||
_ = targetConn.Close()
|
||
return fmt.Errorf("server: forward to local peer %s: %w", msg.To, err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (h *KCPHub) forwardToRelay(msg protocol.Message) error {
|
||
payload, err := protocol.EncodeMessage(msg)
|
||
if err != nil {
|
||
return fmt.Errorf("server: encode relay message: %w", err)
|
||
}
|
||
if len(payload) > kcpRelayMaxDatagramSize {
|
||
return errKCPRelayTooLarge
|
||
}
|
||
|
||
h.mu.RLock()
|
||
conn := h.relaySocket
|
||
peerAddr := cloneRelayAddr(h.relayPeerAddr)
|
||
h.mu.RUnlock()
|
||
|
||
if conn == nil {
|
||
return errKCPRelayUnavailable
|
||
}
|
||
if peerAddr == nil {
|
||
return errKCPRelayPeerUnknown
|
||
}
|
||
|
||
if _, err := conn.WriteTo(payload, peerAddr); err != nil {
|
||
return fmt.Errorf("server: relay write to %s: %w", peerAddr, err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (h *KCPHub) forwardRelayServerError(to, message string) error {
|
||
return h.forwardToRelay(protocol.Message{
|
||
Type: protocol.MessageTypeError,
|
||
From: protocol.ServerPeerID,
|
||
To: to,
|
||
Body: []byte(message),
|
||
})
|
||
}
|
||
|
||
func (h *KCPHub) acceptRelayPeer(addr net.Addr) bool {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
if h.relayPeerAddr == nil && h.relayLearnPeer {
|
||
h.relayPeerAddr = cloneRelayAddr(addr)
|
||
log.Printf("kcp hub learned relay peer %s", addr)
|
||
return true
|
||
}
|
||
if h.relayPeerAddr == nil {
|
||
return true
|
||
}
|
||
return sameRelayAddr(h.relayPeerAddr, addr)
|
||
}
|
||
|
||
func (h *KCPHub) lookup(peerID string) (*transport.KCPConn, bool) {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
conn, ok := h.peers[peerID]
|
||
return conn, ok
|
||
}
|
||
|
||
func (h *KCPHub) unregister(peerID string, conn *transport.KCPConn, sessionDesc string) {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
current, ok := h.peers[peerID]
|
||
if ok && current == conn {
|
||
delete(h.peers, peerID)
|
||
log.Printf("kcp hub unregistered peer %q from session %s (peers=%d)", peerID, sessionDesc, len(h.peers))
|
||
}
|
||
}
|
||
|
||
func sendKCPServerError(conn *transport.KCPConn, to, message string) error {
|
||
return conn.Send(protocol.Message{
|
||
Type: protocol.MessageTypeError,
|
||
From: protocol.ServerPeerID,
|
||
To: to,
|
||
Body: []byte(message),
|
||
})
|
||
}
|
||
|
||
func isRelayBusinessOrErrorMessage(messageType protocol.MessageType) bool {
|
||
switch messageType {
|
||
case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func isExpectedRelayServeExit(err error) bool {
|
||
return errors.Is(err, net.ErrClosed) || strings.Contains(err.Error(), "use of closed network connection")
|
||
}
|
||
|
||
func cloneRelayAddr(addr net.Addr) net.Addr {
|
||
if addr == nil {
|
||
return nil
|
||
}
|
||
udpAddr, ok := addr.(*net.UDPAddr)
|
||
if !ok {
|
||
return addr
|
||
}
|
||
ipCopy := make([]byte, len(udpAddr.IP))
|
||
copy(ipCopy, udpAddr.IP)
|
||
return &net.UDPAddr{
|
||
IP: ipCopy,
|
||
Port: udpAddr.Port,
|
||
Zone: udpAddr.Zone,
|
||
}
|
||
}
|
||
|
||
func sameRelayAddr(left, right net.Addr) bool {
|
||
if left == nil || right == nil {
|
||
return left == right
|
||
}
|
||
return left.String() == right.String()
|
||
}
|
||
|
||
func describeKCPSession(session *kcp.UDPSession) string {
|
||
if session == nil {
|
||
return "conv=<nil> remote=<nil> local=<nil>"
|
||
}
|
||
return fmt.Sprintf("conv=%d remote=%s local=%s", session.GetConv(), addrString(session.RemoteAddr()), addrString(session.LocalAddr()))
|
||
}
|
||
|
||
func addrString(addr net.Addr) string {
|
||
if addr == nil {
|
||
return "<nil>"
|
||
}
|
||
return addr.String()
|
||
}
|