From 17122f6e4c67a7563652913c049cb5f35807d470 Mon Sep 17 00:00:00 2001 From: Mock Date: Sat, 28 Mar 2026 15:28:19 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=B7=BB=E5=8A=A0relay=E6=97=A5?= =?UTF-8?q?=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/internal/server/kcp_hub.go | 48 ++++++++++++++++++++++---- cmd/internal/server/udp_relay.go | 24 +++++++++++-- cmd/internal/server/udp_relay_test.go | 18 +++++++--- cmd/internal/transport/kcp.go | 27 +++++++++++++-- cmd/internal/transport/kcp_test.go | 49 +++++++++++++++++++++++++++ cmd/kcppeer/main.go | 10 +++--- cmd/kcpserver/main.go | 18 +++++++--- cmd/udprelay/main.go | 12 +++++-- 8 files changed, 179 insertions(+), 27 deletions(-) diff --git a/cmd/internal/server/kcp_hub.go b/cmd/internal/server/kcp_hub.go index 6e9289e..717eeaa 100644 --- a/cmd/internal/server/kcp_hub.go +++ b/cmd/internal/server/kcp_hub.go @@ -133,6 +133,9 @@ func (h *KCPHub) ServeRelay() error { // ServeSession 处理一条新接入的 KCP 会话。 func (h *KCPHub) ServeSession(session *kcp.UDPSession) error { + sessionDesc := describeKCPSession(session) + log.Printf("kcp hub accepted session %s", sessionDesc) + conn, err := transport.NewKCPConn( session, transport.WithKCPLogger(h.logger, latencylog.NodeRoleServer, "hub"), @@ -143,24 +146,26 @@ func (h *KCPHub) ServeSession(session *kcp.UDPSession) error { return fmt.Errorf("server: create kcp transport conn: %w", err) } - peerID, err := h.registerConn(conn) + peerID, err := h.registerConn(conn, sessionDesc) if err != nil { _ = conn.Close() return err } - defer h.unregister(peerID, conn) + defer h.unregister(peerID, conn, sessionDesc) - return h.receivePeerLoop(peerID, conn) + return h.receivePeerLoop(peerID, conn, sessionDesc) } // 注册新连接时,KCPHub 期望第一条消息是一个 register 消息,包含 peer 的 ID -func (h *KCPHub) registerConn(conn *transport.KCPConn) (string, error) { +func (h *KCPHub) registerConn(conn *transport.KCPConn, sessionDesc string) (string, error) { msg, err := conn.Receive() if err != nil { + log.Printf("kcp hub session %s failed before register: %v", sessionDesc, err) return "", fmt.Errorf("server: receive kcp register: %w", err) } if msg.Type != protocol.MessageTypeRegister { + log.Printf("kcp hub rejecting session %s: first message type=%s from=%s", sessionDesc, msg.Type, msg.From) if sendErr := sendKCPServerError(conn, msg.From, "first message must be register"); sendErr != nil { return "", fmt.Errorf("server: reject unregistered kcp peer: %w", sendErr) } @@ -171,6 +176,7 @@ func (h *KCPHub) registerConn(conn *transport.KCPConn) (string, error) { defer h.mu.Unlock() if _, exists := h.peers[msg.From]; exists { + log.Printf("kcp hub rejecting duplicate peer %q on session %s", msg.From, sessionDesc) if sendErr := sendKCPServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil { return "", fmt.Errorf("server: duplicate kcp peer id %s: %w", msg.From, sendErr) } @@ -178,6 +184,7 @@ func (h *KCPHub) registerConn(conn *transport.KCPConn) (string, error) { } h.peers[msg.From] = conn + log.Printf("kcp hub registered peer %q on session %s (peers=%d)", msg.From, sessionDesc, len(h.peers)) return msg.From, nil } @@ -190,18 +197,26 @@ func (h *KCPHub) handlePeerMessage(peerID string, conn *transport.KCPConn, msg p if err := h.deliverToLocalPeer(msg); err == nil { return nil } else if !errors.Is(err, errKCPUnknownLocalTarget) { + log.Printf("kcp hub local delivery failed for %s -> %s: %v", peerID, msg.To, err) return sendKCPServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To)) } + log.Printf("kcp hub local target miss for %s -> %s; attempting relay", peerID, msg.To) err := h.forwardToRelay(msg) switch { case err == nil: return nil case errors.Is(err, errKCPRelayUnavailable): + log.Printf("kcp hub target %s unavailable for %s: no relay configured", msg.To, peerID) return sendKCPServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To)) + case errors.Is(err, errKCPRelayPeerUnknown): + log.Printf("kcp hub relay peer address is unknown for %s -> %s", peerID, msg.To) + return sendKCPServerError(conn, peerID, "failed to relay to remote peer") case errors.Is(err, errKCPRelayTooLarge): + log.Printf("kcp hub relay rejected oversize message %s -> %s (%d bytes)", peerID, msg.To, len(msg.Body)) return sendKCPServerError(conn, peerID, "message too large for relay udp") default: + log.Printf("kcp hub relay forward failed for %s -> %s: %v", peerID, msg.To, err) return sendKCPServerError(conn, peerID, "failed to relay to remote peer") } case protocol.MessageTypeRegister, protocol.MessageTypeError: @@ -218,16 +233,18 @@ func (h *KCPHub) handlePeerMessage(peerID string, conn *transport.KCPConn, msg p } // receivePeerLoop 持续读取 peer 发来的消息,并交给 handlePeerMessage 处理,直到连接出错。 -func (h *KCPHub) receivePeerLoop(peerID string, conn *transport.KCPConn) error { +func (h *KCPHub) receivePeerLoop(peerID string, conn *transport.KCPConn, sessionDesc string) error { for { msg, err := conn.Receive() if err != nil { _ = conn.Close() + log.Printf("kcp hub receive loop ending for peer %q on session %s: %v", peerID, sessionDesc, err) return fmt.Errorf("transport: kcp receive loop read: %w", err) } if err := h.handlePeerMessage(peerID, conn, msg); err != nil { _ = conn.Close() + log.Printf("kcp hub handler ending for peer %q on session %s: %v", peerID, sessionDesc, err) return fmt.Errorf("transport: kcp receive loop handler: %w", err) } } @@ -249,6 +266,7 @@ func (h *KCPHub) deliverRelayedMessage(msg protocol.Message) error { return nil } + log.Printf("kcp hub relayed target miss for %s -> %s; sending error back", msg.From, msg.To) return h.forwardRelayServerError(msg.From, fmt.Sprintf("unknown target: %s", msg.To)) } @@ -258,7 +276,7 @@ func (h *KCPHub) deliverToLocalPeer(msg protocol.Message) error { return fmt.Errorf("%w: %s", errKCPUnknownLocalTarget, msg.To) } if err := targetConn.Send(msg); err != nil { - h.unregister(msg.To, targetConn) + h.unregister(msg.To, targetConn, "local-forward-failure") _ = targetConn.Close() return fmt.Errorf("server: forward to local peer %s: %w", msg.To, err) } @@ -307,6 +325,7 @@ func (h *KCPHub) acceptRelayPeer(addr net.Addr) bool { if h.relayPeerAddr == nil && h.relayLearnPeer { h.relayPeerAddr = cloneRelayAddr(addr) + log.Printf("kcp hub learned relay peer %s", addr) return true } if h.relayPeerAddr == nil { @@ -323,13 +342,14 @@ func (h *KCPHub) lookup(peerID string) (*transport.KCPConn, bool) { return conn, ok } -func (h *KCPHub) unregister(peerID string, conn *transport.KCPConn) { +func (h *KCPHub) unregister(peerID string, conn *transport.KCPConn, sessionDesc string) { h.mu.Lock() defer h.mu.Unlock() current, ok := h.peers[peerID] if ok && current == conn { delete(h.peers, peerID) + log.Printf("kcp hub unregistered peer %q from session %s (peers=%d)", peerID, sessionDesc, len(h.peers)) } } @@ -378,3 +398,17 @@ func sameRelayAddr(left, right net.Addr) bool { } return left.String() == right.String() } + +func describeKCPSession(session *kcp.UDPSession) string { + if session == nil { + return "conv= remote= local=" + } + return fmt.Sprintf("conv=%d remote=%s local=%s", session.GetConv(), addrString(session.RemoteAddr()), addrString(session.LocalAddr())) +} + +func addrString(addr net.Addr) string { + if addr == nil { + return "" + } + return addr.String() +} diff --git a/cmd/internal/server/udp_relay.go b/cmd/internal/server/udp_relay.go index c37b127..044be15 100644 --- a/cmd/internal/server/udp_relay.go +++ b/cmd/internal/server/udp_relay.go @@ -31,7 +31,7 @@ func NewUDPRelay(listenConn net.PacketConn, upstreamAddr *net.UDPAddr) (*UDPRela return nil, fmt.Errorf("relay: upstream addr is required") } - upstreamConn, err := net.DialUDP("udp", nil, upstreamAddr) + upstreamConn, err := net.DialUDP(relayUDPNetwork(upstreamAddr), nil, upstreamAddr) if err != nil { return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err) } @@ -68,10 +68,20 @@ func (r *UDPRelay) forwardDownstreamToUpstream() error { return fmt.Errorf("relay: read downstream: %w", err) } + clientAddr := cloneRelayAddr(addr) + r.mu.Lock() - r.clientAddr = cloneRelayAddr(addr) + previousAddr := cloneRelayAddr(r.clientAddr) + r.clientAddr = clientAddr r.mu.Unlock() + switch { + case previousAddr == nil: + log.Printf("relay: learned downstream client %s", clientAddr) + case !sameRelayAddr(previousAddr, clientAddr): + log.Printf("relay: downstream client changed from %s to %s", previousAddr, clientAddr) + } + if _, err := r.upstream.Write(buf[:n]); err != nil { return fmt.Errorf("relay: write upstream: %w", err) } @@ -113,3 +123,13 @@ func (r *UDPRelay) Close() error { } return err2 } + +func relayUDPNetwork(addr *net.UDPAddr) string { + if addr == nil || addr.IP == nil { + return "udp" + } + if addr.IP.To4() != nil { + return "udp4" + } + return "udp6" +} diff --git a/cmd/internal/server/udp_relay_test.go b/cmd/internal/server/udp_relay_test.go index aecff3b..3c1de41 100644 --- a/cmd/internal/server/udp_relay_test.go +++ b/cmd/internal/server/udp_relay_test.go @@ -18,7 +18,7 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) { hub, hubAddr, hubCleanup := startKCPHubForRelay(t) defer hubCleanup() - relayAddr := startUDPRelay(t, hubAddr) + relayAddr, relay := startUDPRelay(t, hubAddr) peerBConn := dialKCPPeer(t, hubAddr) peerAConn := dialKCPPeer(t, relayAddr) @@ -41,6 +41,11 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) { waitForRelay(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + waitForRelay(t, func() bool { + relay.mu.RLock() + defer relay.mu.RUnlock() + return relay.clientAddr != nil + }, "relay to learn the downstream peer") if err := peerBConn.Send(protocol.Message{ Type: protocol.MessageTypeText, @@ -95,7 +100,7 @@ func TestUDPRelayKCPFileMessage(t *testing.T) { hub, hubAddr, hubCleanup := startKCPHubForRelay(t) defer hubCleanup() - relayAddr := startUDPRelay(t, hubAddr) + relayAddr, relay := startUDPRelay(t, hubAddr) peerBConn := dialKCPPeer(t, hubAddr) peerAConn := dialKCPPeer(t, relayAddr) @@ -118,6 +123,11 @@ func TestUDPRelayKCPFileMessage(t *testing.T) { waitForRelay(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + waitForRelay(t, func() bool { + relay.mu.RLock() + defer relay.mu.RUnlock() + return relay.clientAddr != nil + }, "relay to learn the downstream peer") if err := peerBConn.Send(protocol.Message{ Type: protocol.MessageTypeFile, @@ -222,7 +232,7 @@ func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn { return conn } -func startUDPRelay(t *testing.T, upstreamAddr string) string { +func startUDPRelay(t *testing.T, upstreamAddr string) (string, *UDPRelay) { t.Helper() remoteAddr, err := net.ResolveUDPAddr("udp", upstreamAddr) @@ -255,7 +265,7 @@ func startUDPRelay(t *testing.T, upstreamAddr string) string { wg.Wait() }) - return conn.LocalAddr().String() + return conn.LocalAddr().String(), relay } func waitForRelay(t *testing.T, condition func() bool, description string) { diff --git a/cmd/internal/transport/kcp.go b/cmd/internal/transport/kcp.go index f88a525..8e7dcc0 100644 --- a/cmd/internal/transport/kcp.go +++ b/cmd/internal/transport/kcp.go @@ -196,15 +196,26 @@ func generateKCPConversationID() (uint32, error) { return convID, nil } -func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) { +// ResolveUDPListenConfig parses a UDP listen address and returns the socket +// family that should be used for binding it. +func ResolveUDPListenConfig(listenAddr string) (string, *net.UDPAddr, error) { udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) + if err != nil { + return "", nil, fmt.Errorf("transport: resolve udp listen addr %s: %w", listenAddr, err) + } + + return udpListenNetwork(udpAddr), udpAddr, nil +} + +func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) { + network, udpAddr, err := ResolveUDPListenConfig(listenAddr) if err != nil { return nil, fmt.Errorf("transport: resolve kcp listen addr %s: %w", listenAddr, err) } - rawConn, err := listenUDPConn("udp", udpAddr, bindDevice) + rawConn, err := listenUDPConn(network, udpAddr, bindDevice) if err != nil { - return nil, fmt.Errorf("transport: listen udp for kcp on %s: %w", listenAddr, err) + return nil, fmt.Errorf("transport: listen %s for kcp on %s: %w", network, udpListenAddr(udpAddr), err) } packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID) @@ -280,3 +291,13 @@ func udpListenAddr(addr *net.UDPAddr) string { } return addr.String() } + +func udpListenNetwork(addr *net.UDPAddr) string { + if addr == nil || addr.IP == nil { + return "udp" + } + if addr.IP.To4() != nil { + return "udp4" + } + return "udp6" +} diff --git a/cmd/internal/transport/kcp_test.go b/cmd/internal/transport/kcp_test.go index 054ff8f..da1e81a 100644 --- a/cmd/internal/transport/kcp_test.go +++ b/cmd/internal/transport/kcp_test.go @@ -204,6 +204,55 @@ func TestKCPCloseIsIdempotent(t *testing.T) { } } +func TestResolveUDPListenConfigSelectsSocketFamily(t *testing.T) { + tests := []struct { + name string + listenAddr string + wantNetwork string + wantAddr string + }{ + { + name: "ipv4 unspecified", + listenAddr: "0.0.0.0:10909", + wantNetwork: "udp4", + wantAddr: "0.0.0.0:10909", + }, + { + name: "ipv4 loopback", + listenAddr: "127.0.0.1:10909", + wantNetwork: "udp4", + wantAddr: "127.0.0.1:10909", + }, + { + name: "ipv6 loopback", + listenAddr: "[::1]:10909", + wantNetwork: "udp6", + wantAddr: "[::1]:10909", + }, + { + name: "host omitted", + listenAddr: ":10909", + wantNetwork: "udp", + wantAddr: ":10909", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotNetwork, gotAddr, err := ResolveUDPListenConfig(tt.listenAddr) + if err != nil { + t.Fatalf("ResolveUDPListenConfig(%q) error = %v", tt.listenAddr, err) + } + if gotNetwork != tt.wantNetwork { + t.Fatalf("network = %q, want %q", gotNetwork, tt.wantNetwork) + } + if gotAddr.String() != tt.wantAddr { + t.Fatalf("addr = %q, want %q", gotAddr.String(), tt.wantAddr) + } + }) + } +} + func newKCPConnPair(t *testing.T, senderOpts []KCPOption, receiverOpts []KCPOption, senderPacketLogger KCPPacketDebugLogger, receiverPacketLogger KCPPacketDebugLogger) (*KCPConn, <-chan kcpAcceptResult, func()) { t.Helper() diff --git a/cmd/kcppeer/main.go b/cmd/kcppeer/main.go index 3d1ffa7..c3b4051 100644 --- a/cmd/kcppeer/main.go +++ b/cmd/kcppeer/main.go @@ -16,8 +16,8 @@ import ( func main() { peerID := flag.String("id", "peer-a", "peer identity") - serverAddr := flag.String("server", "127.0.0.1:9002", "KCP server address") - relayVia := flag.String("relay-via", "", "optional UDP relay address used to reach the KCP server") + serverAddr := flag.String("server", "127.0.0.1:9002", "logical KCP hub address; when -relay-via is set this may differ from the actual UDP dial target") + relayVia := flag.String("relay-via", "", "optional UDP relay address used as the actual KCP dial target") targetPeer := flag.String("to", "", "optional target peer for one outgoing message") text := flag.String("text", "", "optional text to send after connecting") filePath := flag.String("file", "", "optional file path to send after connecting") @@ -77,10 +77,12 @@ func main() { } defer client.Close() + dialTarget := *serverAddr if *relayVia != "" { - log.Printf("connected to %s via relay %s as %s (KCP)", *serverAddr, *relayVia, client.ID()) + dialTarget = *relayVia + log.Printf("opened KCP session as %s; logical server=%s, actual dial target=%s via relay; register not yet confirmed", client.ID(), *serverAddr, dialTarget) } else { - log.Printf("connected to %s as %s (KCP)", *serverAddr, client.ID()) + log.Printf("opened KCP session as %s; logical server=%s, actual dial target=%s; register not yet confirmed", client.ID(), *serverAddr, dialTarget) } receiveErr := make(chan error, 1) diff --git a/cmd/kcpserver/main.go b/cmd/kcpserver/main.go index 739aced..e45621b 100644 --- a/cmd/kcpserver/main.go +++ b/cmd/kcpserver/main.go @@ -94,6 +94,11 @@ func main() { } func runHubServer(listenAddr, bindDevice, logPath, packetDebugLogPath, sessionStatsLogPath string, statsInterval time.Duration) { + listenNetwork, _, err := transport.ResolveUDPListenConfig(listenAddr) + if err != nil { + log.Fatalf("resolve kcp listen address %s: %v", listenAddr, err) + } + hubOptions := make([]server.KCPOption, 0, 2) if logPath != "" { logger, err := latencylog.NewJSONLLogger(logPath) @@ -131,7 +136,7 @@ func runHubServer(listenAddr, bindDevice, logPath, packetDebugLogPath, sessionSt hub := server.NewKCPHub(hubOptions...) - log.Printf("kcp hub listening on %s", listener.Addr()) + log.Printf("kcp hub listening on %s %s", listenNetwork, packetConn.LocalAddr()) for { session, err := listener.AcceptKCP() @@ -152,9 +157,14 @@ func runHubServer(listenAddr, bindDevice, logPath, packetDebugLogPath, sessionSt } func runUDPRelayServer(listenAddr, remoteAddr string) { - conn, err := net.ListenPacket("udp", listenAddr) + listenNetwork, udpListenAddr, err := transport.ResolveUDPListenConfig(listenAddr) if err != nil { - log.Fatalf("listen udp relay on %s: %v", listenAddr, err) + log.Fatalf("resolve udp relay listen address %s: %v", listenAddr, err) + } + + conn, err := net.ListenPacket(listenNetwork, udpListenAddr.String()) + if err != nil { + log.Fatalf("listen %s relay on %s: %v", listenNetwork, udpListenAddr, err) } defer conn.Close() @@ -169,7 +179,7 @@ func runUDPRelayServer(listenAddr, remoteAddr string) { log.Fatalf("create udp relay: %v", err) } - log.Printf("udp relay listening on %s and forwarding to %s", conn.LocalAddr(), remote) + log.Printf("udp relay listening on %s %s and forwarding to %s", listenNetwork, conn.LocalAddr(), remote) if err := relay.Serve(); err != nil { log.Fatalf("udp relay stopped: %v", err) } diff --git a/cmd/udprelay/main.go b/cmd/udprelay/main.go index ee420f9..6108f26 100644 --- a/cmd/udprelay/main.go +++ b/cmd/udprelay/main.go @@ -6,6 +6,7 @@ import ( "net" "omnisocketgo/cmd/internal/server" + "omnisocketgo/cmd/internal/transport" ) func main() { @@ -18,9 +19,14 @@ func main() { log.Fatalf("resolve upstream address %s: %v", *upstreamAddr, err) } - conn, err := net.ListenPacket("udp", *listenAddr) + listenNetwork, udpListenAddr, err := transport.ResolveUDPListenConfig(*listenAddr) if err != nil { - log.Fatalf("listen udp on %s: %v", *listenAddr, err) + log.Fatalf("resolve udp relay listen address %s: %v", *listenAddr, err) + } + + conn, err := net.ListenPacket(listenNetwork, udpListenAddr.String()) + if err != nil { + log.Fatalf("listen %s on %s: %v", listenNetwork, udpListenAddr, err) } relay, err := server.NewUDPRelay(conn, upstreamUDPAddr) @@ -29,7 +35,7 @@ func main() { log.Fatalf("create udp relay: %v", err) } - log.Printf("udp relay listening on %s, upstream %s", conn.LocalAddr(), *upstreamAddr) + log.Printf("udp relay listening on %s %s, upstream %s", listenNetwork, conn.LocalAddr(), *upstreamAddr) if err := relay.Serve(); err != nil { log.Fatalf("udp relay serve: %v", err)