feat:多跳(B->D->C->A)功能

This commit is contained in:
nnbcccscdscdsc
2026-03-27 23:03:00 +08:00
parent 5be3ff670f
commit 8e2bd0ffc6
8 changed files with 455 additions and 51 deletions

View File

@@ -25,6 +25,7 @@ type clientOptions struct {
udpLinuxTimestamping bool
bindIP string
bindDevice string
kcpDialAddress string
}
// Option 用于配置 Client 的可选行为,例如时延日志。
@@ -73,6 +74,13 @@ func WithBindDevice(device string) Option {
}
}
// WithKCPDialAddress 指定 KCP 实际拨号使用的 UDP 地址,可用于通过 relay 连接逻辑上的 server。
func WithKCPDialAddress(addr string) Option {
return func(options *clientOptions) {
options.kcpDialAddress = addr
}
}
// WithUDPLinuxTimestamping controls whether UDP clients enable Linux timestamping.
func WithUDPLinuxTimestamping(enabled bool) Option {
return func(options *clientOptions) {

View File

@@ -32,8 +32,13 @@ func DialKCP(serverAddr, peerID string, opts ...Option) (*KCPClient, error) {
options.logger = latencylog.NoopLogger{}
}
dialAddr := serverAddr
if options.kcpDialAddress != "" {
dialAddr = options.kcpDialAddress
}
session, err := transport.DialKCPSession(
serverAddr,
dialAddr,
options.bindIP,
options.bindDevice,
options.kcpPacketDebugLogger,

View File

@@ -267,6 +267,96 @@ func TestKCPClientsExchangeMessagesAcrossRelayedServers(t *testing.T) {
}
}
func TestKCPClientsExchangeMessagesViaUDPRelayToSingleHub(t *testing.T) {
hub := server.NewKCPHub()
serverAddr, cleanupHub := startRealKCPHubServer(t, hub)
defer cleanupHub()
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
t.Fatalf("ResolveUDPAddr(server) error = %v", err)
}
baseRelayConn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket(relay) error = %v", err)
}
relayConn := &countingPacketConn{PacketConn: baseRelayConn}
relay, err := server.NewUDPRelay(relayConn, remoteAddr)
if err != nil {
_ = relayConn.Close()
t.Fatalf("NewUDPRelay() error = %v", err)
}
var relayWG sync.WaitGroup
relayWG.Add(1)
go func() {
defer relayWG.Done()
if serveErr := relay.Serve(); serveErr != nil {
t.Errorf("relay.Serve() error = %v", serveErr)
}
}()
defer func() {
_ = relayConn.Close()
relayWG.Wait()
}()
peerA, err := DialKCP(serverAddr, "peer-a", WithKCPDialAddress(relayConn.LocalAddr().String()))
if err != nil {
t.Fatalf("DialKCP(peer-a via relay) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := DialKCP(serverAddr, "peer-b")
if err != nil {
t.Fatalf("DialKCP(peer-b direct) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered on the single hub")
if err := peerB.SendText("peer-a", "hello via udp relay"); err != nil {
t.Fatalf("peerB.SendText() error = %v", err)
}
gotAtA, err := peerA.Receive()
if err != nil {
t.Fatalf("peerA.Receive() error = %v", err)
}
wantAtA := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-b",
To: "peer-a",
Body: []byte("hello via udp relay"),
}
if !reflect.DeepEqual(gotAtA, wantAtA) {
t.Fatalf("peerA received %+v, want %+v", gotAtA, wantAtA)
}
if err := peerA.SendText("peer-b", "hello back through relay"); err != nil {
t.Fatalf("peerA.SendText() error = %v", err)
}
gotAtB, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
wantAtB := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello back through relay"),
}
if !reflect.DeepEqual(gotAtB, wantAtB) {
t.Fatalf("peerB received %+v, want %+v", gotAtB, wantAtB)
}
if got := relayConn.WriteCount(); got == 0 {
t.Fatal("relay should have forwarded packets for peer-a session")
}
}
func TestKCPHubPrefersLocalPeerBeforeRelay(t *testing.T) {
fixture := startRelayedKCPHubs(t)
defer fixture.cleanup()

View File

@@ -0,0 +1,127 @@
package server
import (
"fmt"
"log"
"net"
"sync"
"omnisocketgo/cmd/internal/transport"
)
// UDPRelay 负责在固定远端与多个客户端之间双向透明转发 KCP UDP datagram。
type UDPRelay struct {
conn net.PacketConn
remote *net.UDPAddr
mu sync.RWMutex
clients map[uint32]*net.UDPAddr
}
// NewUDPRelay 创建一个绑定到给定 PacketConn 的透明 UDP relay。
func NewUDPRelay(conn net.PacketConn, remote *net.UDPAddr) (*UDPRelay, error) {
if conn == nil {
return nil, fmt.Errorf("server: nil udp relay conn")
}
if remote == nil {
return nil, fmt.Errorf("server: nil udp relay remote")
}
return &UDPRelay{
conn: conn,
remote: cloneUDPAddr(remote),
clients: make(map[uint32]*net.UDPAddr),
}, nil
}
// Serve 持续双向转发原始 UDP datagram不解析业务消息。
func (r *UDPRelay) Serve() error {
buffer := make([]byte, 64*1024)
for {
n, addr, err := r.conn.ReadFrom(buffer)
if err != nil {
if isExpectedRelayServeExit(err) {
return nil
}
return fmt.Errorf("server: udp relay read packet: %w", err)
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
log.Printf("udp relay dropped packet from non-udp addr %T", addr)
continue
}
payload := append([]byte(nil), buffer[:n]...)
if sameUDPAddr(udpAddr, r.remote) {
if err := r.forwardRemotePacket(payload); err != nil {
log.Printf("udp relay failed forwarding remote packet from %s: %v", udpAddr, err)
}
continue
}
if err := r.forwardClientPacket(udpAddr, payload); err != nil {
log.Printf("udp relay failed forwarding client packet from %s: %v", udpAddr, err)
}
}
}
func (r *UDPRelay) forwardClientPacket(addr *net.UDPAddr, payload []byte) error {
convID, ok := transport.ParseKCPConversationID(payload)
if !ok {
return fmt.Errorf("missing kcp conversation id")
}
r.mu.Lock()
r.clients[convID] = cloneUDPAddr(addr)
r.mu.Unlock()
if _, err := r.conn.WriteTo(payload, r.remote); err != nil {
return fmt.Errorf("write conv %d to remote %s: %w", convID, r.remote, err)
}
return nil
}
func (r *UDPRelay) forwardRemotePacket(payload []byte) error {
convID, ok := transport.ParseKCPConversationID(payload)
if !ok {
return fmt.Errorf("missing kcp conversation id")
}
r.mu.RLock()
clientAddr := cloneUDPAddr(r.clients[convID])
r.mu.RUnlock()
if clientAddr == nil {
return fmt.Errorf("unknown client for conv %d", convID)
}
if _, err := r.conn.WriteTo(payload, clientAddr); err != nil {
return fmt.Errorf("write conv %d to client %s: %w", convID, clientAddr, err)
}
return nil
}
func cloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr {
if addr == nil {
return nil
}
ipCopy := make([]byte, len(addr.IP))
copy(ipCopy, addr.IP)
return &net.UDPAddr{
IP: ipCopy,
Port: addr.Port,
Zone: addr.Zone,
}
}
func sameUDPAddr(left, right *net.UDPAddr) bool {
if left == nil || right == nil {
return left == right
}
if left.Port != right.Port || left.Zone != right.Zone {
return false
}
return left.IP.Equal(right.IP)
}

View File

@@ -0,0 +1,103 @@
package server
import (
"encoding/binary"
"net"
"sync"
"testing"
"time"
)
func TestUDPRelayRoutesPacketsByKCPConversationID(t *testing.T) {
remote, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket(remote) error = %v", err)
}
defer remote.Close()
relayConn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket(relay) error = %v", err)
}
relay, err := NewUDPRelay(relayConn, remote.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = relayConn.Close()
t.Fatalf("NewUDPRelay() error = %v", err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if serveErr := relay.Serve(); serveErr != nil {
t.Errorf("relay.Serve() error = %v", serveErr)
}
}()
defer func() {
_ = relayConn.Close()
wg.Wait()
}()
client1, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket(client1) error = %v", err)
}
defer client1.Close()
client2, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket(client2) error = %v", err)
}
defer client2.Close()
relayAddr := relayConn.LocalAddr()
sendPacket(t, client1, relayAddr, buildRelayTestPacket(1, []byte("client-one")))
assertPacketReceived(t, remote, buildRelayTestPacket(1, []byte("client-one")))
sendPacket(t, client2, relayAddr, buildRelayTestPacket(2, []byte("client-two")))
assertPacketReceived(t, remote, buildRelayTestPacket(2, []byte("client-two")))
sendPacket(t, remote, relayAddr, buildRelayTestPacket(2, []byte("reply-two")))
assertPacketReceived(t, client2, buildRelayTestPacket(2, []byte("reply-two")))
sendPacket(t, remote, relayAddr, buildRelayTestPacket(1, []byte("reply-one")))
assertPacketReceived(t, client1, buildRelayTestPacket(1, []byte("reply-one")))
}
func buildRelayTestPacket(convID uint32, body []byte) []byte {
packet := make([]byte, 4+len(body))
binary.LittleEndian.PutUint32(packet[:4], convID)
copy(packet[4:], body)
return packet
}
func sendPacket(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) {
t.Helper()
if err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("SetWriteDeadline() error = %v", err)
}
if _, err := conn.WriteTo(payload, addr); err != nil {
t.Fatalf("WriteTo(%s) error = %v", addr, err)
}
}
func assertPacketReceived(t *testing.T, conn net.PacketConn, want []byte) {
t.Helper()
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("SetReadDeadline() error = %v", err)
}
buffer := make([]byte, 1024)
n, _, err := conn.ReadFrom(buffer)
if err != nil {
t.Fatalf("ReadFrom() error = %v", err)
}
got := buffer[:n]
if string(got) != string(want) {
t.Fatalf("packet = %v, want %v", got, want)
}
}

View File

@@ -21,6 +21,15 @@ func parseKCPConversationID(packet []byte) *uint32 {
return &conv
}
// ParseKCPConversationID 从原始 KCP UDP datagram 中提取 conv ID。
func ParseKCPConversationID(packet []byte) (uint32, bool) {
conv := parseKCPConversationID(packet)
if conv == nil {
return 0, false
}
return *conv, true
}
func parseKCPPacketSegments(packet []byte) ([]KCPPacketDebugSegment, bool) {
if len(packet) == 0 {
return nil, false