From 34d2f574ace50176b7a1f129c8f05ffdf4ab064d Mon Sep 17 00:00:00 2001 From: nnbcccscdscdsc <2709767634@qq.com> Date: Sat, 28 Mar 2026 13:13:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E6=96=B0=E5=A2=9Eserver=20upd=E8=BD=AC?= =?UTF-8?q?=E5=8F=91=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/internal/server/udp_relay.go | 176 +++++++------- cmd/internal/server/udp_relay_test.go | 336 ++++++++++++++++++++------ cmd/udprelay/main.go | 37 +++ 3 files changed, 383 insertions(+), 166 deletions(-) create mode 100644 cmd/udprelay/main.go diff --git a/cmd/internal/server/udp_relay.go b/cmd/internal/server/udp_relay.go index 3c39d73..8614f02 100644 --- a/cmd/internal/server/udp_relay.go +++ b/cmd/internal/server/udp_relay.go @@ -6,122 +6,114 @@ import ( "net" "sync" - "omnisocketgo/cmd/internal/transport" + "omnisocketgo/cmd/internal/protocol" ) -// UDPRelay 负责在固定远端与多个客户端之间双向透明转发 KCP UDP datagram。 -type UDPRelay struct { - conn net.PacketConn - remote *net.UDPAddr +// udpRelayBufSize 是 relay 接收缓冲区大小,与 UDP transport 层保持一致。 +const udpRelayBufSize = protocol.MaxFrameSize + 1024 - mu sync.RWMutex - clients map[uint32]*net.UDPAddr +// UDPRelay 是一个透明的双向 UDP 转发器。 +// 它在下游(客户端 A)和上游(server D)之间原样转发 UDP 数据报, +// 不解析也不修改协议内容。 +type UDPRelay struct { + downstream *net.UDPConn // 监听端口,等待下游客户端连接 + upstream *net.UDPConn // 连接到上游 server(connected socket) + + mu sync.RWMutex + clientAddr *net.UDPAddr // 下游客户端地址,从第一个下游包学习 } -// NewUDPRelay 创建一个绑定到给定 PacketConn 的透明 UDP relay。 -func NewUDPRelay(conn net.PacketConn, remote *net.UDPAddr) (*UDPRelay, error) { - if conn == nil { - return nil, fmt.Errorf("server: nil udp relay conn") +// NewUDPRelay 创建一个新的 UDP relay。 +// listenConn 是已经绑定好的监听 socket(供下游客户端连接), +// upstreamAddr 是上游 server D 的地址。 +func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error) { + udpUpstreamAddr, err := net.ResolveUDPAddr("udp", upstreamAddr) + if err != nil { + return nil, fmt.Errorf("relay: resolve upstream addr %s: %w", upstreamAddr, err) } - if remote == nil { - return nil, fmt.Errorf("server: nil udp relay remote") + + upstreamConn, err := net.DialUDP("udp", nil, udpUpstreamAddr) + if err != nil { + return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err) } return &UDPRelay{ - conn: conn, - remote: cloneUDPAddr(remote), - clients: make(map[uint32]*net.UDPAddr), + downstream: listenConn, + upstream: upstreamConn, }, nil } -// Serve 持续双向转发原始 UDP datagram,不解析业务消息。 +// Serve 启动双向转发循环,阻塞直到任一方向出错。 func (r *UDPRelay) Serve() error { - buffer := make([]byte, 64*1024) + errCh := make(chan error, 2) + + go func() { + errCh <- r.forwardDownstreamToUpstream() + }() + go func() { + errCh <- r.forwardUpstreamToDownstream() + }() + + err := <-errCh + // 关闭两个 conn 让另一个 goroutine 也退出 + _ = r.downstream.Close() + _ = r.upstream.Close() + return err +} + +// forwardDownstreamToUpstream 从下游读取并转发到上游。 +func (r *UDPRelay) forwardDownstreamToUpstream() error { + buf := make([]byte, udpRelayBufSize) for { - n, addr, err := r.conn.ReadFrom(buffer) + n, addr, err := r.downstream.ReadFromUDP(buf) if err != nil { - if isExpectedRelayServeExit(err) { - return nil - } - return fmt.Errorf("server: udp relay read packet: %w", err) + return fmt.Errorf("relay: read downstream: %w", err) } - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - log.Printf("udp relay dropped packet from non-udp addr %T", addr) + r.mu.Lock() + r.clientAddr = addr + r.mu.Unlock() + + if _, err := r.upstream.Write(buf[:n]); err != nil { + return fmt.Errorf("relay: write upstream: %w", err) + } + + log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr) + } +} + +// forwardUpstreamToDownstream 从上游读取并转发到下游。 +func (r *UDPRelay) forwardUpstreamToDownstream() error { + buf := make([]byte, udpRelayBufSize) + for { + n, err := r.upstream.Read(buf) + if err != nil { + return fmt.Errorf("relay: read upstream: %w", err) + } + + r.mu.RLock() + addr := r.clientAddr + r.mu.RUnlock() + + if addr == nil { + log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n) continue } - payload := append([]byte(nil), buffer[:n]...) - if sameUDPAddr(udpAddr, r.remote) { - if err := r.forwardRemotePacket(payload); err != nil { - log.Printf("udp relay failed forwarding remote packet from %s: %v", udpAddr, err) - } - continue + if _, err := r.downstream.WriteToUDP(buf[:n], addr); err != nil { + return fmt.Errorf("relay: write downstream to %s: %w", addr, err) } - if err := r.forwardClientPacket(udpAddr, payload); err != nil { - log.Printf("udp relay failed forwarding client packet from %s: %v", udpAddr, err) - } + log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr) } } -func (r *UDPRelay) forwardClientPacket(addr *net.UDPAddr, payload []byte) error { - convID, ok := transport.ParseKCPConversationID(payload) - if !ok { - return fmt.Errorf("missing kcp conversation id") +// Close 关闭 relay 的上下游连接。 +func (r *UDPRelay) Close() error { + err1 := r.downstream.Close() + err2 := r.upstream.Close() + if err1 != nil { + return err1 } - - r.mu.Lock() - r.clients[convID] = cloneUDPAddr(addr) - r.mu.Unlock() - - if _, err := r.conn.WriteTo(payload, r.remote); err != nil { - return fmt.Errorf("write conv %d to remote %s: %w", convID, r.remote, err) - } - return nil -} - -func (r *UDPRelay) forwardRemotePacket(payload []byte) error { - convID, ok := transport.ParseKCPConversationID(payload) - if !ok { - return fmt.Errorf("missing kcp conversation id") - } - - r.mu.RLock() - clientAddr := cloneUDPAddr(r.clients[convID]) - r.mu.RUnlock() - if clientAddr == nil { - return fmt.Errorf("unknown client for conv %d", convID) - } - - if _, err := r.conn.WriteTo(payload, clientAddr); err != nil { - return fmt.Errorf("write conv %d to client %s: %w", convID, clientAddr, err) - } - return nil -} - -func cloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr { - if addr == nil { - return nil - } - - ipCopy := make([]byte, len(addr.IP)) - copy(ipCopy, addr.IP) - - return &net.UDPAddr{ - IP: ipCopy, - Port: addr.Port, - Zone: addr.Zone, - } -} - -func sameUDPAddr(left, right *net.UDPAddr) bool { - if left == nil || right == nil { - return left == right - } - if left.Port != right.Port || left.Zone != right.Zone { - return false - } - return left.IP.Equal(right.IP) + return err2 } diff --git a/cmd/internal/server/udp_relay_test.go b/cmd/internal/server/udp_relay_test.go index f866211..206f3d5 100644 --- a/cmd/internal/server/udp_relay_test.go +++ b/cmd/internal/server/udp_relay_test.go @@ -1,103 +1,291 @@ package server import ( - "encoding/binary" "net" + "strings" "sync" "testing" "time" + + kcp "github.com/xtaci/kcp-go/v5" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" ) -func TestUDPRelayRoutesPacketsByKCPConversationID(t *testing.T) { - remote, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatalf("ListenPacket(remote) error = %v", err) - } - defer remote.Close() +// TestUDPRelayKCPForwardAndReturn 验证 KCP 通过 UDP relay 的完整双向转发路径: +// peer-b -> D(KCP hub) -> C(UDP relay) -> peer-a 以及反向。 +func TestUDPRelayKCPForwardAndReturn(t *testing.T) { + // 启动 D(KCP Hub) + hub, hubAddr, hubCleanup := startKCPHubForRelay(t) + defer hubCleanup() - relayConn, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatalf("ListenPacket(relay) error = %v", err) + // 启动 C(UDP Relay),upstream 指向 D + relayAddr := startUDPRelay(t, hubAddr) + + // peer-b 直连 D(KCP) + peerBConn := dialKCPPeer(t, hubAddr) + // peer-a 连 C(通过 relay 间接连到 D) + peerAConn := dialKCPPeer(t, relayAddr) + + // 注册 peer-b + if err := peerBConn.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-b", + To: protocol.ServerPeerID, + }); err != nil { + t.Fatalf("peerB register: %v", err) } - relay, err := NewUDPRelay(relayConn, remote.LocalAddr().(*net.UDPAddr)) - if err != nil { - _ = relayConn.Close() - t.Fatalf("NewUDPRelay() error = %v", err) + // 注册 peer-a(通过 relay) + if err := peerAConn.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-a", + To: protocol.ServerPeerID, + }); err != nil { + t.Fatalf("peerA register: %v", err) } - var wg sync.WaitGroup + waitForRelay(t, func() bool { + return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") + }, "both peers to be registered") + + // peer-b -> peer-a(路径: B -> D -> C -> A) + if err := peerBConn.Send(protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-b", + To: "peer-a", + Body: []byte("hello from peer-b"), + }); err != nil { + t.Fatalf("peerB send text: %v", err) + } + + msg, err := peerAConn.Receive() + if err != nil { + t.Fatalf("peerA receive: %v", err) + } + if msg.Type != protocol.MessageTypeText { + t.Fatalf("message type = %s, want text", msg.Type) + } + if msg.From != "peer-b" { + t.Fatalf("message from = %s, want peer-b", msg.From) + } + if string(msg.Body) != "hello from peer-b" { + t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-b") + } + + // peer-a -> peer-b(路径: A -> C -> D -> B) + if err := peerAConn.Send(protocol.Message{ + Type: protocol.MessageTypeText, + ID: 2, + From: "peer-a", + To: "peer-b", + Body: []byte("reply from peer-单个 downstream peer 通过 relay 连到 KCP server”这条 + 链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原 + 样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。 + 关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。 + + 还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在 + hub 注册成功a"), + }); err != nil { + t.Fatalf("peerA send text: %v", err) + } + + msg2, err := peerBConn.Receive() + if err != nil { + t.Fatalf("peerB receive: %v", err) + } + if msg2.Type != protocol.MessageTypeText { + t.Fatalf("message type = %s, want text", msg2.Type) + } + if msg2.From != "peer-a" { + t.Fatalf("message from = %s, want peer-a", msg2.From) + } + if string(msg2.Body) != "reply from peer-a" { + t.Fatalf("message body = %q, want %q", string(msg2.Body), "reply from peer-a") + } +} + +// TestUDPRelayKCPFileMessage 验证通过 relay 转发 KCP 文件消息。 +func TestUDPRelayKCPFileMessage(t *testing.T) { + hub, hubAddr, hubCleanup := startKCPHubForRelay(t) + defer hubCleanup() + + relayAddr := startUDPRelay(t, hubAddr) + + peerBConn := dialKCPPeer(t, hubAddr) + peerAConn := dialKCPPeer(t, relayAddr) + + _ = peerBConn.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-b", + To: protocol.ServerPeerID, + }) + _ = peerAConn.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-a", + To: protocol.ServerPeerID, + }) + + waitForRelay(t, func() bool { + return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") + }, "both peers to be registered") + + if err := peerBConn.Send(protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 1, + From: "peer-b", + To: "peer-a", + FileName: "test.bin", + Body: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + }); err != nil { + t.Fatalf("peerB send file: %v", err) + } + + msg, err := peerAConn.Receive() + if err != nil { + t.Fatalf("peerA receive: %v", err) + } + if msg.Type != protocol.MessageTypeFile { + t.Fatalf("message type = %s, want file", msg.Type) + } + if msg.FileName != "test.bin" { + t.Fatalf("file name = %q, want %q", msg.FileName, "test.bin") + } + if len(msg.Body) != 4 || msg.Body[0] != 0xDE { + t.Fatalf("file body mismatch: 单个 downstream peer 通过 relay 连到 KCP server”这条 + 链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原 + 样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。 + 关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。 + + 还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在 + hub 注册成功got %v", msg.Body) + } +} + +// startKCPHubForRelay 启动一个 KCP hub server,返回 hub、监听地址和 cleanup 函数。 +func startKCPHubForRelay(t *testing.T) (*KCPHub, string, func()) { + t.Helper() + + hub := NewKCPHub() + + listener, packetConn, err := transport.ListenKCPSessions("127.0.0.1:0", "", nil, latencylog.NodeRoleServer, "hub") + if err != nil { + t.Fatalf("ListenKCPSessions() error = %v", err) + } + + var ( + wg sync.WaitGroup + stop = make(chan struct{}) + ) + wg.Add(1) go func() { defer wg.Done() - if serveErr := relay.Serve(); serveErr != nil { - t.Errorf("relay.Serve() error = %v", serveErr) + for { + session, acceptErr := listener.AcceptKCP() + if acceptErr != nil { + select { + case <-stop: + return + default: + } + if strings.Contains(acceptErr.Error(), "closed") { + return + } + t.Errorf("AcceptKCP() error = %v", acceptErr) + return + } + + wg.Add(1) + go func(sess *kcp.UDPSession) { + defer wg.Done() + if serveErr := hub.ServeSession(sess); serveErr != nil { + msg := serveErr.Error() + if !strings.Contains(msg, "closed") && !strings.Contains(msg, "broken pipe") { + t.Logf("hub.ServeSession() ended with %v", serveErr) + } + } + }(session) } }() - defer func() { - _ = relayConn.Close() + + cleanup := func() { + close(stop) + _ = listener.Close() + _ = packetConn.Close() wg.Wait() + } + + return hub, listener.Addr().String(), cleanup +} + +// dialKCPPeer 创建一条到指定地址的 KCP 连接,用于测试。 +func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn { + t.Helper() + + session, err := transport.DialKCPSession(serverAddr, "", "", nil, latencylog.NodeRolePeer, "test") + if err != nil { + t.Fatalf("DialKCPSession(%s) error = %v", serverAddr, err) + } + + conn, err := transport.NewKCPConn(session) + if err != nil { + _ = session.Close() + t.Fatalf("NewKCPConn() error = %v", err) + } + + t.Cleanup(func() { + _ = conn.Close() + }) + + return conn +} + +// startUDPRelay 创建并启动一个 UDPRelay,返回其监听地址字符串。 +func startUDPRelay(t *testing.T, upstreamAddr string) string { + t.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr() error = %v", err) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("ListenUDP() error = %v", err) + } + + relay, err := NewUDPRelay(conn, upstreamAddr) + if err != nil { + _ = conn.Close() + t.Fatalf("NewUDPRelay() error = %v", err) + } + + go func() { + _ = relay.Serve() }() - client1, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatalf("ListenPacket(client1) error = %v", err) - } - defer client1.Close() + t.Cleanup(func() { + _ = relay.Close() + }) - client2, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatalf("ListenPacket(client2) error = %v", err) - } - defer client2.Close() - - relayAddr := relayConn.LocalAddr() - - sendPacket(t, client1, relayAddr, buildRelayTestPacket(1, []byte("client-one"))) - assertPacketReceived(t, remote, buildRelayTestPacket(1, []byte("client-one"))) - - sendPacket(t, client2, relayAddr, buildRelayTestPacket(2, []byte("client-two"))) - assertPacketReceived(t, remote, buildRelayTestPacket(2, []byte("client-two"))) - - sendPacket(t, remote, relayAddr, buildRelayTestPacket(2, []byte("reply-two"))) - assertPacketReceived(t, client2, buildRelayTestPacket(2, []byte("reply-two"))) - - sendPacket(t, remote, relayAddr, buildRelayTestPacket(1, []byte("reply-one"))) - assertPacketReceived(t, client1, buildRelayTestPacket(1, []byte("reply-one"))) + return conn.LocalAddr().String() } -func buildRelayTestPacket(convID uint32, body []byte) []byte { - packet := make([]byte, 4+len(body)) - binary.LittleEndian.PutUint32(packet[:4], convID) - copy(packet[4:], body) - return packet -} - -func sendPacket(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) { +// waitForRelay 轮询等待条件满足,超时则 fail。 +func waitForRelay(t *testing.T, condition func() bool, description string) { t.Helper() - if err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil { - t.Fatalf("SetWriteDeadline() error = %v", err) - } - if _, err := conn.WriteTo(payload, addr); err != nil { - t.Fatalf("WriteTo(%s) error = %v", addr, err) - } -} - -func assertPacketReceived(t *testing.T, conn net.PacketConn, want []byte) { - t.Helper() - - if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { - t.Fatalf("SetReadDeadline() error = %v", err) - } - - buffer := make([]byte, 1024) - n, _, err := conn.ReadFrom(buffer) - if err != nil { - t.Fatalf("ReadFrom() error = %v", err) - } - got := buffer[:n] - if string(got) != string(want) { - t.Fatalf("packet = %v, want %v", got, want) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(10 * time.Millisecond) } + + t.Fatalf("timed out waiting for %s", description) } diff --git a/cmd/udprelay/main.go b/cmd/udprelay/main.go new file mode 100644 index 0000000..89abe74 --- /dev/null +++ b/cmd/udprelay/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "flag" + "log" + "net" + + "omnisocketgo/cmd/internal/server" +) + +func main() { + listenAddr := flag.String("listen", ":9003", "UDP relay listen address (downstream, for KCP peer to connect)") + upstreamAddr := flag.String("upstream", "127.0.0.1:9002", "upstream KCP server address (server D)") + flag.Parse() + + udpAddr, err := net.ResolveUDPAddr("udp", *listenAddr) + if err != nil { + log.Fatalf("resolve listen address %s: %v", *listenAddr, err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + log.Fatalf("listen udp on %s: %v", *listenAddr, err) + } + + relay, err := server.NewUDPRelay(conn, *upstreamAddr) + if err != nil { + _ = conn.Close() + log.Fatalf("create udp relay: %v", err) + } + + log.Printf("udp relay listening on %s, upstream %s", conn.LocalAddr(), *upstreamAddr) + + if err := relay.Serve(); err != nil { + log.Fatalf("udp relay serve: %v", err) + } +}