package server import ( "errors" "fmt" "log" "net" "strings" "sync" "time" kcp "github.com/xtaci/kcp-go/v5" "omnisocketgo/cmd/internal/latencylog" "omnisocketgo/cmd/internal/protocol" "omnisocketgo/cmd/internal/transport" ) const kcpRelayMaxDatagramSize = 60 * 1024 var ( errKCPRelayUnavailable = errors.New("server: kcp relay socket is not configured") errKCPRelayPeerUnknown = errors.New("server: kcp relay peer address is unknown") errKCPRelayTooLarge = errors.New("server: kcp relay message too large") errKCPUnknownLocalTarget = errors.New("server: unknown local kcp target") ) // KCPOption 用于配置 KCPHub 的可选行为。 type KCPOption func(*KCPHub) // WithKCPLogger 为 KCP hub 注入时延日志记录器。 func WithKCPLogger(logger latencylog.Logger) KCPOption { return func(hub *KCPHub) { hub.logger = logger } } // WithKCPSessionStatsLogger 为 KCP hub 注入会话统计日志器。 func WithKCPSessionStatsLogger(logger transport.KCPSessionStatsLogger, interval time.Duration) KCPOption { return func(hub *KCPHub) { hub.sessionStatsLogger = logger hub.sessionStatsInterval = interval } } // KCPHub 管理已注册 peer 的 KCP 会话,并负责在它们之间转发消息。 type KCPHub struct { mu sync.RWMutex peers map[string]*transport.KCPConn logger latencylog.Logger sessionStatsLogger transport.KCPSessionStatsLogger sessionStatsInterval time.Duration relaySocket net.PacketConn relayPeerAddr net.Addr relayLearnPeer bool } // NewKCPHub 创建一个空的 KCP 连接中心。 func NewKCPHub(opts ...KCPOption) *KCPHub { hub := &KCPHub{ peers: make(map[string]*transport.KCPConn), logger: latencylog.NoopLogger{}, } for _, opt := range opts { opt(hub) } if hub.logger == nil { hub.logger = latencylog.NoopLogger{} } return hub } // SetRelaySocket 配置 KCPHub 的原始 UDP relay 信道。 func (h *KCPHub) SetRelaySocket(conn net.PacketConn, peerAddr net.Addr, learnPeer bool) { h.mu.Lock() defer h.mu.Unlock() h.relaySocket = conn h.relayPeerAddr = cloneRelayAddr(peerAddr) h.relayLearnPeer = learnPeer } // HasPeer 返回给定 ID 是否已经注册到 hub。 func (h *KCPHub) HasPeer(peerID string) bool { h.mu.RLock() defer h.mu.RUnlock() _, ok := h.peers[peerID] return ok } // ServeRelay 持续从 relay UDP socket 读取消息,并尝试本地投递。 func (h *KCPHub) ServeRelay() error { h.mu.RLock() conn := h.relaySocket h.mu.RUnlock() if conn == nil { return errKCPRelayUnavailable } buffer := make([]byte, kcpRelayMaxDatagramSize) for { n, addr, err := conn.ReadFrom(buffer) if err != nil { if isExpectedRelayServeExit(err) { return nil } return fmt.Errorf("server: relay receive packet: %w", err) } if !h.acceptRelayPeer(addr) { log.Printf("kcp relay dropped packet from unexpected peer %s", addr) continue } msg, err := protocol.DecodeMessage(buffer[:n]) if err != nil { log.Printf("kcp relay dropped invalid packet from %s: %v", addr, err) continue } if !isRelayBusinessOrErrorMessage(msg.Type) { log.Printf("kcp relay dropped unsupported message type %s from %s", msg.Type, addr) continue } if err := h.deliverRelayedMessage(msg); err != nil { log.Printf("kcp relay delivery for %s -> %s failed: %v", msg.From, msg.To, err) } } } // ServeSession 处理一条新接入的 KCP 会话。 func (h *KCPHub) ServeSession(session *kcp.UDPSession) error { sessionDesc := describeKCPSession(session) log.Printf("kcp hub accepted session %s", sessionDesc) conn, err := transport.NewKCPConn( session, transport.WithKCPLogger(h.logger, latencylog.NodeRoleServer, "hub"), transport.WithKCPSessionStatsLogger(h.sessionStatsLogger, h.sessionStatsInterval), ) if err != nil { _ = session.Close() return fmt.Errorf("server: create kcp transport conn: %w", err) } peerID, err := h.registerConn(conn, sessionDesc) if err != nil { _ = conn.Close() return err } defer h.unregister(peerID, conn, sessionDesc) return h.receivePeerLoop(peerID, conn, sessionDesc) } // 注册新连接时,KCPHub 期望第一条消息是一个 register 消息,包含 peer 的 ID func (h *KCPHub) registerConn(conn *transport.KCPConn, sessionDesc string) (string, error) { msg, err := conn.Receive() if err != nil { log.Printf("kcp hub session %s failed before register: %v", sessionDesc, err) return "", fmt.Errorf("server: receive kcp register: %w", err) } if msg.Type != protocol.MessageTypeRegister { log.Printf("kcp hub rejecting session %s: first message type=%s from=%s", sessionDesc, msg.Type, msg.From) if sendErr := sendKCPServerError(conn, msg.From, "first message must be register"); sendErr != nil { return "", fmt.Errorf("server: reject unregistered kcp peer: %w", sendErr) } return "", fmt.Errorf("server: first kcp message must be register, got %s", msg.Type) } h.mu.Lock() defer h.mu.Unlock() if _, exists := h.peers[msg.From]; exists { log.Printf("kcp hub rejecting duplicate peer %q on session %s", msg.From, sessionDesc) if sendErr := sendKCPServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil { return "", fmt.Errorf("server: duplicate kcp peer id %s: %w", msg.From, sendErr) } return "", fmt.Errorf("server: duplicate kcp peer id: %s", msg.From) } h.peers[msg.From] = conn log.Printf("kcp hub registered peer %q on session %s (peers=%d)", msg.From, sessionDesc, len(h.peers)) return msg.From, nil } // handlePeerMessage 处理已注册 peer 发来的消息,并将其转发给目标 peer。 func (h *KCPHub) handlePeerMessage(peerID string, conn *transport.KCPConn, msg protocol.Message) error { switch msg.Type { case protocol.MessageTypeText, protocol.MessageTypeFile: msg.From = peerID if err := h.deliverToLocalPeer(msg); err == nil { return nil } else if !errors.Is(err, errKCPUnknownLocalTarget) { log.Printf("kcp hub local delivery failed for %s -> %s: %v", peerID, msg.To, err) return sendKCPServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To)) } log.Printf("kcp hub local target miss for %s -> %s; attempting relay", peerID, msg.To) err := h.forwardToRelay(msg) switch { case err == nil: return nil case errors.Is(err, errKCPRelayUnavailable): log.Printf("kcp hub target %s unavailable for %s: no relay configured", msg.To, peerID) return sendKCPServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To)) case errors.Is(err, errKCPRelayPeerUnknown): log.Printf("kcp hub relay peer address is unknown for %s -> %s", peerID, msg.To) return sendKCPServerError(conn, peerID, "failed to relay to remote peer") case errors.Is(err, errKCPRelayTooLarge): log.Printf("kcp hub relay rejected oversize message %s -> %s (%d bytes)", peerID, msg.To, len(msg.Body)) return sendKCPServerError(conn, peerID, "message too large for relay udp") default: log.Printf("kcp hub relay forward failed for %s -> %s: %v", peerID, msg.To, err) return sendKCPServerError(conn, peerID, "failed to relay to remote peer") } case protocol.MessageTypeRegister, protocol.MessageTypeError: if err := sendKCPServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil { return fmt.Errorf("server: send kcp protocol error: %w", err) } return fmt.Errorf("server: unexpected kcp message type from peer %s: %s", peerID, msg.Type) default: if err := sendKCPServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil { return fmt.Errorf("server: send unsupported kcp type error: %w", err) } return fmt.Errorf("server: unsupported kcp message type: %s", msg.Type) } } // receivePeerLoop 持续读取 peer 发来的消息,并交给 handlePeerMessage 处理,直到连接出错。 func (h *KCPHub) receivePeerLoop(peerID string, conn *transport.KCPConn, sessionDesc string) error { for { msg, err := conn.Receive() if err != nil { _ = conn.Close() log.Printf("kcp hub receive loop ending for peer %q on session %s: %v", peerID, sessionDesc, err) return fmt.Errorf("transport: kcp receive loop read: %w", err) } if err := h.handlePeerMessage(peerID, conn, msg); err != nil { _ = conn.Close() log.Printf("kcp hub handler ending for peer %q on session %s: %v", peerID, sessionDesc, err) return fmt.Errorf("transport: kcp receive loop handler: %w", err) } } } func (h *KCPHub) deliverRelayedMessage(msg protocol.Message) error { if err := h.deliverToLocalPeer(msg); err == nil { return nil } else if !errors.Is(err, errKCPUnknownLocalTarget) { if msg.Type == protocol.MessageTypeError { log.Printf("kcp relay dropped undeliverable server error to %s: %v", msg.To, err) return nil } return h.forwardRelayServerError(msg.From, fmt.Sprintf("failed to forward to %s", msg.To)) } if msg.Type == protocol.MessageTypeError { log.Printf("kcp relay dropped server error for unknown local peer %s", msg.To) return nil } log.Printf("kcp hub relayed target miss for %s -> %s; sending error back", msg.From, msg.To) return h.forwardRelayServerError(msg.From, fmt.Sprintf("unknown target: %s", msg.To)) } func (h *KCPHub) deliverToLocalPeer(msg protocol.Message) error { targetConn, ok := h.lookup(msg.To) if !ok { return fmt.Errorf("%w: %s", errKCPUnknownLocalTarget, msg.To) } if err := targetConn.Send(msg); err != nil { h.unregister(msg.To, targetConn, "local-forward-failure") _ = targetConn.Close() return fmt.Errorf("server: forward to local peer %s: %w", msg.To, err) } return nil } func (h *KCPHub) forwardToRelay(msg protocol.Message) error { payload, err := protocol.EncodeMessage(msg) if err != nil { return fmt.Errorf("server: encode relay message: %w", err) } if len(payload) > kcpRelayMaxDatagramSize { return errKCPRelayTooLarge } h.mu.RLock() conn := h.relaySocket peerAddr := cloneRelayAddr(h.relayPeerAddr) h.mu.RUnlock() if conn == nil { return errKCPRelayUnavailable } if peerAddr == nil { return errKCPRelayPeerUnknown } if _, err := conn.WriteTo(payload, peerAddr); err != nil { return fmt.Errorf("server: relay write to %s: %w", peerAddr, err) } return nil } func (h *KCPHub) forwardRelayServerError(to, message string) error { return h.forwardToRelay(protocol.Message{ Type: protocol.MessageTypeError, From: protocol.ServerPeerID, To: to, Body: []byte(message), }) } func (h *KCPHub) acceptRelayPeer(addr net.Addr) bool { h.mu.Lock() defer h.mu.Unlock() if h.relayPeerAddr == nil && h.relayLearnPeer { h.relayPeerAddr = cloneRelayAddr(addr) log.Printf("kcp hub learned relay peer %s", addr) return true } if h.relayPeerAddr == nil { return true } return sameRelayAddr(h.relayPeerAddr, addr) } func (h *KCPHub) lookup(peerID string) (*transport.KCPConn, bool) { h.mu.RLock() defer h.mu.RUnlock() conn, ok := h.peers[peerID] return conn, ok } func (h *KCPHub) unregister(peerID string, conn *transport.KCPConn, sessionDesc string) { h.mu.Lock() defer h.mu.Unlock() current, ok := h.peers[peerID] if ok && current == conn { delete(h.peers, peerID) log.Printf("kcp hub unregistered peer %q from session %s (peers=%d)", peerID, sessionDesc, len(h.peers)) } } func sendKCPServerError(conn *transport.KCPConn, to, message string) error { return conn.Send(protocol.Message{ Type: protocol.MessageTypeError, From: protocol.ServerPeerID, To: to, Body: []byte(message), }) } func isRelayBusinessOrErrorMessage(messageType protocol.MessageType) bool { switch messageType { case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError: return true default: return false } } func isExpectedRelayServeExit(err error) bool { return errors.Is(err, net.ErrClosed) || strings.Contains(err.Error(), "use of closed network connection") } func cloneRelayAddr(addr net.Addr) net.Addr { if addr == nil { return nil } udpAddr, ok := addr.(*net.UDPAddr) if !ok { return addr } ipCopy := make([]byte, len(udpAddr.IP)) copy(ipCopy, udpAddr.IP) return &net.UDPAddr{ IP: ipCopy, Port: udpAddr.Port, Zone: udpAddr.Zone, } } func sameRelayAddr(left, right net.Addr) bool { if left == nil || right == nil { return left == right } return left.String() == right.String() } func describeKCPSession(session *kcp.UDPSession) string { if session == nil { return "conv= remote= local=" } return fmt.Sprintf("conv=%d remote=%s local=%s", session.GetConv(), addrString(session.RemoteAddr()), addrString(session.LocalAddr())) } func addrString(addr net.Addr) string { if addr == nil { return "" } return addr.String() }