From 581c52f9b586e01dd8a13c30dd30b5ad5590c5ec Mon Sep 17 00:00:00 2001 From: Mock Date: Sat, 28 Mar 2026 14:23:00 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=B8=AD=E8=BD=AC=E6=B7=BB=E5=8A=A0udp?= =?UTF-8?q?=20relay?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/internal/peer/kcp_client_test.go | 4 +- cmd/internal/server/udp_relay.go | 44 +++++++++--------- cmd/internal/server/udp_relay_test.go | 64 ++++++++++----------------- cmd/kcpserver/main.go | 1 + cmd/udprelay/main.go | 8 ++-- 5 files changed, 50 insertions(+), 71 deletions(-) diff --git a/cmd/internal/peer/kcp_client_test.go b/cmd/internal/peer/kcp_client_test.go index 5bc51f0..f6ffcf2 100644 --- a/cmd/internal/peer/kcp_client_test.go +++ b/cmd/internal/peer/kcp_client_test.go @@ -293,12 +293,12 @@ func TestKCPClientsExchangeMessagesViaUDPRelayToSingleHub(t *testing.T) { relayWG.Add(1) go func() { defer relayWG.Done() - if serveErr := relay.Serve(); serveErr != nil { + if serveErr := relay.Serve(); serveErr != nil && !isExpectedKCPRelayServeExit(serveErr) { t.Errorf("relay.Serve() error = %v", serveErr) } }() defer func() { - _ = relayConn.Close() + _ = relay.Close() relayWG.Wait() }() diff --git a/cmd/internal/server/udp_relay.go b/cmd/internal/server/udp_relay.go index 8614f02..c37b127 100644 --- a/cmd/internal/server/udp_relay.go +++ b/cmd/internal/server/udp_relay.go @@ -9,30 +9,29 @@ import ( "omnisocketgo/cmd/internal/protocol" ) -// udpRelayBufSize 是 relay 接收缓冲区大小,与 UDP transport 层保持一致。 const udpRelayBufSize = protocol.MaxFrameSize + 1024 -// UDPRelay 是一个透明的双向 UDP 转发器。 -// 它在下游(客户端 A)和上游(server D)之间原样转发 UDP 数据报, -// 不解析也不修改协议内容。 +// UDPRelay transparently forwards UDP datagrams between one downstream client +// and a fixed upstream server. type UDPRelay struct { - downstream *net.UDPConn // 监听端口,等待下游客户端连接 - upstream *net.UDPConn // 连接到上游 server(connected socket) + downstream net.PacketConn + upstream *net.UDPConn mu sync.RWMutex - clientAddr *net.UDPAddr // 下游客户端地址,从第一个下游包学习 + clientAddr net.Addr } -// NewUDPRelay 创建一个新的 UDP relay。 -// listenConn 是已经绑定好的监听 socket(供下游客户端连接), -// upstreamAddr 是上游 server D 的地址。 -func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error) { - udpUpstreamAddr, err := net.ResolveUDPAddr("udp", upstreamAddr) - if err != nil { - return nil, fmt.Errorf("relay: resolve upstream addr %s: %w", upstreamAddr, err) +// NewUDPRelay creates a relay that listens on listenConn and forwards all +// traffic to upstreamAddr. +func NewUDPRelay(listenConn net.PacketConn, upstreamAddr *net.UDPAddr) (*UDPRelay, error) { + if listenConn == nil { + return nil, fmt.Errorf("relay: listen conn is required") + } + if upstreamAddr == nil { + return nil, fmt.Errorf("relay: upstream addr is required") } - upstreamConn, err := net.DialUDP("udp", nil, udpUpstreamAddr) + upstreamConn, err := net.DialUDP("udp", nil, upstreamAddr) if err != nil { return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err) } @@ -43,7 +42,8 @@ func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error }, nil } -// Serve 启动双向转发循环,阻塞直到任一方向出错。 +// Serve starts bidirectional forwarding and blocks until either direction +// exits with an error. func (r *UDPRelay) Serve() error { errCh := make(chan error, 2) @@ -55,23 +55,21 @@ func (r *UDPRelay) Serve() error { }() err := <-errCh - // 关闭两个 conn 让另一个 goroutine 也退出 _ = r.downstream.Close() _ = r.upstream.Close() return err } -// forwardDownstreamToUpstream 从下游读取并转发到上游。 func (r *UDPRelay) forwardDownstreamToUpstream() error { buf := make([]byte, udpRelayBufSize) for { - n, addr, err := r.downstream.ReadFromUDP(buf) + n, addr, err := r.downstream.ReadFrom(buf) if err != nil { return fmt.Errorf("relay: read downstream: %w", err) } r.mu.Lock() - r.clientAddr = addr + r.clientAddr = cloneRelayAddr(addr) r.mu.Unlock() if _, err := r.upstream.Write(buf[:n]); err != nil { @@ -82,7 +80,6 @@ func (r *UDPRelay) forwardDownstreamToUpstream() error { } } -// forwardUpstreamToDownstream 从上游读取并转发到下游。 func (r *UDPRelay) forwardUpstreamToDownstream() error { buf := make([]byte, udpRelayBufSize) for { @@ -92,7 +89,7 @@ func (r *UDPRelay) forwardUpstreamToDownstream() error { } r.mu.RLock() - addr := r.clientAddr + addr := cloneRelayAddr(r.clientAddr) r.mu.RUnlock() if addr == nil { @@ -100,7 +97,7 @@ func (r *UDPRelay) forwardUpstreamToDownstream() error { continue } - if _, err := r.downstream.WriteToUDP(buf[:n], addr); err != nil { + if _, err := r.downstream.WriteTo(buf[:n], addr); err != nil { return fmt.Errorf("relay: write downstream to %s: %w", addr, err) } @@ -108,7 +105,6 @@ func (r *UDPRelay) forwardUpstreamToDownstream() error { } } -// Close 关闭 relay 的上下游连接。 func (r *UDPRelay) Close() error { err1 := r.downstream.Close() err2 := r.upstream.Close() diff --git a/cmd/internal/server/udp_relay_test.go b/cmd/internal/server/udp_relay_test.go index 206f3d5..aecff3b 100644 --- a/cmd/internal/server/udp_relay_test.go +++ b/cmd/internal/server/udp_relay_test.go @@ -14,22 +14,15 @@ import ( "omnisocketgo/cmd/internal/transport" ) -// TestUDPRelayKCPForwardAndReturn 验证 KCP 通过 UDP relay 的完整双向转发路径: -// peer-b -> D(KCP hub) -> C(UDP relay) -> peer-a 以及反向。 func TestUDPRelayKCPForwardAndReturn(t *testing.T) { - // 启动 D(KCP Hub) hub, hubAddr, hubCleanup := startKCPHubForRelay(t) defer hubCleanup() - // 启动 C(UDP Relay),upstream 指向 D relayAddr := startUDPRelay(t, hubAddr) - // peer-b 直连 D(KCP) peerBConn := dialKCPPeer(t, hubAddr) - // peer-a 连 C(通过 relay 间接连到 D) peerAConn := dialKCPPeer(t, relayAddr) - // 注册 peer-b if err := peerBConn.Send(protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-b", @@ -37,8 +30,6 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) { }); err != nil { t.Fatalf("peerB register: %v", err) } - - // 注册 peer-a(通过 relay) if err := peerAConn.Send(protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-a", @@ -51,7 +42,6 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") - // peer-b -> peer-a(路径: B -> D -> C -> A) if err := peerBConn.Send(protocol.Message{ Type: protocol.MessageTypeText, ID: 1, @@ -76,19 +66,12 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) { t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-b") } - // peer-a -> peer-b(路径: A -> C -> D -> B) if err := peerAConn.Send(protocol.Message{ Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", - Body: []byte("reply from peer-单个 downstream peer 通过 relay 连到 KCP server”这条 - 链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原 - 样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。 - 关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。 - - 还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在 - hub 注册成功a"), + Body: []byte("reply from peer-a"), }); err != nil { t.Fatalf("peerA send text: %v", err) } @@ -108,7 +91,6 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) { } } -// TestUDPRelayKCPFileMessage 验证通过 relay 转发 KCP 文件消息。 func TestUDPRelayKCPFileMessage(t *testing.T) { hub, hubAddr, hubCleanup := startKCPHubForRelay(t) defer hubCleanup() @@ -118,16 +100,20 @@ func TestUDPRelayKCPFileMessage(t *testing.T) { peerBConn := dialKCPPeer(t, hubAddr) peerAConn := dialKCPPeer(t, relayAddr) - _ = peerBConn.Send(protocol.Message{ + if err := peerBConn.Send(protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-b", To: protocol.ServerPeerID, - }) - _ = peerAConn.Send(protocol.Message{ + }); err != nil { + t.Fatalf("peerB register: %v", err) + } + if err := peerAConn.Send(protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-a", To: protocol.ServerPeerID, - }) + }); err != nil { + t.Fatalf("peerA register: %v", err) + } waitForRelay(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") @@ -154,18 +140,11 @@ func TestUDPRelayKCPFileMessage(t *testing.T) { if msg.FileName != "test.bin" { t.Fatalf("file name = %q, want %q", msg.FileName, "test.bin") } - if len(msg.Body) != 4 || msg.Body[0] != 0xDE { - t.Fatalf("file body mismatch: 单个 downstream peer 通过 relay 连到 KCP server”这条 - 链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原 - 样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。 - 关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。 - - 还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在 - hub 注册成功got %v", msg.Body) + if string(msg.Body) != string([]byte{0xDE, 0xAD, 0xBE, 0xEF}) { + t.Fatalf("file body = %v, want %v", msg.Body, []byte{0xDE, 0xAD, 0xBE, 0xEF}) } } -// startKCPHubForRelay 启动一个 KCP hub server,返回 hub、监听地址和 cleanup 函数。 func startKCPHubForRelay(t *testing.T) (*KCPHub, string, func()) { t.Helper() @@ -222,7 +201,6 @@ func startKCPHubForRelay(t *testing.T) (*KCPHub, string, func()) { return hub, listener.Addr().String(), cleanup } -// dialKCPPeer 创建一条到指定地址的 KCP 连接,用于测试。 func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn { t.Helper() @@ -244,38 +222,42 @@ func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn { return conn } -// startUDPRelay 创建并启动一个 UDPRelay,返回其监听地址字符串。 func startUDPRelay(t *testing.T, upstreamAddr string) string { t.Helper() - addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + remoteAddr, err := net.ResolveUDPAddr("udp", upstreamAddr) if err != nil { - t.Fatalf("ResolveUDPAddr() error = %v", err) + t.Fatalf("ResolveUDPAddr(%s) error = %v", upstreamAddr, err) } - conn, err := net.ListenUDP("udp", addr) + conn, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { - t.Fatalf("ListenUDP() error = %v", err) + t.Fatalf("ListenPacket() error = %v", err) } - relay, err := NewUDPRelay(conn, upstreamAddr) + relay, err := NewUDPRelay(conn, remoteAddr) if err != nil { _ = conn.Close() t.Fatalf("NewUDPRelay() error = %v", err) } + var wg sync.WaitGroup + wg.Add(1) go func() { - _ = relay.Serve() + defer wg.Done() + if serveErr := relay.Serve(); serveErr != nil && !isExpectedRelayServeExit(serveErr) { + t.Errorf("relay.Serve() error = %v", serveErr) + } }() t.Cleanup(func() { _ = relay.Close() + wg.Wait() }) return conn.LocalAddr().String() } -// waitForRelay 轮询等待条件满足,超时则 fail。 func waitForRelay(t *testing.T, condition func() bool, description string) { t.Helper() diff --git a/cmd/kcpserver/main.go b/cmd/kcpserver/main.go index 75cf507..739aced 100644 --- a/cmd/kcpserver/main.go +++ b/cmd/kcpserver/main.go @@ -165,6 +165,7 @@ func runUDPRelayServer(listenAddr, remoteAddr string) { relay, err := server.NewUDPRelay(conn, remote) if err != nil { + _ = conn.Close() log.Fatalf("create udp relay: %v", err) } diff --git a/cmd/udprelay/main.go b/cmd/udprelay/main.go index 89abe74..ee420f9 100644 --- a/cmd/udprelay/main.go +++ b/cmd/udprelay/main.go @@ -13,17 +13,17 @@ func main() { upstreamAddr := flag.String("upstream", "127.0.0.1:9002", "upstream KCP server address (server D)") flag.Parse() - udpAddr, err := net.ResolveUDPAddr("udp", *listenAddr) + upstreamUDPAddr, err := net.ResolveUDPAddr("udp", *upstreamAddr) if err != nil { - log.Fatalf("resolve listen address %s: %v", *listenAddr, err) + log.Fatalf("resolve upstream address %s: %v", *upstreamAddr, err) } - conn, err := net.ListenUDP("udp", udpAddr) + conn, err := net.ListenPacket("udp", *listenAddr) if err != nil { log.Fatalf("listen udp on %s: %v", *listenAddr, err) } - relay, err := server.NewUDPRelay(conn, *upstreamAddr) + relay, err := server.NewUDPRelay(conn, upstreamUDPAddr) if err != nil { _ = conn.Close() log.Fatalf("create udp relay: %v", err)