feat:新增server upd转发功能
This commit is contained in:
@@ -6,122 +6,114 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"omnisocketgo/cmd/internal/transport"
|
"omnisocketgo/cmd/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UDPRelay 负责在固定远端与多个客户端之间双向透明转发 KCP UDP datagram。
|
// udpRelayBufSize 是 relay 接收缓冲区大小,与 UDP transport 层保持一致。
|
||||||
|
const udpRelayBufSize = protocol.MaxFrameSize + 1024
|
||||||
|
|
||||||
|
// UDPRelay 是一个透明的双向 UDP 转发器。
|
||||||
|
// 它在下游(客户端 A)和上游(server D)之间原样转发 UDP 数据报,
|
||||||
|
// 不解析也不修改协议内容。
|
||||||
type UDPRelay struct {
|
type UDPRelay struct {
|
||||||
conn net.PacketConn
|
downstream *net.UDPConn // 监听端口,等待下游客户端连接
|
||||||
remote *net.UDPAddr
|
upstream *net.UDPConn // 连接到上游 server(connected socket)
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
clients map[uint32]*net.UDPAddr
|
clientAddr *net.UDPAddr // 下游客户端地址,从第一个下游包学习
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPRelay 创建一个绑定到给定 PacketConn 的透明 UDP relay。
|
// NewUDPRelay 创建一个新的 UDP relay。
|
||||||
func NewUDPRelay(conn net.PacketConn, remote *net.UDPAddr) (*UDPRelay, error) {
|
// listenConn 是已经绑定好的监听 socket(供下游客户端连接),
|
||||||
if conn == nil {
|
// upstreamAddr 是上游 server D 的地址。
|
||||||
return nil, fmt.Errorf("server: nil udp relay conn")
|
func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error) {
|
||||||
|
udpUpstreamAddr, err := net.ResolveUDPAddr("udp", upstreamAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("relay: resolve upstream addr %s: %w", upstreamAddr, err)
|
||||||
}
|
}
|
||||||
if remote == nil {
|
|
||||||
return nil, fmt.Errorf("server: nil udp relay remote")
|
upstreamConn, err := net.DialUDP("udp", nil, udpUpstreamAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &UDPRelay{
|
return &UDPRelay{
|
||||||
conn: conn,
|
downstream: listenConn,
|
||||||
remote: cloneUDPAddr(remote),
|
upstream: upstreamConn,
|
||||||
clients: make(map[uint32]*net.UDPAddr),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve 持续双向转发原始 UDP datagram,不解析业务消息。
|
// Serve 启动双向转发循环,阻塞直到任一方向出错。
|
||||||
func (r *UDPRelay) Serve() error {
|
func (r *UDPRelay) Serve() error {
|
||||||
buffer := make([]byte, 64*1024)
|
errCh := make(chan error, 2)
|
||||||
for {
|
|
||||||
n, addr, err := r.conn.ReadFrom(buffer)
|
|
||||||
if err != nil {
|
|
||||||
if isExpectedRelayServeExit(err) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("server: udp relay read packet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
udpAddr, ok := addr.(*net.UDPAddr)
|
go func() {
|
||||||
if !ok {
|
errCh <- r.forwardDownstreamToUpstream()
|
||||||
log.Printf("udp relay dropped packet from non-udp addr %T", addr)
|
}()
|
||||||
continue
|
go func() {
|
||||||
}
|
errCh <- r.forwardUpstreamToDownstream()
|
||||||
|
}()
|
||||||
|
|
||||||
payload := append([]byte(nil), buffer[:n]...)
|
err := <-errCh
|
||||||
if sameUDPAddr(udpAddr, r.remote) {
|
// 关闭两个 conn 让另一个 goroutine 也退出
|
||||||
if err := r.forwardRemotePacket(payload); err != nil {
|
_ = r.downstream.Close()
|
||||||
log.Printf("udp relay failed forwarding remote packet from %s: %v", udpAddr, err)
|
_ = r.upstream.Close()
|
||||||
}
|
return err
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.forwardClientPacket(udpAddr, payload); err != nil {
|
|
||||||
log.Printf("udp relay failed forwarding client packet from %s: %v", udpAddr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UDPRelay) forwardClientPacket(addr *net.UDPAddr, payload []byte) error {
|
// forwardDownstreamToUpstream 从下游读取并转发到上游。
|
||||||
convID, ok := transport.ParseKCPConversationID(payload)
|
func (r *UDPRelay) forwardDownstreamToUpstream() error {
|
||||||
if !ok {
|
buf := make([]byte, udpRelayBufSize)
|
||||||
return fmt.Errorf("missing kcp conversation id")
|
for {
|
||||||
|
n, addr, err := r.downstream.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("relay: read downstream: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
r.clients[convID] = cloneUDPAddr(addr)
|
r.clientAddr = addr
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
|
|
||||||
if _, err := r.conn.WriteTo(payload, r.remote); err != nil {
|
if _, err := r.upstream.Write(buf[:n]); err != nil {
|
||||||
return fmt.Errorf("write conv %d to remote %s: %w", convID, r.remote, err)
|
return fmt.Errorf("relay: write upstream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UDPRelay) forwardRemotePacket(payload []byte) error {
|
// forwardUpstreamToDownstream 从上游读取并转发到下游。
|
||||||
convID, ok := transport.ParseKCPConversationID(payload)
|
func (r *UDPRelay) forwardUpstreamToDownstream() error {
|
||||||
if !ok {
|
buf := make([]byte, udpRelayBufSize)
|
||||||
return fmt.Errorf("missing kcp conversation id")
|
for {
|
||||||
|
n, err := r.upstream.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("relay: read upstream: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
clientAddr := cloneUDPAddr(r.clients[convID])
|
addr := r.clientAddr
|
||||||
r.mu.RUnlock()
|
r.mu.RUnlock()
|
||||||
if clientAddr == nil {
|
|
||||||
return fmt.Errorf("unknown client for conv %d", convID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := r.conn.WriteTo(payload, clientAddr); err != nil {
|
|
||||||
return fmt.Errorf("write conv %d to client %s: %w", convID, clientAddr, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr {
|
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
return nil
|
log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ipCopy := make([]byte, len(addr.IP))
|
if _, err := r.downstream.WriteToUDP(buf[:n], addr); err != nil {
|
||||||
copy(ipCopy, addr.IP)
|
return fmt.Errorf("relay: write downstream to %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
return &net.UDPAddr{
|
log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr)
|
||||||
IP: ipCopy,
|
|
||||||
Port: addr.Port,
|
|
||||||
Zone: addr.Zone,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func sameUDPAddr(left, right *net.UDPAddr) bool {
|
// Close 关闭 relay 的上下游连接。
|
||||||
if left == nil || right == nil {
|
func (r *UDPRelay) Close() error {
|
||||||
return left == right
|
err1 := r.downstream.Close()
|
||||||
|
err2 := r.upstream.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
}
|
}
|
||||||
if left.Port != right.Port || left.Zone != right.Zone {
|
return err2
|
||||||
return false
|
|
||||||
}
|
|
||||||
return left.IP.Equal(right.IP)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,103 +1,291 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
kcp "github.com/xtaci/kcp-go/v5"
|
||||||
|
|
||||||
|
"omnisocketgo/cmd/internal/latencylog"
|
||||||
|
"omnisocketgo/cmd/internal/protocol"
|
||||||
|
"omnisocketgo/cmd/internal/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUDPRelayRoutesPacketsByKCPConversationID(t *testing.T) {
|
// TestUDPRelayKCPForwardAndReturn 验证 KCP 通过 UDP relay 的完整双向转发路径:
|
||||||
remote, err := net.ListenPacket("udp", "127.0.0.1:0")
|
// peer-b -> D(KCP hub) -> C(UDP relay) -> peer-a 以及反向。
|
||||||
if err != nil {
|
func TestUDPRelayKCPForwardAndReturn(t *testing.T) {
|
||||||
t.Fatalf("ListenPacket(remote) error = %v", err)
|
// 启动 D(KCP Hub)
|
||||||
}
|
hub, hubAddr, hubCleanup := startKCPHubForRelay(t)
|
||||||
defer remote.Close()
|
defer hubCleanup()
|
||||||
|
|
||||||
relayConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
// 启动 C(UDP Relay),upstream 指向 D
|
||||||
if err != nil {
|
relayAddr := startUDPRelay(t, hubAddr)
|
||||||
t.Fatalf("ListenPacket(relay) error = %v", err)
|
|
||||||
|
// peer-b 直连 D(KCP)
|
||||||
|
peerBConn := dialKCPPeer(t, hubAddr)
|
||||||
|
// peer-a 连 C(通过 relay 间接连到 D)
|
||||||
|
peerAConn := dialKCPPeer(t, relayAddr)
|
||||||
|
|
||||||
|
// 注册 peer-b
|
||||||
|
if err := peerBConn.Send(protocol.Message{
|
||||||
|
Type: protocol.MessageTypeRegister,
|
||||||
|
From: "peer-b",
|
||||||
|
To: protocol.ServerPeerID,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("peerB register: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
relay, err := NewUDPRelay(relayConn, remote.LocalAddr().(*net.UDPAddr))
|
// 注册 peer-a(通过 relay)
|
||||||
if err != nil {
|
if err := peerAConn.Send(protocol.Message{
|
||||||
_ = relayConn.Close()
|
Type: protocol.MessageTypeRegister,
|
||||||
t.Fatalf("NewUDPRelay() error = %v", err)
|
From: "peer-a",
|
||||||
|
To: protocol.ServerPeerID,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("peerA register: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
waitForRelay(t, func() bool {
|
||||||
|
return hub.HasPeer("peer-a") && hub.HasPeer("peer-b")
|
||||||
|
}, "both peers to be registered")
|
||||||
|
|
||||||
|
// peer-b -> peer-a(路径: B -> D -> C -> A)
|
||||||
|
if err := peerBConn.Send(protocol.Message{
|
||||||
|
Type: protocol.MessageTypeText,
|
||||||
|
ID: 1,
|
||||||
|
From: "peer-b",
|
||||||
|
To: "peer-a",
|
||||||
|
Body: []byte("hello from peer-b"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("peerB send text: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := peerAConn.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("peerA receive: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Type != protocol.MessageTypeText {
|
||||||
|
t.Fatalf("message type = %s, want text", msg.Type)
|
||||||
|
}
|
||||||
|
if msg.From != "peer-b" {
|
||||||
|
t.Fatalf("message from = %s, want peer-b", msg.From)
|
||||||
|
}
|
||||||
|
if string(msg.Body) != "hello from peer-b" {
|
||||||
|
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-b")
|
||||||
|
}
|
||||||
|
|
||||||
|
// peer-a -> peer-b(路径: A -> C -> D -> B)
|
||||||
|
if err := peerAConn.Send(protocol.Message{
|
||||||
|
Type: protocol.MessageTypeText,
|
||||||
|
ID: 2,
|
||||||
|
From: "peer-a",
|
||||||
|
To: "peer-b",
|
||||||
|
Body: []byte("reply from peer-单个 downstream peer 通过 relay 连到 KCP server”这条
|
||||||
|
链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原
|
||||||
|
样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。
|
||||||
|
关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。
|
||||||
|
|
||||||
|
还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在
|
||||||
|
hub 注册成功a"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("peerA send text: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg2, err := peerBConn.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("peerB receive: %v", err)
|
||||||
|
}
|
||||||
|
if msg2.Type != protocol.MessageTypeText {
|
||||||
|
t.Fatalf("message type = %s, want text", msg2.Type)
|
||||||
|
}
|
||||||
|
if msg2.From != "peer-a" {
|
||||||
|
t.Fatalf("message from = %s, want peer-a", msg2.From)
|
||||||
|
}
|
||||||
|
if string(msg2.Body) != "reply from peer-a" {
|
||||||
|
t.Fatalf("message body = %q, want %q", string(msg2.Body), "reply from peer-a")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUDPRelayKCPFileMessage 验证通过 relay 转发 KCP 文件消息。
|
||||||
|
func TestUDPRelayKCPFileMessage(t *testing.T) {
|
||||||
|
hub, hubAddr, hubCleanup := startKCPHubForRelay(t)
|
||||||
|
defer hubCleanup()
|
||||||
|
|
||||||
|
relayAddr := startUDPRelay(t, hubAddr)
|
||||||
|
|
||||||
|
peerBConn := dialKCPPeer(t, hubAddr)
|
||||||
|
peerAConn := dialKCPPeer(t, relayAddr)
|
||||||
|
|
||||||
|
_ = peerBConn.Send(protocol.Message{
|
||||||
|
Type: protocol.MessageTypeRegister,
|
||||||
|
From: "peer-b",
|
||||||
|
To: protocol.ServerPeerID,
|
||||||
|
})
|
||||||
|
_ = peerAConn.Send(protocol.Message{
|
||||||
|
Type: protocol.MessageTypeRegister,
|
||||||
|
From: "peer-a",
|
||||||
|
To: protocol.ServerPeerID,
|
||||||
|
})
|
||||||
|
|
||||||
|
waitForRelay(t, func() bool {
|
||||||
|
return hub.HasPeer("peer-a") && hub.HasPeer("peer-b")
|
||||||
|
}, "both peers to be registered")
|
||||||
|
|
||||||
|
if err := peerBConn.Send(protocol.Message{
|
||||||
|
Type: protocol.MessageTypeFile,
|
||||||
|
ID: 1,
|
||||||
|
From: "peer-b",
|
||||||
|
To: "peer-a",
|
||||||
|
FileName: "test.bin",
|
||||||
|
Body: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("peerB send file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := peerAConn.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("peerA receive: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Type != protocol.MessageTypeFile {
|
||||||
|
t.Fatalf("message type = %s, want file", msg.Type)
|
||||||
|
}
|
||||||
|
if msg.FileName != "test.bin" {
|
||||||
|
t.Fatalf("file name = %q, want %q", msg.FileName, "test.bin")
|
||||||
|
}
|
||||||
|
if len(msg.Body) != 4 || msg.Body[0] != 0xDE {
|
||||||
|
t.Fatalf("file body mismatch: 单个 downstream peer 通过 relay 连到 KCP server”这条
|
||||||
|
链路是成立的,转发逻辑本身没有明显的地址错误。cmd/internal/server/udp_relay.go 里就是原
|
||||||
|
样双向转发,下游来的包会记录 clientAddr 并写给上游,上游回来的包再写回这个 clientAddr。
|
||||||
|
关键代码在 cmd/internal/server/udp_relay.go:68 和 cmd/internal/server/udp_relay.go:89。
|
||||||
|
|
||||||
|
还有一个关键事实:kcppeer 里那句 connected to ... as ... (KCP) 不能证明 peer-a 真的在
|
||||||
|
hub 注册成功got %v", msg.Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startKCPHubForRelay 启动一个 KCP hub server,返回 hub、监听地址和 cleanup 函数。
|
||||||
|
func startKCPHubForRelay(t *testing.T) (*KCPHub, string, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
hub := NewKCPHub()
|
||||||
|
|
||||||
|
listener, packetConn, err := transport.ListenKCPSessions("127.0.0.1:0", "", nil, latencylog.NodeRoleServer, "hub")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenKCPSessions() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
wg sync.WaitGroup
|
||||||
|
stop = make(chan struct{})
|
||||||
|
)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if serveErr := relay.Serve(); serveErr != nil {
|
for {
|
||||||
t.Errorf("relay.Serve() error = %v", serveErr)
|
session, acceptErr := listener.AcceptKCP()
|
||||||
|
if acceptErr != nil {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if strings.Contains(acceptErr.Error(), "closed") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Errorf("AcceptKCP() error = %v", acceptErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(sess *kcp.UDPSession) {
|
||||||
|
defer wg.Done()
|
||||||
|
if serveErr := hub.ServeSession(sess); serveErr != nil {
|
||||||
|
msg := serveErr.Error()
|
||||||
|
if !strings.Contains(msg, "closed") && !strings.Contains(msg, "broken pipe") {
|
||||||
|
t.Logf("hub.ServeSession() ended with %v", serveErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(session)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
defer func() {
|
|
||||||
_ = relayConn.Close()
|
cleanup := func() {
|
||||||
|
close(stop)
|
||||||
|
_ = listener.Close()
|
||||||
|
_ = packetConn.Close()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
return hub, listener.Addr().String(), cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialKCPPeer 创建一条到指定地址的 KCP 连接,用于测试。
|
||||||
|
func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
session, err := transport.DialKCPSession(serverAddr, "", "", nil, latencylog.NodeRolePeer, "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DialKCPSession(%s) error = %v", serverAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := transport.NewKCPConn(session)
|
||||||
|
if err != nil {
|
||||||
|
_ = session.Close()
|
||||||
|
t.Fatalf("NewKCPConn() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// startUDPRelay 创建并启动一个 UDPRelay,返回其监听地址字符串。
|
||||||
|
func startUDPRelay(t *testing.T, upstreamAddr string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenUDP() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relay, err := NewUDPRelay(conn, upstreamAddr)
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
t.Fatalf("NewUDPRelay() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = relay.Serve()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client1, err := net.ListenPacket("udp", "127.0.0.1:0")
|
t.Cleanup(func() {
|
||||||
if err != nil {
|
_ = relay.Close()
|
||||||
t.Fatalf("ListenPacket(client1) error = %v", err)
|
})
|
||||||
}
|
|
||||||
defer client1.Close()
|
|
||||||
|
|
||||||
client2, err := net.ListenPacket("udp", "127.0.0.1:0")
|
return conn.LocalAddr().String()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ListenPacket(client2) error = %v", err)
|
|
||||||
}
|
|
||||||
defer client2.Close()
|
|
||||||
|
|
||||||
relayAddr := relayConn.LocalAddr()
|
|
||||||
|
|
||||||
sendPacket(t, client1, relayAddr, buildRelayTestPacket(1, []byte("client-one")))
|
|
||||||
assertPacketReceived(t, remote, buildRelayTestPacket(1, []byte("client-one")))
|
|
||||||
|
|
||||||
sendPacket(t, client2, relayAddr, buildRelayTestPacket(2, []byte("client-two")))
|
|
||||||
assertPacketReceived(t, remote, buildRelayTestPacket(2, []byte("client-two")))
|
|
||||||
|
|
||||||
sendPacket(t, remote, relayAddr, buildRelayTestPacket(2, []byte("reply-two")))
|
|
||||||
assertPacketReceived(t, client2, buildRelayTestPacket(2, []byte("reply-two")))
|
|
||||||
|
|
||||||
sendPacket(t, remote, relayAddr, buildRelayTestPacket(1, []byte("reply-one")))
|
|
||||||
assertPacketReceived(t, client1, buildRelayTestPacket(1, []byte("reply-one")))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildRelayTestPacket(convID uint32, body []byte) []byte {
|
// waitForRelay 轮询等待条件满足,超时则 fail。
|
||||||
packet := make([]byte, 4+len(body))
|
func waitForRelay(t *testing.T, condition func() bool, description string) {
|
||||||
binary.LittleEndian.PutUint32(packet[:4], convID)
|
|
||||||
copy(packet[4:], body)
|
|
||||||
return packet
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendPacket(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
if err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
t.Fatalf("SetWriteDeadline() error = %v", err)
|
for time.Now().Before(deadline) {
|
||||||
|
if condition() {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if _, err := conn.WriteTo(payload, addr); err != nil {
|
time.Sleep(10 * time.Millisecond)
|
||||||
t.Fatalf("WriteTo(%s) error = %v", addr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertPacketReceived(t *testing.T, conn net.PacketConn, want []byte) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
|
||||||
t.Fatalf("SetReadDeadline() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer := make([]byte, 1024)
|
|
||||||
n, _, err := conn.ReadFrom(buffer)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadFrom() error = %v", err)
|
|
||||||
}
|
|
||||||
got := buffer[:n]
|
|
||||||
if string(got) != string(want) {
|
|
||||||
t.Fatalf("packet = %v, want %v", got, want)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Fatalf("timed out waiting for %s", description)
|
||||||
}
|
}
|
||||||
|
|||||||
37
cmd/udprelay/main.go
Normal file
37
cmd/udprelay/main.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"omnisocketgo/cmd/internal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
listenAddr := flag.String("listen", ":9003", "UDP relay listen address (downstream, for KCP peer to connect)")
|
||||||
|
upstreamAddr := flag.String("upstream", "127.0.0.1:9002", "upstream KCP server address (server D)")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", *listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("resolve listen address %s: %v", *listenAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("listen udp on %s: %v", *listenAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relay, err := server.NewUDPRelay(conn, *upstreamAddr)
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
log.Fatalf("create udp relay: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("udp relay listening on %s, upstream %s", conn.LocalAddr(), *upstreamAddr)
|
||||||
|
|
||||||
|
if err := relay.Serve(); err != nil {
|
||||||
|
log.Fatalf("udp relay serve: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user