fix: 中转添加udp relay
This commit is contained in:
@@ -293,12 +293,12 @@ func TestKCPClientsExchangeMessagesViaUDPRelayToSingleHub(t *testing.T) {
|
||||
relayWG.Add(1)
|
||||
go func() {
|
||||
defer relayWG.Done()
|
||||
if serveErr := relay.Serve(); serveErr != nil {
|
||||
if serveErr := relay.Serve(); serveErr != nil && !isExpectedKCPRelayServeExit(serveErr) {
|
||||
t.Errorf("relay.Serve() error = %v", serveErr)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = relayConn.Close()
|
||||
_ = relay.Close()
|
||||
relayWG.Wait()
|
||||
}()
|
||||
|
||||
|
||||
@@ -9,30 +9,29 @@ import (
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// udpRelayBufSize 是 relay 接收缓冲区大小,与 UDP transport 层保持一致。
|
||||
const udpRelayBufSize = protocol.MaxFrameSize + 1024
|
||||
|
||||
// UDPRelay 是一个透明的双向 UDP 转发器。
|
||||
// 它在下游(客户端 A)和上游(server D)之间原样转发 UDP 数据报,
|
||||
// 不解析也不修改协议内容。
|
||||
// UDPRelay transparently forwards UDP datagrams between one downstream client
|
||||
// and a fixed upstream server.
|
||||
type UDPRelay struct {
|
||||
downstream *net.UDPConn // 监听端口,等待下游客户端连接
|
||||
upstream *net.UDPConn // 连接到上游 server(connected socket)
|
||||
downstream net.PacketConn
|
||||
upstream *net.UDPConn
|
||||
|
||||
mu sync.RWMutex
|
||||
clientAddr *net.UDPAddr // 下游客户端地址,从第一个下游包学习
|
||||
clientAddr net.Addr
|
||||
}
|
||||
|
||||
// NewUDPRelay 创建一个新的 UDP relay。
|
||||
// listenConn 是已经绑定好的监听 socket(供下游客户端连接),
|
||||
// upstreamAddr 是上游 server D 的地址。
|
||||
func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error) {
|
||||
udpUpstreamAddr, err := net.ResolveUDPAddr("udp", upstreamAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("relay: resolve upstream addr %s: %w", upstreamAddr, err)
|
||||
// NewUDPRelay creates a relay that listens on listenConn and forwards all
|
||||
// traffic to upstreamAddr.
|
||||
func NewUDPRelay(listenConn net.PacketConn, upstreamAddr *net.UDPAddr) (*UDPRelay, error) {
|
||||
if listenConn == nil {
|
||||
return nil, fmt.Errorf("relay: listen conn is required")
|
||||
}
|
||||
if upstreamAddr == nil {
|
||||
return nil, fmt.Errorf("relay: upstream addr is required")
|
||||
}
|
||||
|
||||
upstreamConn, err := net.DialUDP("udp", nil, udpUpstreamAddr)
|
||||
upstreamConn, err := net.DialUDP("udp", nil, upstreamAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err)
|
||||
}
|
||||
@@ -43,7 +42,8 @@ func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve 启动双向转发循环,阻塞直到任一方向出错。
|
||||
// Serve starts bidirectional forwarding and blocks until either direction
|
||||
// exits with an error.
|
||||
func (r *UDPRelay) Serve() error {
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
@@ -55,23 +55,21 @@ func (r *UDPRelay) Serve() error {
|
||||
}()
|
||||
|
||||
err := <-errCh
|
||||
// 关闭两个 conn 让另一个 goroutine 也退出
|
||||
_ = r.downstream.Close()
|
||||
_ = r.upstream.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// forwardDownstreamToUpstream 从下游读取并转发到上游。
|
||||
func (r *UDPRelay) forwardDownstreamToUpstream() error {
|
||||
buf := make([]byte, udpRelayBufSize)
|
||||
for {
|
||||
n, addr, err := r.downstream.ReadFromUDP(buf)
|
||||
n, addr, err := r.downstream.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("relay: read downstream: %w", err)
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.clientAddr = addr
|
||||
r.clientAddr = cloneRelayAddr(addr)
|
||||
r.mu.Unlock()
|
||||
|
||||
if _, err := r.upstream.Write(buf[:n]); err != nil {
|
||||
@@ -82,7 +80,6 @@ func (r *UDPRelay) forwardDownstreamToUpstream() error {
|
||||
}
|
||||
}
|
||||
|
||||
// forwardUpstreamToDownstream 从上游读取并转发到下游。
|
||||
func (r *UDPRelay) forwardUpstreamToDownstream() error {
|
||||
buf := make([]byte, udpRelayBufSize)
|
||||
for {
|
||||
@@ -92,7 +89,7 @@ func (r *UDPRelay) forwardUpstreamToDownstream() error {
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
addr := r.clientAddr
|
||||
addr := cloneRelayAddr(r.clientAddr)
|
||||
r.mu.RUnlock()
|
||||
|
||||
if addr == nil {
|
||||
@@ -100,7 +97,7 @@ func (r *UDPRelay) forwardUpstreamToDownstream() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := r.downstream.WriteToUDP(buf[:n], addr); err != nil {
|
||||
if _, err := r.downstream.WriteTo(buf[:n], addr); err != nil {
|
||||
return fmt.Errorf("relay: write downstream to %s: %w", addr, err)
|
||||
}
|
||||
|
||||
@@ -108,7 +105,6 @@ func (r *UDPRelay) forwardUpstreamToDownstream() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭 relay 的上下游连接。
|
||||
func (r *UDPRelay) Close() error {
|
||||
err1 := r.downstream.Close()
|
||||
err2 := r.upstream.Close()
|
||||
|
||||
@@ -14,22 +14,15 @@ import (
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// TestUDPRelayKCPForwardAndReturn 验证 KCP 通过 UDP relay 的完整双向转发路径:
|
||||
// peer-b -> D(KCP hub) -> C(UDP relay) -> peer-a 以及反向。
|
||||
func TestUDPRelayKCPForwardAndReturn(t *testing.T) {
|
||||
// 启动 D(KCP Hub)
|
||||
hub, hubAddr, hubCleanup := startKCPHubForRelay(t)
|
||||
defer hubCleanup()
|
||||
|
||||
// 启动 C(UDP Relay),upstream 指向 D
|
||||
relayAddr := startUDPRelay(t, hubAddr)
|
||||
|
||||
// peer-b 直连 D(KCP)
|
||||
peerBConn := dialKCPPeer(t, hubAddr)
|
||||
// peer-a 连 C(通过 relay 间接连到 D)
|
||||
peerAConn := dialKCPPeer(t, relayAddr)
|
||||
|
||||
// 注册 peer-b
|
||||
if err := peerBConn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-b",
|
||||
@@ -37,8 +30,6 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) {
|
||||
}); err != nil {
|
||||
t.Fatalf("peerB register: %v", err)
|
||||
}
|
||||
|
||||
// 注册 peer-a(通过 relay)
|
||||
if err := peerAConn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-a",
|
||||
@@ -51,7 +42,6 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) {
|
||||
return hub.HasPeer("peer-a") && hub.HasPeer("peer-b")
|
||||
}, "both peers to be registered")
|
||||
|
||||
// peer-b -> peer-a(路径: B -> D -> C -> A)
|
||||
if err := peerBConn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
@@ -76,19 +66,12 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) {
|
||||
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-b")
|
||||
}
|
||||
|
||||
// peer-a -> peer-b(路径: A -> C -> D -> B)
|
||||
if err := peerAConn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("reply from peer-单个 downstream peer 通过 relay 连到 KCP server”这条
|
||||
链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原
|
||||
样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。
|
||||
关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。
|
||||
|
||||
还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在
|
||||
hub 注册成功a"),
|
||||
Body: []byte("reply from peer-a"),
|
||||
}); err != nil {
|
||||
t.Fatalf("peerA send text: %v", err)
|
||||
}
|
||||
@@ -108,7 +91,6 @@ func TestUDPRelayKCPForwardAndReturn(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPRelayKCPFileMessage 验证通过 relay 转发 KCP 文件消息。
|
||||
func TestUDPRelayKCPFileMessage(t *testing.T) {
|
||||
hub, hubAddr, hubCleanup := startKCPHubForRelay(t)
|
||||
defer hubCleanup()
|
||||
@@ -118,16 +100,20 @@ func TestUDPRelayKCPFileMessage(t *testing.T) {
|
||||
peerBConn := dialKCPPeer(t, hubAddr)
|
||||
peerAConn := dialKCPPeer(t, relayAddr)
|
||||
|
||||
_ = peerBConn.Send(protocol.Message{
|
||||
if err := peerBConn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-b",
|
||||
To: protocol.ServerPeerID,
|
||||
})
|
||||
_ = peerAConn.Send(protocol.Message{
|
||||
}); err != nil {
|
||||
t.Fatalf("peerB register: %v", err)
|
||||
}
|
||||
if err := peerAConn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-a",
|
||||
To: protocol.ServerPeerID,
|
||||
})
|
||||
}); err != nil {
|
||||
t.Fatalf("peerA register: %v", err)
|
||||
}
|
||||
|
||||
waitForRelay(t, func() bool {
|
||||
return hub.HasPeer("peer-a") && hub.HasPeer("peer-b")
|
||||
@@ -154,18 +140,11 @@ func TestUDPRelayKCPFileMessage(t *testing.T) {
|
||||
if msg.FileName != "test.bin" {
|
||||
t.Fatalf("file name = %q, want %q", msg.FileName, "test.bin")
|
||||
}
|
||||
if len(msg.Body) != 4 || msg.Body[0] != 0xDE {
|
||||
t.Fatalf("file body mismatch: 单个 downstream peer 通过 relay 连到 KCP server”这条
|
||||
链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原
|
||||
样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。
|
||||
关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。
|
||||
|
||||
还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在
|
||||
hub 注册成功got %v", msg.Body)
|
||||
if string(msg.Body) != string([]byte{0xDE, 0xAD, 0xBE, 0xEF}) {
|
||||
t.Fatalf("file body = %v, want %v", msg.Body, []byte{0xDE, 0xAD, 0xBE, 0xEF})
|
||||
}
|
||||
}
|
||||
|
||||
// startKCPHubForRelay 启动一个 KCP hub server,返回 hub、监听地址和 cleanup 函数。
|
||||
func startKCPHubForRelay(t *testing.T) (*KCPHub, string, func()) {
|
||||
t.Helper()
|
||||
|
||||
@@ -222,7 +201,6 @@ func startKCPHubForRelay(t *testing.T) (*KCPHub, string, func()) {
|
||||
return hub, listener.Addr().String(), cleanup
|
||||
}
|
||||
|
||||
// dialKCPPeer 创建一条到指定地址的 KCP 连接,用于测试。
|
||||
func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn {
|
||||
t.Helper()
|
||||
|
||||
@@ -244,38 +222,42 @@ func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn {
|
||||
return conn
|
||||
}
|
||||
|
||||
// startUDPRelay 创建并启动一个 UDPRelay,返回其监听地址字符串。
|
||||
func startUDPRelay(t *testing.T, upstreamAddr string) string {
|
||||
t.Helper()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", upstreamAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||
t.Fatalf("ResolveUDPAddr(%s) error = %v", upstreamAddr, err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP() error = %v", err)
|
||||
t.Fatalf("ListenPacket() error = %v", err)
|
||||
}
|
||||
|
||||
relay, err := NewUDPRelay(conn, upstreamAddr)
|
||||
relay, err := NewUDPRelay(conn, remoteAddr)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("NewUDPRelay() error = %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_ = relay.Serve()
|
||||
defer wg.Done()
|
||||
if serveErr := relay.Serve(); serveErr != nil && !isExpectedRelayServeExit(serveErr) {
|
||||
t.Errorf("relay.Serve() error = %v", serveErr)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = relay.Close()
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
return conn.LocalAddr().String()
|
||||
}
|
||||
|
||||
// waitForRelay 轮询等待条件满足,超时则 fail。
|
||||
func waitForRelay(t *testing.T, condition func() bool, description string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user