diff --git a/cmd/internal/peer/client.go b/cmd/internal/peer/client.go index f31699e..108e970 100644 --- a/cmd/internal/peer/client.go +++ b/cmd/internal/peer/client.go @@ -25,6 +25,7 @@ type clientOptions struct { udpLinuxTimestamping bool bindIP string bindDevice string + kcpDialAddress string } // Option 用于配置 Client 的可选行为,例如时延日志。 @@ -73,6 +74,13 @@ func WithBindDevice(device string) Option { } } +// WithKCPDialAddress 指定 KCP 实际拨号使用的 UDP 地址,可用于通过 relay 连接逻辑上的 server。 +func WithKCPDialAddress(addr string) Option { + return func(options *clientOptions) { + options.kcpDialAddress = addr + } +} + // WithUDPLinuxTimestamping controls whether UDP clients enable Linux timestamping. func WithUDPLinuxTimestamping(enabled bool) Option { return func(options *clientOptions) { diff --git a/cmd/internal/peer/kcp_client.go b/cmd/internal/peer/kcp_client.go index 130294e..246d8e3 100644 --- a/cmd/internal/peer/kcp_client.go +++ b/cmd/internal/peer/kcp_client.go @@ -32,8 +32,13 @@ func DialKCP(serverAddr, peerID string, opts ...Option) (*KCPClient, error) { options.logger = latencylog.NoopLogger{} } + dialAddr := serverAddr + if options.kcpDialAddress != "" { + dialAddr = options.kcpDialAddress + } + session, err := transport.DialKCPSession( - serverAddr, + dialAddr, options.bindIP, options.bindDevice, options.kcpPacketDebugLogger, diff --git a/cmd/internal/peer/kcp_client_test.go b/cmd/internal/peer/kcp_client_test.go index 51392c4..5bc51f0 100644 --- a/cmd/internal/peer/kcp_client_test.go +++ b/cmd/internal/peer/kcp_client_test.go @@ -267,6 +267,96 @@ func TestKCPClientsExchangeMessagesAcrossRelayedServers(t *testing.T) { } } +func TestKCPClientsExchangeMessagesViaUDPRelayToSingleHub(t *testing.T) { + hub := server.NewKCPHub() + serverAddr, cleanupHub := startRealKCPHubServer(t, hub) + defer cleanupHub() + + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + t.Fatalf("ResolveUDPAddr(server) error = %v", err) + } + + baseRelayConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenPacket(relay) error = %v", err) + } + relayConn := &countingPacketConn{PacketConn: baseRelayConn} + + relay, err := server.NewUDPRelay(relayConn, remoteAddr) + if err != nil { + _ = relayConn.Close() + t.Fatalf("NewUDPRelay() error = %v", err) + } + + var relayWG sync.WaitGroup + relayWG.Add(1) + go func() { + defer relayWG.Done() + if serveErr := relay.Serve(); serveErr != nil { + t.Errorf("relay.Serve() error = %v", serveErr) + } + }() + defer func() { + _ = relayConn.Close() + relayWG.Wait() + }() + + peerA, err := DialKCP(serverAddr, "peer-a", WithKCPDialAddress(relayConn.LocalAddr().String())) + if err != nil { + t.Fatalf("DialKCP(peer-a via relay) error = %v", err) + } + defer func() { _ = peerA.Close() }() + + peerB, err := DialKCP(serverAddr, "peer-b") + if err != nil { + t.Fatalf("DialKCP(peer-b direct) error = %v", err) + } + defer func() { _ = peerB.Close() }() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered on the single hub") + + if err := peerB.SendText("peer-a", "hello via udp relay"); err != nil { + t.Fatalf("peerB.SendText() error = %v", err) + } + gotAtA, err := peerA.Receive() + if err != nil { + t.Fatalf("peerA.Receive() error = %v", err) + } + wantAtA := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-b", + To: "peer-a", + Body: []byte("hello via udp relay"), + } + if !reflect.DeepEqual(gotAtA, wantAtA) { + t.Fatalf("peerA received %+v, want %+v", gotAtA, wantAtA) + } + + if err := peerA.SendText("peer-b", "hello back through relay"); err != nil { + t.Fatalf("peerA.SendText() error = %v", err) + } + gotAtB, err := peerB.Receive() + if err != nil { + t.Fatalf("peerB.Receive() error = %v", err) + } + wantAtB := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello back through relay"), + } + if !reflect.DeepEqual(gotAtB, wantAtB) { + t.Fatalf("peerB received %+v, want %+v", gotAtB, wantAtB) + } + + if got := relayConn.WriteCount(); got == 0 { + t.Fatal("relay should have forwarded packets for peer-a session") + } +} + func TestKCPHubPrefersLocalPeerBeforeRelay(t *testing.T) { fixture := startRelayedKCPHubs(t) defer fixture.cleanup() diff --git a/cmd/internal/server/udp_relay.go b/cmd/internal/server/udp_relay.go new file mode 100644 index 0000000..3c39d73 --- /dev/null +++ b/cmd/internal/server/udp_relay.go @@ -0,0 +1,127 @@ +package server + +import ( + "fmt" + "log" + "net" + "sync" + + "omnisocketgo/cmd/internal/transport" +) + +// UDPRelay 负责在固定远端与多个客户端之间双向透明转发 KCP UDP datagram。 +type UDPRelay struct { + conn net.PacketConn + remote *net.UDPAddr + + mu sync.RWMutex + clients map[uint32]*net.UDPAddr +} + +// NewUDPRelay 创建一个绑定到给定 PacketConn 的透明 UDP relay。 +func NewUDPRelay(conn net.PacketConn, remote *net.UDPAddr) (*UDPRelay, error) { + if conn == nil { + return nil, fmt.Errorf("server: nil udp relay conn") + } + if remote == nil { + return nil, fmt.Errorf("server: nil udp relay remote") + } + + return &UDPRelay{ + conn: conn, + remote: cloneUDPAddr(remote), + clients: make(map[uint32]*net.UDPAddr), + }, nil +} + +// Serve 持续双向转发原始 UDP datagram,不解析业务消息。 +func (r *UDPRelay) Serve() error { + buffer := make([]byte, 64*1024) + for { + n, addr, err := r.conn.ReadFrom(buffer) + if err != nil { + if isExpectedRelayServeExit(err) { + return nil + } + return fmt.Errorf("server: udp relay read packet: %w", err) + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + log.Printf("udp relay dropped packet from non-udp addr %T", addr) + continue + } + + payload := append([]byte(nil), buffer[:n]...) + if sameUDPAddr(udpAddr, r.remote) { + if err := r.forwardRemotePacket(payload); err != nil { + log.Printf("udp relay failed forwarding remote packet from %s: %v", udpAddr, err) + } + continue + } + + if err := r.forwardClientPacket(udpAddr, payload); err != nil { + log.Printf("udp relay failed forwarding client packet from %s: %v", udpAddr, err) + } + } +} + +func (r *UDPRelay) forwardClientPacket(addr *net.UDPAddr, payload []byte) error { + convID, ok := transport.ParseKCPConversationID(payload) + if !ok { + return fmt.Errorf("missing kcp conversation id") + } + + r.mu.Lock() + r.clients[convID] = cloneUDPAddr(addr) + r.mu.Unlock() + + if _, err := r.conn.WriteTo(payload, r.remote); err != nil { + return fmt.Errorf("write conv %d to remote %s: %w", convID, r.remote, err) + } + return nil +} + +func (r *UDPRelay) forwardRemotePacket(payload []byte) error { + convID, ok := transport.ParseKCPConversationID(payload) + if !ok { + return fmt.Errorf("missing kcp conversation id") + } + + r.mu.RLock() + clientAddr := cloneUDPAddr(r.clients[convID]) + r.mu.RUnlock() + if clientAddr == nil { + return fmt.Errorf("unknown client for conv %d", convID) + } + + if _, err := r.conn.WriteTo(payload, clientAddr); err != nil { + return fmt.Errorf("write conv %d to client %s: %w", convID, clientAddr, err) + } + return nil +} + +func cloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr { + if addr == nil { + return nil + } + + ipCopy := make([]byte, len(addr.IP)) + copy(ipCopy, addr.IP) + + return &net.UDPAddr{ + IP: ipCopy, + Port: addr.Port, + Zone: addr.Zone, + } +} + +func sameUDPAddr(left, right *net.UDPAddr) bool { + if left == nil || right == nil { + return left == right + } + if left.Port != right.Port || left.Zone != right.Zone { + return false + } + return left.IP.Equal(right.IP) +} diff --git a/cmd/internal/server/udp_relay_test.go b/cmd/internal/server/udp_relay_test.go new file mode 100644 index 0000000..f866211 --- /dev/null +++ b/cmd/internal/server/udp_relay_test.go @@ -0,0 +1,103 @@ +package server + +import ( + "encoding/binary" + "net" + "sync" + "testing" + "time" +) + +func TestUDPRelayRoutesPacketsByKCPConversationID(t *testing.T) { + remote, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenPacket(remote) error = %v", err) + } + defer remote.Close() + + relayConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenPacket(relay) error = %v", err) + } + + relay, err := NewUDPRelay(relayConn, remote.LocalAddr().(*net.UDPAddr)) + if err != nil { + _ = relayConn.Close() + t.Fatalf("NewUDPRelay() error = %v", err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if serveErr := relay.Serve(); serveErr != nil { + t.Errorf("relay.Serve() error = %v", serveErr) + } + }() + defer func() { + _ = relayConn.Close() + wg.Wait() + }() + + client1, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenPacket(client1) error = %v", err) + } + defer client1.Close() + + client2, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenPacket(client2) error = %v", err) + } + defer client2.Close() + + relayAddr := relayConn.LocalAddr() + + sendPacket(t, client1, relayAddr, buildRelayTestPacket(1, []byte("client-one"))) + assertPacketReceived(t, remote, buildRelayTestPacket(1, []byte("client-one"))) + + sendPacket(t, client2, relayAddr, buildRelayTestPacket(2, []byte("client-two"))) + assertPacketReceived(t, remote, buildRelayTestPacket(2, []byte("client-two"))) + + sendPacket(t, remote, relayAddr, buildRelayTestPacket(2, []byte("reply-two"))) + assertPacketReceived(t, client2, buildRelayTestPacket(2, []byte("reply-two"))) + + sendPacket(t, remote, relayAddr, buildRelayTestPacket(1, []byte("reply-one"))) + assertPacketReceived(t, client1, buildRelayTestPacket(1, []byte("reply-one"))) +} + +func buildRelayTestPacket(convID uint32, body []byte) []byte { + packet := make([]byte, 4+len(body)) + binary.LittleEndian.PutUint32(packet[:4], convID) + copy(packet[4:], body) + return packet +} + +func sendPacket(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) { + t.Helper() + + if err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("SetWriteDeadline() error = %v", err) + } + if _, err := conn.WriteTo(payload, addr); err != nil { + t.Fatalf("WriteTo(%s) error = %v", addr, err) + } +} + +func assertPacketReceived(t *testing.T, conn net.PacketConn, want []byte) { + t.Helper() + + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("SetReadDeadline() error = %v", err) + } + + buffer := make([]byte, 1024) + n, _, err := conn.ReadFrom(buffer) + if err != nil { + t.Fatalf("ReadFrom() error = %v", err) + } + got := buffer[:n] + if string(got) != string(want) { + t.Fatalf("packet = %v, want %v", got, want) + } +} diff --git a/cmd/internal/transport/kcp_packet_metadata.go b/cmd/internal/transport/kcp_packet_metadata.go index 7b50adc..1469814 100644 --- a/cmd/internal/transport/kcp_packet_metadata.go +++ b/cmd/internal/transport/kcp_packet_metadata.go @@ -21,6 +21,15 @@ func parseKCPConversationID(packet []byte) *uint32 { return &conv } +// ParseKCPConversationID 从原始 KCP UDP datagram 中提取 conv ID。 +func ParseKCPConversationID(packet []byte) (uint32, bool) { + conv := parseKCPConversationID(packet) + if conv == nil { + return 0, false + } + return *conv, true +} + func parseKCPPacketSegments(packet []byte) ([]KCPPacketDebugSegment, bool) { if len(packet) == 0 { return nil, false diff --git a/cmd/kcppeer/main.go b/cmd/kcppeer/main.go index cc90d3d..3d1ffa7 100644 --- a/cmd/kcppeer/main.go +++ b/cmd/kcppeer/main.go @@ -17,6 +17,7 @@ import ( func main() { peerID := flag.String("id", "peer-a", "peer identity") serverAddr := flag.String("server", "127.0.0.1:9002", "KCP server address") + relayVia := flag.String("relay-via", "", "optional UDP relay address used to reach the KCP server") targetPeer := flag.String("to", "", "optional target peer for one outgoing message") text := flag.String("text", "", "optional text to send after connecting") filePath := flag.String("file", "", "optional file path to send after connecting") @@ -66,6 +67,9 @@ func main() { if *bindDevice != "" { clientOptions = append(clientOptions, peerpkg.WithBindDevice(*bindDevice)) } + if *relayVia != "" { + clientOptions = append(clientOptions, peerpkg.WithKCPDialAddress(*relayVia)) + } client, err := peerpkg.DialKCP(*serverAddr, *peerID, clientOptions...) if err != nil { @@ -73,7 +77,11 @@ func main() { } defer client.Close() - log.Printf("connected to %s as %s (KCP)", *serverAddr, client.ID()) + if *relayVia != "" { + log.Printf("connected to %s via relay %s as %s (KCP)", *serverAddr, *relayVia, client.ID()) + } else { + log.Printf("connected to %s as %s (KCP)", *serverAddr, client.ID()) + } receiveErr := make(chan error, 1) go func() { diff --git a/cmd/kcpserver/main.go b/cmd/kcpserver/main.go index 99d1956..75cf507 100644 --- a/cmd/kcpserver/main.go +++ b/cmd/kcpserver/main.go @@ -5,6 +5,7 @@ import ( "log" "net" "strings" + "time" kcp "github.com/xtaci/kcp-go/v5" @@ -13,94 +14,124 @@ import ( "omnisocketgo/cmd/internal/transport" ) +const ( + kcpServerModeHub = "hub" + kcpServerModeRelay = "relay" +) + func main() { - listenAddr := flag.String("listen", ":9002", "KCP server listen address") + mode := flag.String("mode", kcpServerModeHub, "kcpserver mode: hub or relay") + listenAddr := flag.String("listen", ":9002", "listen address; KCP listener in hub mode, UDP relay listener in relay mode") bindDevice := flag.String("bind-device", "", "optional Linux network device used when listening") logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs") kcpTimestampDebugLogPath := flag.String("kcp-ts-debug-log", "", "optional JSONL file path for KCP packet kernel timestamp debug records") kcpSessionStatsLogPath := flag.String("kcp-session-stats-log", "", "optional JSONL file path for KCP session stats records") kcpSessionStatsInterval := flag.String("kcp-session-stats-interval", transport.DefaultKCPSessionStatsInterval.String(), "sampling interval for KCP session stats, for example 100ms") - relayListenAddr := flag.String("relay-listen", "", "optional raw UDP relay listen address") - relayPeerAddr := flag.String("relay-peer", "", "optional fixed raw UDP relay peer address") - relayLearnPeer := flag.Bool("relay-learn-peer", false, "learn the relay peer address from the first inbound relay packet") + relayListenAddr := flag.String("relay-listen", "", "deprecated alias for -listen in relay mode") + relayRemoteAddr := flag.String("relay-remote", "", "fixed remote UDP address used in relay mode") + relayPeerAddr := flag.String("relay-peer", "", "deprecated alias for -relay-remote") flag.Parse() + var relayRemoteFlagSet bool + var relayPeerFlagSet bool + var relayListenFlagSet bool + flag.Visit(func(f *flag.Flag) { + switch f.Name { + case "relay-listen": + relayListenFlagSet = true + case "relay-remote": + relayRemoteFlagSet = true + case "relay-peer": + relayPeerFlagSet = true + } + }) + + switch { + case relayRemoteFlagSet && relayPeerFlagSet && *relayRemoteAddr != *relayPeerAddr: + log.Fatal("flags -relay-remote and -relay-peer must match when both are set") + case *relayRemoteAddr == "" && *relayPeerAddr != "": + *relayRemoteAddr = *relayPeerAddr + } + if relayPeerFlagSet { + log.Printf("warning: flag -relay-peer is deprecated; use -relay-remote instead") + } + if relayListenFlagSet { + if *relayListenAddr == "" { + log.Fatal("flag -relay-listen must not be empty when set") + } + if *mode != kcpServerModeRelay { + log.Fatal("flag -relay-listen may only be used in relay mode") + } + if *listenAddr != ":9002" && *listenAddr != *relayListenAddr { + log.Fatal("flags -listen and -relay-listen must match when both are set in relay mode") + } + *listenAddr = *relayListenAddr + log.Printf("warning: flag -relay-listen is deprecated; use -listen with -mode=relay instead") + } + statsInterval, err := transport.ParseKCPSessionStatsInterval(*kcpSessionStatsInterval) if err != nil { log.Fatalf("parse -kcp-session-stats-interval=%q: %v", *kcpSessionStatsInterval, err) } + switch *mode { + case kcpServerModeHub: + if *relayRemoteAddr != "" { + log.Fatal("flag -relay-remote may only be used in relay mode") + } + runHubServer(*listenAddr, *bindDevice, *logPath, *kcpTimestampDebugLogPath, *kcpSessionStatsLogPath, statsInterval) + case kcpServerModeRelay: + if *bindDevice != "" { + log.Fatal("flag -bind-device is not supported in relay mode") + } + if *relayRemoteAddr == "" { + log.Fatal("flag -relay-remote is required in relay mode") + } + runUDPRelayServer(*listenAddr, *relayRemoteAddr) + default: + log.Fatalf("unsupported -mode=%q; want %q or %q", *mode, kcpServerModeHub, kcpServerModeRelay) + } +} + +func runHubServer(listenAddr, bindDevice, logPath, packetDebugLogPath, sessionStatsLogPath string, statsInterval time.Duration) { hubOptions := make([]server.KCPOption, 0, 2) - if *logPath != "" { - logger, err := latencylog.NewJSONLLogger(*logPath) + if logPath != "" { + logger, err := latencylog.NewJSONLLogger(logPath) if err != nil { - log.Fatalf("create latency logger %s: %v", *logPath, err) + log.Fatalf("create latency logger %s: %v", logPath, err) } defer logger.Close() hubOptions = append(hubOptions, server.WithKCPLogger(logger)) } var packetLogger transport.KCPPacketDebugLogger - if *kcpTimestampDebugLogPath != "" { - logger, err := transport.NewJSONLKCPPacketDebugLogger(*kcpTimestampDebugLogPath) + if packetDebugLogPath != "" { + logger, err := transport.NewJSONLKCPPacketDebugLogger(packetDebugLogPath) if err != nil { - log.Fatalf("create kcp packet debug logger %s: %v", *kcpTimestampDebugLogPath, err) + log.Fatalf("create kcp packet debug logger %s: %v", packetDebugLogPath, err) } defer logger.Close() packetLogger = logger } - if *kcpSessionStatsLogPath != "" { - logger, err := transport.NewJSONLKCPSessionStatsLogger(*kcpSessionStatsLogPath) + if sessionStatsLogPath != "" { + logger, err := transport.NewJSONLKCPSessionStatsLogger(sessionStatsLogPath) if err != nil { - log.Fatalf("create kcp session stats logger %s: %v", *kcpSessionStatsLogPath, err) + log.Fatalf("create kcp session stats logger %s: %v", sessionStatsLogPath, err) } defer logger.Close() hubOptions = append(hubOptions, server.WithKCPSessionStatsLogger(logger, statsInterval)) } - listener, packetConn, err := transport.ListenKCPSessions(*listenAddr, *bindDevice, packetLogger, latencylog.NodeRoleServer, "hub") + listener, packetConn, err := transport.ListenKCPSessions(listenAddr, bindDevice, packetLogger, latencylog.NodeRoleServer, "hub") if err != nil { - log.Fatalf("listen kcp on %s: %v", *listenAddr, err) + log.Fatalf("listen kcp on %s: %v", listenAddr, err) } defer packetConn.Close() defer listener.Close() hub := server.NewKCPHub(hubOptions...) - if *relayPeerAddr != "" && *relayListenAddr == "" { - log.Fatal("flag -relay-listen is required when -relay-peer is set") - } - if *relayLearnPeer && *relayListenAddr == "" { - log.Fatal("flag -relay-listen is required when -relay-learn-peer is set") - } - if *relayListenAddr != "" { - relayConn, err := net.ListenPacket("udp", *relayListenAddr) - if err != nil { - log.Fatalf("listen relay udp on %s: %v", *relayListenAddr, err) - } - defer relayConn.Close() - - var relayPeer net.Addr - if *relayPeerAddr != "" { - relayPeer, err = net.ResolveUDPAddr("udp", *relayPeerAddr) - if err != nil { - log.Fatalf("resolve relay peer %s: %v", *relayPeerAddr, err) - } - } - - hub.SetRelaySocket(relayConn, relayPeer, *relayLearnPeer) - go func() { - if serveErr := hub.ServeRelay(); serveErr != nil { - log.Printf("kcp relay loop ended: %v", serveErr) - } - }() - log.Printf("kcp relay listening on %s", relayConn.LocalAddr()) - if relayPeer != nil { - log.Printf("kcp relay peer configured as %s", relayPeer) - } - } - - log.Printf("kcp server listening on %s", listener.Addr()) + log.Printf("kcp hub listening on %s", listener.Addr()) for { session, err := listener.AcceptKCP() @@ -119,3 +150,26 @@ func main() { }(session) } } + +func runUDPRelayServer(listenAddr, remoteAddr string) { + conn, err := net.ListenPacket("udp", listenAddr) + if err != nil { + log.Fatalf("listen udp relay on %s: %v", listenAddr, err) + } + defer conn.Close() + + remote, err := net.ResolveUDPAddr("udp", remoteAddr) + if err != nil { + log.Fatalf("resolve relay remote %s: %v", remoteAddr, err) + } + + relay, err := server.NewUDPRelay(conn, remote) + if err != nil { + log.Fatalf("create udp relay: %v", err) + } + + log.Printf("udp relay listening on %s and forwarding to %s", conn.LocalAddr(), remote) + if err := relay.Serve(); err != nil { + log.Fatalf("udp relay stopped: %v", err) + } +}