From c126b0596107599f8de63a37a6d7f68ca3f51121 Mon Sep 17 00:00:00 2001 From: Mock Date: Tue, 24 Mar 2026 15:39:00 +0800 Subject: [PATCH] =?UTF-8?q?Feat:=20UDP=20=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .claude/settings.local.json | 9 + cmd/internal/peer/udp_client.go | 202 +++++++++++++ cmd/internal/peer/udp_client_test.go | 211 +++++++++++++ cmd/internal/server/udp_hub.go | 185 ++++++++++++ cmd/internal/server/udp_hub_test.go | 238 +++++++++++++++ cmd/internal/transport/udp.go | 141 +++++++++ cmd/internal/transport/udp_linux.go | 358 +++++++++++++++++++++++ cmd/internal/transport/udp_linux_test.go | 105 +++++++ cmd/internal/transport/udp_test.go | 358 +++++++++++++++++++++++ cmd/udppeer/interactive.go | 70 +++++ cmd/udppeer/main.go | 194 ++++++++++++ cmd/udpserver/main.go | 48 +++ 12 files changed, 2119 insertions(+) create mode 100644 .claude/settings.local.json create mode 100644 cmd/internal/peer/udp_client.go create mode 100644 cmd/internal/peer/udp_client_test.go create mode 100644 cmd/internal/server/udp_hub.go create mode 100644 cmd/internal/server/udp_hub_test.go create mode 100644 cmd/internal/transport/udp.go create mode 100644 cmd/internal/transport/udp_linux.go create mode 100644 cmd/internal/transport/udp_linux_test.go create mode 100644 cmd/internal/transport/udp_test.go create mode 100644 cmd/udppeer/interactive.go create mode 100644 cmd/udppeer/main.go create mode 100644 cmd/udpserver/main.go diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..3a6bd95 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,9 @@ +{ + "permissions": { + "allow": [ + "Bash(go vet:*)", + "Bash(go build:*)", + "Bash(go test:*)" + ] + } +} diff --git a/cmd/internal/peer/udp_client.go b/cmd/internal/peer/udp_client.go new file mode 100644 index 0000000..ca8eccc --- /dev/null +++ b/cmd/internal/peer/udp_client.go @@ -0,0 +1,202 @@ +package peer + +import ( + "fmt" + "net" + "os" + "path/filepath" + "sync/atomic" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" +) + +// UDPClient 表示一个通过 UDP 连接到 server 的 peer。 +type UDPClient struct { + id string + conn *transport.UDPConn + logger latencylog.Logger + + nextID uint64 +} + +// DialUDP 通过 UDP 连接到 server,并发送 register 消息完成身份注册。 +func DialUDP(serverAddr, peerID string, opts ...Option) (*UDPClient, error) { + options := clientOptions{ + logger: latencylog.NoopLogger{}, + } + for _, opt := range opts { + opt(&options) + } + if options.logger == nil { + options.logger = latencylog.NoopLogger{} + } + + udpServerAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + return nil, fmt.Errorf("peer: resolve udp server addr %s: %w", serverAddr, err) + } + + var localAddr *net.UDPAddr + if options.bindIP != "" { + ip := net.ParseIP(options.bindIP) + if ip == nil { + return nil, fmt.Errorf("peer: invalid bind ip %q", options.bindIP) + } + localAddr = &net.UDPAddr{IP: ip} + } + + rawConn, err := net.DialUDP("udp", localAddr, udpServerAddr) + if err != nil { + return nil, fmt.Errorf("peer: dial udp server %s: %w", serverAddr, err) + } + + conn, err := transport.NewUDPConn( + rawConn, + nil, // peer 侧已连接模式,不需要指定 peerAddr + transport.WithUDPLogger(options.logger, latencylog.NodeRolePeer, peerID), + transport.WithUDPTXTimestampDebugLogger(options.txTimestampDebugLogger), + ) + if err != nil { + _ = rawConn.Close() + return nil, fmt.Errorf("peer: create udp transport conn: %w", err) + } + + client := &UDPClient{ + id: peerID, + conn: conn, + logger: options.logger, + } + + // 发送 register 消息完成身份注册 + if err := conn.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: peerID, + To: protocol.ServerPeerID, + }); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("peer: udp register with server: %w", err) + } + + return client, nil +} + +// ID 返回当前 client 的 peer 标识。 +func (c *UDPClient) ID() string { + return c.id +} + +// SendText 向目标 peer 发送一条文本消息。 +func (c *UDPClient) SendText(to, body string) error { + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: c.nextMessageID(), + From: c.id, + To: to, + } + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg) + + msg.Body = []byte(body) + + return c.conn.Send(msg) +} + +// SendFile 向目标 peer 发送一条文件消息。 +func (c *UDPClient) SendFile(to, fileName string, body []byte) error { + msg := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: c.nextMessageID(), + From: c.id, + To: to, + FileName: fileName, + } + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg) + + bodyCopy := make([]byte, len(body)) + copy(bodyCopy, body) + + msg.Body = bodyCopy + + return c.conn.Send(msg) +} + +// SendFilePath 从本地文件读取内容并发送给目标 peer。 +func (c *UDPClient) SendFilePath(to, path string) error { + msg := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: c.nextMessageID(), + From: c.id, + To: to, + FileName: filepath.Base(path), + } + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg) + + body, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("peer: read file %s: %w", path, err) + } + + msg.Body = body + + return c.conn.Send(msg) +} + +// Receive 读取一条来自 server 的消息。 +func (c *UDPClient) Receive() (protocol.Message, error) { + msg, _, err := c.conn.Receive() + if err != nil { + return protocol.Message{}, fmt.Errorf("peer: udp receive from server: %w", err) + } + + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg) + + return msg, nil +} + +// ReceiveLoop 持续接收 server 消息并交给 handler 处理。 +func (c *UDPClient) ReceiveLoop(handler func(protocol.Message) error) error { + return c.conn.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error { + switch msg.Type { + case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError: + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg) + return handler(msg) + default: + return fmt.Errorf("peer: unexpected message type from server: %s", msg.Type) + } + }) +} + +// PersistMessage 将收到的业务消息写入本地磁盘。 +func (c *UDPClient) PersistMessage(msg protocol.Message, inboxDir string) (string, error) { + if !latencylog.IsBusinessMessage(msg) { + return "", fmt.Errorf("peer: cannot persist message type %s", msg.Type) + } + if inboxDir == "" { + return "", fmt.Errorf("peer: inbox directory is required") + } + + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistBegin, msg) + + if err := os.MkdirAll(inboxDir, 0o755); err != nil { + return "", fmt.Errorf("peer: create inbox dir %s: %w", inboxDir, err) + } + + path, err := persistMessageToDisk(msg, inboxDir) + if err != nil { + return "", err + } + + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistEnd, msg) + + return path, nil +} + +// Close 关闭与 server 的 UDP 连接。 +func (c *UDPClient) Close() error { + return c.conn.Close() +} + +func (c *UDPClient) nextMessageID() uint64 { + return atomic.AddUint64(&c.nextID, 1) +} diff --git a/cmd/internal/peer/udp_client_test.go b/cmd/internal/peer/udp_client_test.go new file mode 100644 index 0000000..95aa4c1 --- /dev/null +++ b/cmd/internal/peer/udp_client_test.go @@ -0,0 +1,211 @@ +package peer + +import ( + "net" + "os" + "path/filepath" + "testing" + "time" + + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/server" + "omnisocketgo/cmd/internal/transport" +) + +// TestUDPDialAndSendText 验证 UDP 客户端可以成功连接、注册并发送文本消息。 +func TestUDPDialAndSendText(t *testing.T) { + hubAddr := startUDPTestHub(t) + + clientA, err := DialUDP(hubAddr.String(), "peer-a") + if err != nil { + t.Fatalf("DialUDP(peer-a) error = %v", err) + } + defer clientA.Close() + + clientB, err := DialUDP(hubAddr.String(), "peer-b") + if err != nil { + t.Fatalf("DialUDP(peer-b) error = %v", err) + } + defer clientB.Close() + + // 等待注册被处理 + time.Sleep(50 * time.Millisecond) + + // peer-a 发送文本给 peer-b + if err := clientA.SendText("peer-b", "hello from udp"); err != nil { + t.Fatalf("SendText() error = %v", err) + } + + // peer-b 接收 + msg := receiveUDPClientMessage(t, clientB) + if msg.Type != protocol.MessageTypeText { + t.Fatalf("message type = %s, want text", msg.Type) + } + if string(msg.Body) != "hello from udp" { + t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from udp") + } +} + +// TestUDPClientID 验证 ID() 返回正确的 peer 标识。 +func TestUDPClientID(t *testing.T) { + hubAddr := startUDPTestHub(t) + + client, err := DialUDP(hubAddr.String(), "my-peer-id") + if err != nil { + t.Fatalf("DialUDP() error = %v", err) + } + defer client.Close() + + if got := client.ID(); got != "my-peer-id" { + t.Fatalf("ID() = %q, want %q", got, "my-peer-id") + } +} + +// TestUDPClientPersistMessage 验证 UDP 客户端可以将消息持久化到磁盘。 +func TestUDPClientPersistMessage(t *testing.T) { + hubAddr := startUDPTestHub(t) + + client, err := DialUDP(hubAddr.String(), "peer-persist") + if err != nil { + t.Fatalf("DialUDP() error = %v", err) + } + defer client.Close() + + inboxDir := t.TempDir() + + // 持久化文本消息 + textMsg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "sender", + To: "peer-persist", + Body: []byte("persisted text"), + } + + path, err := client.PersistMessage(textMsg, inboxDir) + if err != nil { + t.Fatalf("PersistMessage(text) error = %v", err) + } + if !filepath.IsAbs(path) && path == "" { + t.Fatalf("PersistMessage(text) returned empty path") + } + + // 持久化文件消息 + fileMsg := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 2, + From: "sender", + To: "peer-persist", + FileName: "test.bin", + Body: []byte{0x01, 0x02, 0x03}, + } + + filePath, err := client.PersistMessage(fileMsg, inboxDir) + if err != nil { + t.Fatalf("PersistMessage(file) error = %v", err) + } + + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("ReadFile(%s) error = %v", filePath, err) + } + if len(content) != 3 || content[0] != 0x01 { + t.Fatalf("file content mismatch: got %v", content) + } +} + +// TestUDPClientSendFile 验证 UDP 客户端可以发送文件消息。 +func TestUDPClientSendFile(t *testing.T) { + hubAddr := startUDPTestHub(t) + + clientA, err := DialUDP(hubAddr.String(), "peer-a") + if err != nil { + t.Fatalf("DialUDP(peer-a) error = %v", err) + } + defer clientA.Close() + + clientB, err := DialUDP(hubAddr.String(), "peer-b") + if err != nil { + t.Fatalf("DialUDP(peer-b) error = %v", err) + } + defer clientB.Close() + + time.Sleep(50 * time.Millisecond) + + fileBody := []byte{0xDE, 0xAD, 0xBE, 0xEF} + if err := clientA.SendFile("peer-b", "test.bin", fileBody); err != nil { + t.Fatalf("SendFile() error = %v", err) + } + + msg := receiveUDPClientMessage(t, clientB) + if msg.Type != protocol.MessageTypeFile { + t.Fatalf("message type = %s, want file", msg.Type) + } + if msg.FileName != "test.bin" { + t.Fatalf("file name = %q, want %q", msg.FileName, "test.bin") + } + if len(msg.Body) != 4 { + t.Fatalf("body length = %d, want 4", len(msg.Body)) + } +} + +// startUDPTestHub 创建并启动一个测试用 UDPHub。 +func startUDPTestHub(t *testing.T) *net.UDPAddr { + t.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr() error = %v", err) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("ListenUDP() error = %v", err) + } + + hub, err := server.NewUDPHub(conn) + if err != nil { + _ = conn.Close() + t.Fatalf("NewUDPHub() error = %v", err) + } + + go func() { + _ = hub.Serve() + }() + + t.Cleanup(func() { + _ = hub.Close() + }) + + return conn.LocalAddr().(*net.UDPAddr) +} + +// receiveUDPClientMessage 从 UDP 客户端接收一条消息,带超时。 +func receiveUDPClientMessage(t *testing.T, client *UDPClient) protocol.Message { + t.Helper() + + type result struct { + msg protocol.Message + err error + } + + ch := make(chan result, 1) + go func() { + msg, err := client.Receive() + ch <- result{msg: msg, err: err} + }() + + select { + case r := <-ch: + if r.err != nil { + t.Fatalf("Receive() error = %v", r.err) + } + return r.msg + case <-time.After(2 * time.Second): + t.Fatal("Receive() timed out after 2s") + return protocol.Message{} + } +} + +// Ensure transport package is used (needed for WithTXTimestampDebugLogger option). +var _ transport.TXTimestampDebugLogger = nil diff --git a/cmd/internal/server/udp_hub.go b/cmd/internal/server/udp_hub.go new file mode 100644 index 0000000..4275337 --- /dev/null +++ b/cmd/internal/server/udp_hub.go @@ -0,0 +1,185 @@ +package server + +import ( + "fmt" + "log" + "net" + "sync" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" +) + +// UDPOption 用于配置 UDPHub 的可选行为。 +type UDPOption func(*UDPHub) + +// WithUDPLogger 为 UDP hub 注入时延日志记录器。 +func WithUDPLogger(logger latencylog.Logger) UDPOption { + return func(hub *UDPHub) { + hub.logger = logger + } +} + +// UDPHub 管理通过 UDP 注册的 peer,并负责在它们之间转发消息。 +// 与 TCP Hub 不同,UDPHub 使用单个 net.UDPConn 与所有 peer 通信, +// 通过维护 peerID -> UDPAddr 映射表来寻址。 +type UDPHub struct { + mu sync.RWMutex + peers map[string]*net.UDPAddr // peerID -> 对端 UDP 地址 + addrs map[string]string // addr.String() -> peerID,用于反查 + + conn *transport.UDPConn + logger latencylog.Logger +} + +// NewUDPHub 创建一个新的 UDP 连接中心。 +func NewUDPHub(conn *net.UDPConn, opts ...UDPOption) (*UDPHub, error) { + hub := &UDPHub{ + peers: make(map[string]*net.UDPAddr), + addrs: make(map[string]string), + logger: latencylog.NoopLogger{}, + } + + for _, opt := range opts { + opt(hub) + } + + if hub.logger == nil { + hub.logger = latencylog.NoopLogger{} + } + + udpConn, err := transport.NewUDPConn( + conn, + nil, + transport.WithUDPLogger(hub.logger, latencylog.NodeRoleServer, "hub"), + ) + if err != nil { + return nil, fmt.Errorf("server: create udp transport conn: %w", err) + } + + hub.conn = udpConn + + return hub, nil +} + +// Serve 启动 UDP 接收主循环,持续读取消息并处理注册/转发。 +// 此方法会阻塞,直到底层连接关闭或发生不可恢复的错误。 +func (h *UDPHub) Serve() error { + return h.conn.ReceiveLoop(func(msg protocol.Message, addr *net.UDPAddr) error { + if err := h.handleMessage(msg, addr); err != nil { + log.Printf("udp hub: handle message from %s: %v", addr, err) + } + return nil // 不因为单条消息处理失败而退出主循环 + }) +} + +// HasPeer 返回给定 ID 是否已注册到 hub。 +func (h *UDPHub) HasPeer(peerID string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + + _, ok := h.peers[peerID] + return ok +} + +// handleMessage 处理从指定地址收到的消息。 +func (h *UDPHub) handleMessage(msg protocol.Message, addr *net.UDPAddr) error { + switch msg.Type { + case protocol.MessageTypeRegister: + return h.registerPeer(msg, addr) + case protocol.MessageTypeText, protocol.MessageTypeFile: + return h.forwardMessage(msg, addr) + case protocol.MessageTypeError: + return h.sendErrorTo(addr, msg.From, "peers cannot send error messages") + default: + peerID := h.lookupPeerID(addr) + if peerID == "" { + peerID = msg.From + } + return h.sendErrorTo(addr, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)) + } +} + +// registerPeer 处理 peer 的注册请求。 +func (h *UDPHub) registerPeer(msg protocol.Message, addr *net.UDPAddr) error { + peerID := msg.From + if peerID == "" { + return h.sendErrorTo(addr, "", "register: missing peer id") + } + + h.mu.Lock() + defer h.mu.Unlock() + + // 如果同一个 peerID 从新地址注册,更新地址映射(支持 peer 重启换端口)。 + if existingAddr, exists := h.peers[peerID]; exists { + // 清理旧地址的反查映射 + delete(h.addrs, existingAddr.String()) + } + + h.peers[peerID] = addr + h.addrs[addr.String()] = peerID + log.Printf("udp hub: registered peer %s from %s", peerID, addr) + return nil +} + +// forwardMessage 转发业务消息到目标 peer。 +func (h *UDPHub) forwardMessage(msg protocol.Message, senderAddr *net.UDPAddr) error { + // 通过来源地址反查发送者 peerID + senderID := h.lookupPeerID(senderAddr) + if senderID == "" { + return h.sendErrorTo(senderAddr, msg.From, "not registered; send register first") + } + + // server 覆盖 From,不信任客户端自报身份 + msg.From = senderID + + // 查找目标 peer 地址 + targetAddr := h.lookupAddr(msg.To) + if targetAddr == nil { + return h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("unknown target: %s", msg.To)) + } + + // 转发消息 + if err := h.conn.SendTo(msg, targetAddr); err != nil { + // 转发失败,通知发送方 + _ = h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("failed to forward to %s", msg.To)) + return fmt.Errorf("forward to %s at %s: %w", msg.To, targetAddr, err) + } + + return nil +} + +// lookupPeerID 通过 UDP 地址反查 peerID。 +func (h *UDPHub) lookupPeerID(addr *net.UDPAddr) string { + h.mu.RLock() + defer h.mu.RUnlock() + + return h.addrs[addr.String()] +} + +// lookupAddr 通过 peerID 查找 UDP 地址。 +func (h *UDPHub) lookupAddr(peerID string) *net.UDPAddr { + h.mu.RLock() + defer h.mu.RUnlock() + + return h.peers[peerID] +} + +// sendErrorTo 向指定地址发送错误消息。 +func (h *UDPHub) sendErrorTo(addr *net.UDPAddr, to, message string) error { + if to == "" { + to = "unknown" + } + return h.conn.SendTo(protocol.Message{ + Type: protocol.MessageTypeError, + From: protocol.ServerPeerID, + To: to, + Body: []byte(message), + }, addr) +} + +// Close 关闭底层 UDP 连接。 +func (h *UDPHub) Close() error { + return h.conn.Close() +} diff --git a/cmd/internal/server/udp_hub_test.go b/cmd/internal/server/udp_hub_test.go new file mode 100644 index 0000000..f3db724 --- /dev/null +++ b/cmd/internal/server/udp_hub_test.go @@ -0,0 +1,238 @@ +package server + +import ( + "net" + "testing" + "time" + + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" +) + +// TestUDPHubRegisterAndForward 验证 peer 注册后可以互相转发消息。 +func TestUDPHubRegisterAndForward(t *testing.T) { + hub, hubAddr := startUDPHub(t) + _ = hub + + peerA := dialUDPPeer(t, hubAddr) + peerB := dialUDPPeer(t, hubAddr) + + // 注册 peer-a + sendUDPMessage(t, peerA, protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-a", + To: protocol.ServerPeerID, + }) + + // 注册 peer-b + sendUDPMessage(t, peerB, protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-b", + To: protocol.ServerPeerID, + }) + + // 等待注册被处理 + time.Sleep(50 * time.Millisecond) + + // peer-a 发送消息给 peer-b + sendUDPMessage(t, peerA, protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello from peer-a"), + }) + + // peer-b 应该收到消息 + msg := receiveUDPMessage(t, peerB) + if msg.Type != protocol.MessageTypeText { + t.Fatalf("message type = %s, want text", msg.Type) + } + if msg.From != "peer-a" { + t.Fatalf("message from = %s, want peer-a", msg.From) + } + if msg.To != "peer-b" { + t.Fatalf("message to = %s, want peer-b", msg.To) + } + if string(msg.Body) != "hello from peer-a" { + t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-a") + } +} + +// TestUDPHubRejectsUnregistered 验证未注册的 peer 发送业务消息会收到错误。 +func TestUDPHubRejectsUnregistered(t *testing.T) { + _, hubAddr := startUDPHub(t) + + peer := dialUDPPeer(t, hubAddr) + + // 直接发送业务消息而不注册 + sendUDPMessage(t, peer, protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("should fail"), + }) + + // 应该收到错误响应 + msg := receiveUDPMessage(t, peer) + if msg.Type != protocol.MessageTypeError { + t.Fatalf("message type = %s, want error", msg.Type) + } +} + +// TestUDPHubRejectsUnknownTarget 验证发送到不存在的目标会返回错误。 +func TestUDPHubRejectsUnknownTarget(t *testing.T) { + _, hubAddr := startUDPHub(t) + + peer := dialUDPPeer(t, hubAddr) + + // 注册 + sendUDPMessage(t, peer, protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-a", + To: protocol.ServerPeerID, + }) + + time.Sleep(50 * time.Millisecond) + + // 发送到不存在的目标 + sendUDPMessage(t, peer, protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-nonexistent", + Body: []byte("should fail"), + }) + + // 应该收到错误响应 + msg := receiveUDPMessage(t, peer) + if msg.Type != protocol.MessageTypeError { + t.Fatalf("message type = %s, want error", msg.Type) + } +} + +// TestUDPHubOverridesFromField 验证 server 会覆盖消息的 From 字段。 +func TestUDPHubOverridesFromField(t *testing.T) { + _, hubAddr := startUDPHub(t) + + peerA := dialUDPPeer(t, hubAddr) + peerB := dialUDPPeer(t, hubAddr) + + sendUDPMessage(t, peerA, protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-a", + To: protocol.ServerPeerID, + }) + sendUDPMessage(t, peerB, protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-b", + To: protocol.ServerPeerID, + }) + + time.Sleep(50 * time.Millisecond) + + // peer-a 伪造 From 为 "fake-id" + sendUDPMessage(t, peerA, protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "fake-id", + To: "peer-b", + Body: []byte("spoofed"), + }) + + msg := receiveUDPMessage(t, peerB) + // server 应该用实际注册的 "peer-a" 覆盖 From + if msg.From != "peer-a" { + t.Fatalf("message from = %s, want peer-a (server should override)", msg.From) + } +} + +// startUDPHub 创建并启动一个 UDPHub,返回 hub 和监听地址。 +func startUDPHub(t *testing.T) (*UDPHub, *net.UDPAddr) { + t.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr() error = %v", err) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("ListenUDP() error = %v", err) + } + + hub, err := NewUDPHub(conn) + if err != nil { + _ = conn.Close() + t.Fatalf("NewUDPHub() error = %v", err) + } + + go func() { + _ = hub.Serve() + }() + + t.Cleanup(func() { + _ = hub.Close() + }) + + return hub, conn.LocalAddr().(*net.UDPAddr) +} + +// dialUDPPeer 创建一个连接到指定地址的 UDP transport 连接。 +func dialUDPPeer(t *testing.T, serverAddr *net.UDPAddr) *transport.UDPConn { + t.Helper() + + raw, err := net.DialUDP("udp", nil, serverAddr) + if err != nil { + t.Fatalf("DialUDP() error = %v", err) + } + + conn, err := transport.NewUDPConn(raw, nil) + if err != nil { + _ = raw.Close() + t.Fatalf("NewUDPConn() error = %v", err) + } + + t.Cleanup(func() { + _ = conn.Close() + }) + + return conn +} + +// sendUDPMessage 发送一条 UDP 消息。 +func sendUDPMessage(t *testing.T, conn *transport.UDPConn, msg protocol.Message) { + t.Helper() + + if err := conn.Send(msg); err != nil { + t.Fatalf("Send() error = %v", err) + } +} + +// receiveUDPMessage 接收一条 UDP 消息,带超时。 +func receiveUDPMessage(t *testing.T, conn *transport.UDPConn) protocol.Message { + t.Helper() + + type result struct { + msg protocol.Message + err error + } + + ch := make(chan result, 1) + go func() { + msg, _, err := conn.Receive() + ch <- result{msg: msg, err: err} + }() + + select { + case r := <-ch: + if r.err != nil { + t.Fatalf("Receive() error = %v", r.err) + } + return r.msg + case <-time.After(2 * time.Second): + t.Fatal("Receive() timed out after 2s") + return protocol.Message{} + } +} diff --git a/cmd/internal/transport/udp.go b/cmd/internal/transport/udp.go new file mode 100644 index 0000000..07bba36 --- /dev/null +++ b/cmd/internal/transport/udp.go @@ -0,0 +1,141 @@ +package transport + +import ( + "fmt" + "net" + "sync" + "syscall" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +// UDPConn 是对 UDP 连接的轻量封装。 +// server 侧:共享同一个 net.UDPConn,Send 时通过 peerAddr 指定对端地址。 +// peer 侧:独立的 net.UDPConn,已通过 Dial 连接到 server,Send 直接写即可。 +type UDPConn struct { + conn *net.UDPConn + peerAddr *net.UDPAddr // server 侧为对端地址;peer 侧为 nil(连接模式下直接 Write) + raw syscall.RawConn // 底层 syscall 句柄,用于 Linux socket timestamping + + logger latencylog.Logger + txTimestampDebugLogger TXTimestampDebugLogger + nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer" + nodeID string // 日志中记录的节点 ID + writeMu sync.Mutex // 保护 Send 的互斥锁 + closeOnce sync.Once + closeErr error +} + +// UDPOption 用于为 UDPConn 注入可选行为。 +type UDPOption func(*UDPConn) + +// WithUDPLogger 为 UDP 连接注入业务消息日志上下文。 +func WithUDPLogger(logger latencylog.Logger, nodeRole, nodeID string) UDPOption { + return func(conn *UDPConn) { + conn.logger = logger + conn.nodeRole = nodeRole + conn.nodeID = nodeID + } +} + +// WithUDPTXTimestampDebugLogger 为 UDP 连接注入可选的 TX errqueue 调试日志器。 +func WithUDPTXTimestampDebugLogger(logger TXTimestampDebugLogger) UDPOption { + return func(conn *UDPConn) { + conn.txTimestampDebugLogger = logger + } +} + +// NewUDPConn 创建 UDP transport 连接封装。 +// peerAddr 为 nil 时表示 peer 侧已连接模式(conn 已 Dial 到 server)。 +// peerAddr 非 nil 时表示 server 侧,Send 时需要指定目标地址。 +func NewUDPConn(conn *net.UDPConn, peerAddr *net.UDPAddr, opts ...UDPOption) (*UDPConn, error) { + udpConn := &UDPConn{ + conn: conn, + peerAddr: peerAddr, + logger: latencylog.NoopLogger{}, + } + + for _, opt := range opts { + opt(udpConn) + } + + if udpConn.logger == nil { + udpConn.logger = latencylog.NoopLogger{} + } + + if err := udpConn.initUDPLinuxTimestamping(); err != nil { + return nil, err + } + + return udpConn, nil +} + +// Send 将一条协议消息编码为 UDP 数据报并发送。 +// 多个 goroutine 可以并发调用,内部会串行化写入。 +func (c *UDPConn) Send(msg protocol.Message) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg) + + if err := c.sendMessageLinux(msg); err != nil { + return fmt.Errorf("transport: udp send message: %w", err) + } + + latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg) + + return nil +} + +// SendTo 将一条协议消息编码为 UDP 数据报并发送到指定地址。 +// 主要用于 server 侧向特定 peer 发送消息。 +func (c *UDPConn) SendTo(msg protocol.Message, addr *net.UDPAddr) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg) + + if err := c.sendMessageToLinux(msg, addr); err != nil { + return fmt.Errorf("transport: udp send message to %s: %w", addr, err) + } + + latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg) + + return nil +} + +// Receive 从 UDP 连接读取一条完整协议消息。 +// 返回解码后的消息和来源地址(peer 侧来源地址始终为 server 地址)。 +func (c *UDPConn) Receive() (protocol.Message, *net.UDPAddr, error) { + msg, addr, err := c.receiveMessageLinux() + if err != nil { + return protocol.Message{}, nil, fmt.Errorf("transport: udp receive message: %w", err) + } + + return msg, addr, nil +} + +// ReceiveLoop 持续从 UDP 连接读取消息并交给 handler 处理。 +// handler 的第二个参数是消息来源地址。 +func (c *UDPConn) ReceiveLoop(handler func(protocol.Message, *net.UDPAddr) error) error { + for { + msg, addr, err := c.Receive() + if err != nil { + return fmt.Errorf("transport: udp receive loop read: %w", err) + } + + if err := handler(msg, addr); err != nil { + return fmt.Errorf("transport: udp receive loop handler: %w", err) + } + } +} + +// Close 关闭底层 UDP 连接,保证重复调用安全。 +// 注意:server 侧多个 UDPConn 共享同一个 net.UDPConn 时, +// 只应由 UDPHub 负责关闭底层连接,不应通过此方法关闭。 +func (c *UDPConn) Close() error { + c.closeOnce.Do(func() { + c.closeErr = c.conn.Close() + }) + + return c.closeErr +} diff --git a/cmd/internal/transport/udp_linux.go b/cmd/internal/transport/udp_linux.go new file mode 100644 index 0000000..7271594 --- /dev/null +++ b/cmd/internal/transport/udp_linux.go @@ -0,0 +1,358 @@ +//go:build linux + +package transport + +import ( + "errors" + "fmt" + "net" + "syscall" + "time" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +// UDP 接收缓冲区大小,足以容纳 MaxFrameSize 加上协议头。 +const udpReceiveBufferSize = protocol.MaxFrameSize + 1024 + +// initUDPLinuxTimestamping 拿到底层 fd,并打开 Linux timestamping。 +func (c *UDPConn) initUDPLinuxTimestamping() error { + rawConn, err := c.conn.SyscallConn() + if err != nil || rawConn == nil { + if err != nil { + return fmt.Errorf("transport: udp get syscall conn: %w", err) + } + return fmt.Errorf("transport: udp missing syscall conn") + } + + // UDP 不需要 OPT_ID_TCP,使用标准的 OPT_ID 即可。 + flagCandidates := []int{ + linuxSOFTimestampingTXSched | + linuxSOFTimestampingTXSoftware | + linuxSOFTimestampingRXSoftware | + linuxSOFTimestampingSoftware | + linuxSOFTimestampingOptID | + linuxSOFTimestampingOptTSONLY, + linuxSOFTimestampingTXSched | + linuxSOFTimestampingTXSoftware | + linuxSOFTimestampingRXSoftware | + linuxSOFTimestampingSoftware | + linuxSOFTimestampingOptTSONLY, + } + + var lastErr error + for _, flags := range flagCandidates { + err := rawConn.Control(func(fd uintptr) { + lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags) + }) + if err != nil { + return err + } + if lastErr == nil { + c.raw = rawConn + return nil + } + if !errors.Is(lastErr, syscall.EINVAL) { + return lastErr + } + } + + return lastErr +} + +// sendMessageLinux 编码消息并通过 UDP 发送,采集 TX 时间戳。 +func (c *UDPConn) sendMessageLinux(msg protocol.Message) error { + payload, err := protocol.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("protocol: encode message: %w", err) + } + + readIndex := 0 + c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex) + + if c.peerAddr != nil { + if err := c.udpSendTo(payload, c.peerAddr); err != nil { + return err + } + } else { + if err := c.udpSend(payload); err != nil { + return err + } + } + + c.collectAndLogUDPTXTimestampEvents(msg, &readIndex) + return nil +} + +// sendMessageToLinux 编码消息并通过 UDP 发送到指定地址,采集 TX 时间戳。 +func (c *UDPConn) sendMessageToLinux(msg protocol.Message, addr *net.UDPAddr) error { + payload, err := protocol.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("protocol: encode message: %w", err) + } + + readIndex := 0 + c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex) + + if err := c.udpSendTo(payload, addr); err != nil { + return err + } + + c.collectAndLogUDPTXTimestampEvents(msg, &readIndex) + return nil +} + +// udpSend 通过已连接的 UDP socket 发送数据。 +func (c *UDPConn) udpSend(payload []byte) error { + if c.raw != nil { + return c.udpSendmsgRaw(payload, nil) + } + _, err := c.conn.Write(payload) + return err +} + +// udpSendTo 通过 UDP socket 发送数据到指定地址。 +func (c *UDPConn) udpSendTo(payload []byte, addr *net.UDPAddr) error { + if c.raw != nil { + sa := udpAddrToSockaddr(addr) + if sa != nil { + return c.udpSendmsgRaw(payload, sa) + } + } + _, err := c.conn.WriteToUDP(payload, addr) + return err +} + +// udpSendmsgRaw 通过 sendmsg syscall 发送 UDP 数据。 +func (c *UDPConn) udpSendmsgRaw(payload []byte, to syscall.Sockaddr) error { + var opErr error + + for { + err := c.raw.Control(func(fd uintptr) { + opErr = syscall.Sendmsg(int(fd), payload, nil, to, 0) + }) + if err != nil { + return err + } + if opErr == nil { + return nil + } + if isWouldBlock(opErr) { + time.Sleep(linuxDataPollInterval) + continue + } + return opErr + } +} + +// receiveMessageLinux 从 UDP 连接读取一条完整消息,并记录 RX 时间戳。 +func (c *UDPConn) receiveMessageLinux() (protocol.Message, *net.UDPAddr, error) { + payload, addr, rxTimestamp, err := c.udpRecvFrom() + if err != nil { + return protocol.Message{}, nil, fmt.Errorf("protocol: udp read: %w", err) + } + + msg, err := protocol.DecodeMessage(payload) + if err != nil { + return protocol.Message{}, nil, fmt.Errorf("protocol: decode message: %w", err) + } + + if rxTimestamp > 0 { + latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg) + } + + return msg, addr, nil +} + +// udpRecvFrom 从 UDP socket 接收一个完整数据报,返回数据、来源地址和 RX 时间戳。 +func (c *UDPConn) udpRecvFrom() ([]byte, *net.UDPAddr, int64, error) { + if c.raw != nil { + return c.udpRecvmsgRaw() + } + + buf := make([]byte, udpReceiveBufferSize) + n, addr, err := c.conn.ReadFromUDP(buf) + if err != nil { + return nil, nil, 0, err + } + + return buf[:n], addr, 0, nil +} + +// udpRecvmsgRaw 通过 recvmsg syscall 接收 UDP 数据,同时采集 RX 时间戳。 +func (c *UDPConn) udpRecvmsgRaw() ([]byte, *net.UDPAddr, int64, error) { + for { + var ( + n int + rxTimeNS int64 + from syscall.Sockaddr + opErr error + ) + + buf := make([]byte, udpReceiveBufferSize) + err := c.raw.Control(func(fd uintptr) { + oob := make([]byte, linuxTimestampControlBufferSize) + readN, oobN, _, sa, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0) + if recvErr != nil { + opErr = recvErr + return + } + n = readN + from = sa + rxTimeNS = parseRXTimestampControlMessages(oob[:oobN]) + }) + if err != nil { + return nil, nil, 0, err + } + if opErr != nil { + if isWouldBlock(opErr) { + time.Sleep(linuxDataPollInterval) + continue + } + return nil, nil, 0, opErr + } + + addr := sockaddrToUDPAddr(from) + return buf[:n], addr, rxTimeNS, nil + } +} + +// collectAndLogUDPTXTimestampEvents 采集并记录 UDP 发送的 TX 时间戳事件。 +func (c *UDPConn) collectAndLogUDPTXTimestampEvents(msg protocol.Message, readIndex *int) { + timestamps := c.collectUDPTXTimestampEvents(msg, readIndex) + + if ts, ok := timestamps[latencylog.EventATXSched]; ok { + latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg) + } + if ts, ok := timestamps[latencylog.EventATXSoftware]; ok { + latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg) + } +} + +// collectUDPTXTimestampEvents 在 errqueue 中等待 TX 时间戳。 +func (c *UDPConn) collectUDPTXTimestampEvents(msg protocol.Message, readIndex *int) map[string]int64 { + if c.raw == nil { + return nil + } + + deadline := time.Now().Add(linuxTXTimestampWaitTimeout) + timestamps := make(map[string]int64, 2) + + for time.Now().Before(deadline) { + event, err := c.recvUDPTXTimestampOnce() + if err != nil { + if isWouldBlock(err) { + time.Sleep(linuxTXTimestampPollInterval) + continue + } + break + } + if event.EventName == "" || event.TSUnixNano <= 0 { + continue + } + *readIndex++ + + if isBusinessTXTimestampEventName(event.EventName) { + if _, exists := timestamps[event.EventName]; !exists { + timestamps[event.EventName] = event.TSUnixNano + } + } + + if hasCompleteTXTimestampPair(timestamps) { + break + } + } + + c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePostSelectDrain, readIndex) + return timestamps +} + +// drainPendingUDPTXTimestampEvents 清空 errqueue 中残留的时间戳事件。 +func (c *UDPConn) drainPendingUDPTXTimestampEvents(msg protocol.Message, phase string, readIndex *int) { + if c.raw == nil { + return + } + + for { + event, err := c.recvUDPTXTimestampOnce() + if err != nil { + return + } + if event.EventName == "" || event.TSUnixNano <= 0 { + continue + } + *readIndex++ + } +} + +// recvUDPTXTimestampOnce 从 errqueue 读一次时间戳事件。 +func (c *UDPConn) recvUDPTXTimestampOnce() (txTimestampEvent, error) { + var ( + event txTimestampEvent + opErr error + ) + + err := c.raw.Control(func(fd uintptr) { + oob := make([]byte, linuxTimestampControlBufferSize) + _, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT) + if recvErr != nil { + opErr = recvErr + return + } + event, _ = parseTXTimestampControlMessages(oob[:oobn]) + }) + if err != nil { + return txTimestampEvent{}, err + } + if opErr != nil { + return txTimestampEvent{}, opErr + } + + return event, nil +} + +// udpAddrToSockaddr 将 net.UDPAddr 转换为 syscall.Sockaddr。 +func udpAddrToSockaddr(addr *net.UDPAddr) syscall.Sockaddr { + if ip4 := addr.IP.To4(); ip4 != nil { + sa := &syscall.SockaddrInet4{Port: addr.Port} + copy(sa.Addr[:], ip4) + return sa + } + if ip6 := addr.IP.To16(); ip6 != nil { + sa := &syscall.SockaddrInet6{Port: addr.Port} + copy(sa.Addr[:], ip6) + return sa + } + return nil +} + +// sockaddrToUDPAddr 将 syscall.Sockaddr 转换为 net.UDPAddr。 +func sockaddrToUDPAddr(sa syscall.Sockaddr) *net.UDPAddr { + switch addr := sa.(type) { + case *syscall.SockaddrInet4: + return &net.UDPAddr{ + IP: net.IP(addr.Addr[:]), + Port: addr.Port, + } + case *syscall.SockaddrInet6: + return &net.UDPAddr{ + IP: net.IP(addr.Addr[:]), + Port: addr.Port, + Zone: zoneToString(addr.ZoneId), + } + default: + return nil + } +} + +func zoneToString(zone uint32) string { + if zone == 0 { + return "" + } + iface, err := net.InterfaceByIndex(int(zone)) + if err != nil { + return "" + } + return iface.Name +} diff --git a/cmd/internal/transport/udp_linux_test.go b/cmd/internal/transport/udp_linux_test.go new file mode 100644 index 0000000..71c3496 --- /dev/null +++ b/cmd/internal/transport/udp_linux_test.go @@ -0,0 +1,105 @@ +//go:build linux + +package transport + +import ( + "net" + "reflect" + "testing" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +// TestUDPLinuxTimestampingRecordsKernelEvents 验证 UDP 在 Linux 上能正确采集内核时间戳。 +func TestUDPLinuxTimestampingRecordsKernelEvents(t *testing.T) { + tests := []struct { + name string + msg protocol.Message + }{ + { + name: "text", + msg: protocol.Message{ + Type: protocol.MessageTypeText, + ID: 41, + From: "peer-a", + To: "peer-b", + Body: []byte("hello over udp"), + }, + }, + { + name: "file", + msg: protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 42, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x00, 0x01, 0x02, 0xff}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + senderLogger := &recordingLogger{} + receiverLogger := &recordingLogger{} + + // 创建 server 侧监听 + serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr() error = %v", err) + } + serverRaw, err := net.ListenUDP("udp", serverAddr) + if err != nil { + t.Fatalf("ListenUDP() error = %v", err) + } + receiver, err := NewUDPConn( + serverRaw, + nil, + WithUDPLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"), + ) + if err != nil { + _ = serverRaw.Close() + t.Fatalf("NewUDPConn(receiver) error = %v", err) + } + t.Cleanup(func() { _ = receiver.Close() }) + + // 创建 peer 侧连接 + peerRaw, err := net.DialUDP("udp", nil, serverRaw.LocalAddr().(*net.UDPAddr)) + if err != nil { + t.Fatalf("DialUDP() error = %v", err) + } + sender, err := NewUDPConn( + peerRaw, + nil, + WithUDPLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"), + ) + if err != nil { + _ = peerRaw.Close() + t.Fatalf("NewUDPConn(sender) error = %v", err) + } + t.Cleanup(func() { _ = sender.Close() }) + + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(tt.msg) + }() + + got, _, err := receiver.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("Send() error = %v", err) + } + if !reflect.DeepEqual(got, tt.msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg) + } + + assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID) + assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID) + assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID) + }) + } +} diff --git a/cmd/internal/transport/udp_test.go b/cmd/internal/transport/udp_test.go new file mode 100644 index 0000000..bc136bf --- /dev/null +++ b/cmd/internal/transport/udp_test.go @@ -0,0 +1,358 @@ +package transport + +import ( + "net" + "reflect" + "strings" + "sync" + "testing" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +// TestUDPSendReceiveMessage 验证 UDP transport 可以正常收发 text 和 file 消息。 +func TestUDPSendReceiveMessage(t *testing.T) { + tests := []struct { + name string + msg protocol.Message + }{ + { + name: "text", + msg: protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello udp"), + }, + }, + { + name: "file", + msg: protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 2, + From: "peer-a", + To: "peer-b", + FileName: "data.bin", + Body: []byte{0x00, 0x10, 0xff}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sender, receiver := newUDPConnPair(t, nil, nil) + + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(tt.msg) + }() + + got, _, err := receiver.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("Send() error = %v", err) + } + + if !reflect.DeepEqual(got, tt.msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg) + } + }) + } +} + +// TestUDPSendLogsHandoffEvents 验证 UDP Send 会记录 handoff 事件。 +func TestUDPSendLogsHandoffEvents(t *testing.T) { + logger := &recordingLogger{} + sender, receiver := newUDPConnPair( + t, + []UDPOption{WithUDPLogger(logger, latencylog.NodeRolePeer, "peer-a")}, + nil, + ) + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 7, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(msg) + }() + + got, _, err := receiver.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("Send() error = %v", err) + } + if !reflect.DeepEqual(got, msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, msg) + } + + events := logger.Events() + if len(events) < 2 { + t.Fatalf("event count = %d, want at least 2", len(events)) + } + if events[0].Event != latencylog.EventSendHandoffBegin { + t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin) + } + // 最后一个事件应该是 SendHandoffEnd + lastEvent := events[len(events)-1] + if lastEvent.Event != latencylog.EventSendHandoffEnd { + t.Fatalf("last event = %q, want %q", lastEvent.Event, latencylog.EventSendHandoffEnd) + } +} + +// TestUDPReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。 +func TestUDPReceiveLoopDeliversMessages(t *testing.T) { + sender, receiver := newUDPConnPair(t, nil, nil) + + want := []protocol.Message{ + { + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + }, + { + Type: protocol.MessageTypeFile, + ID: 2, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x01, 0x02, 0x03}, + }, + } + + var ( + mu sync.Mutex + got []protocol.Message + ) + loopErr := make(chan error, 1) + go func() { + loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error { + mu.Lock() + defer mu.Unlock() + got = append(got, msg) + if len(got) >= len(want) { + return nil + } + return nil + }) + }() + + for _, msg := range want { + if err := sender.Send(msg); err != nil { + t.Fatalf("Send() error = %v", err) + } + } + + // 关闭发送端,ReceiveLoop 会因读取错误退出 + if err := sender.Close(); err != nil { + t.Fatalf("sender.Close() error = %v", err) + } + + err := <-loopErr + if err == nil { + t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close") + } + if !strings.Contains(err.Error(), "udp receive loop read") { + t.Fatalf("ReceiveLoop() error = %v, want read context", err) + } + + mu.Lock() + defer mu.Unlock() + if !reflect.DeepEqual(got, want) { + t.Fatalf("received messages mismatch: got %+v want %+v", got, want) + } +} + +// TestUDPCloseIsIdempotent 验证 Close 可以安全地被重复调用。 +func TestUDPCloseIsIdempotent(t *testing.T) { + conn, peer := newUDPConnPair(t, nil, nil) + + if err := conn.Close(); err != nil { + t.Fatalf("Close(first) error = %v", err) + } + if err := conn.Close(); err != nil { + t.Fatalf("Close(second) error = %v, want nil", err) + } + _ = peer.Close() +} + +// TestUDPSendToMessage 验证 SendTo 可以向指定地址发送消息。 +func TestUDPSendToMessage(t *testing.T) { + serverConn := newUDPListener(t) + peerConn := newUDPDialed(t, serverConn.conn.LocalAddr().(*net.UDPAddr)) + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "server", + Body: []byte("hello sendto"), + } + + // peer 发送消息到 server + sendErr := make(chan error, 1) + go func() { + sendErr <- peerConn.Send(msg) + }() + + got, addr, err := serverConn.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("Send() error = %v", err) + } + if !reflect.DeepEqual(got, msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, msg) + } + + // server 用 SendTo 回复到 peer 地址 + reply := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 2, + From: "server", + To: "peer-a", + Body: []byte("reply"), + } + + sendErr2 := make(chan error, 1) + go func() { + sendErr2 <- serverConn.SendTo(reply, addr) + }() + + gotReply, _, err := peerConn.Receive() + if err != nil { + t.Fatalf("peer Receive() error = %v", err) + } + if err := <-sendErr2; err != nil { + t.Fatalf("SendTo() error = %v", err) + } + if !reflect.DeepEqual(gotReply, reply) { + t.Fatalf("reply mismatch: got %+v want %+v", gotReply, reply) + } +} + +// newUDPConnPair 创建一对互相连接的 UDP transport 连接,用于测试。 +func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOption) (*UDPConn, *UDPConn) { + t.Helper() + + // 创建两个 UDP socket,通过 Dial 互相连接 + addr1, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr() error = %v", err) + } + addr2, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr() error = %v", err) + } + + conn1, err := net.ListenUDP("udp", addr1) + if err != nil { + t.Fatalf("ListenUDP(1) error = %v", err) + } + conn2, err := net.ListenUDP("udp", addr2) + if err != nil { + _ = conn1.Close() + t.Fatalf("ListenUDP(2) error = %v", err) + } + + // 用 Dial 模式连接对端 + senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr)) + if err != nil { + _ = conn1.Close() + _ = conn2.Close() + t.Fatalf("DialUDP(sender) error = %v", err) + } + _ = conn1.Close() // 不再需要 conn1 + + receiverRaw, err := net.DialUDP("udp", conn2.LocalAddr().(*net.UDPAddr), senderRaw.LocalAddr().(*net.UDPAddr)) + if err != nil { + _ = senderRaw.Close() + _ = conn2.Close() + t.Fatalf("DialUDP(receiver) error = %v", err) + } + _ = conn2.Close() // 不再需要 conn2 + + sender, err := NewUDPConn(senderRaw, nil, senderOpts...) + if err != nil { + _ = senderRaw.Close() + _ = receiverRaw.Close() + t.Fatalf("NewUDPConn(sender) error = %v", err) + } + + receiver, err := NewUDPConn(receiverRaw, nil, receiverOpts...) + if err != nil { + _ = sender.Close() + _ = receiverRaw.Close() + t.Fatalf("NewUDPConn(receiver) error = %v", err) + } + + t.Cleanup(func() { + _ = sender.Close() + _ = receiver.Close() + }) + + return sender, receiver +} + +// newUDPListener 创建一个监听模式的 UDP 连接,用于测试 server 场景。 +func newUDPListener(t *testing.T) *UDPConn { + t.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr() error = %v", err) + } + + raw, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("ListenUDP() error = %v", err) + } + + conn, err := NewUDPConn(raw, nil) + if err != nil { + _ = raw.Close() + t.Fatalf("NewUDPConn() error = %v", err) + } + + t.Cleanup(func() { + _ = conn.Close() + }) + + return conn +} + +// newUDPDialed 创建一个已连接到指定地址的 UDP 连接,用于测试 peer 场景。 +func newUDPDialed(t *testing.T, serverAddr *net.UDPAddr) *UDPConn { + t.Helper() + + raw, err := net.DialUDP("udp", nil, serverAddr) + if err != nil { + t.Fatalf("DialUDP() error = %v", err) + } + + conn, err := NewUDPConn(raw, nil) + if err != nil { + _ = raw.Close() + t.Fatalf("NewUDPConn() error = %v", err) + } + + t.Cleanup(func() { + _ = conn.Close() + }) + + return conn +} diff --git a/cmd/udppeer/interactive.go b/cmd/udppeer/interactive.go new file mode 100644 index 0000000..fb4f08d --- /dev/null +++ b/cmd/udppeer/interactive.go @@ -0,0 +1,70 @@ +package main + +import ( + "errors" + "fmt" + "strings" +) + +var errUDPEmptyCommand = errors.New("interactive command is empty") + +type udpInteractiveCommand struct { + name string + to string + value string +} + +func parseUDPInteractiveCommand(line string) (udpInteractiveCommand, error) { + commandName, rest, ok := cutUDPField(strings.TrimSpace(line)) + if !ok { + return udpInteractiveCommand{}, errUDPEmptyCommand + } + + switch strings.ToLower(commandName) { + case "help", "h", "?": + return udpInteractiveCommand{name: "help"}, nil + case "quit", "exit": + return udpInteractiveCommand{name: "quit"}, nil + case "text": + to, body, err := parseUDPTargetValue(rest, "text") + if err != nil { + return udpInteractiveCommand{}, err + } + return udpInteractiveCommand{name: "text", to: to, value: body}, nil + case "file": + to, path, err := parseUDPTargetValue(rest, "file") + if err != nil { + return udpInteractiveCommand{}, err + } + return udpInteractiveCommand{name: "file", to: to, value: path}, nil + default: + return udpInteractiveCommand{}, fmt.Errorf("unknown command %q; type help for usage", commandName) + } +} + +func parseUDPTargetValue(rest, commandName string) (string, string, error) { + to, value, ok := cutUDPField(strings.TrimSpace(rest)) + if !ok { + return "", "", fmt.Errorf("%s command requires a target peer and payload", commandName) + } + if strings.TrimSpace(value) == "" { + return "", "", fmt.Errorf("%s command requires a non-empty payload", commandName) + } + + return to, strings.TrimSpace(value), nil +} + +func cutUDPField(input string) (string, string, bool) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return "", "", false + } + + for i, r := range trimmed { + if r == ' ' || r == '\t' { + return trimmed[:i], strings.TrimSpace(trimmed[i+1:]), true + } + } + + return trimmed, "", true +} diff --git a/cmd/udppeer/main.go b/cmd/udppeer/main.go new file mode 100644 index 0000000..18f9345 --- /dev/null +++ b/cmd/udppeer/main.go @@ -0,0 +1,194 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "io" + "log" + "os" + + "omnisocketgo/cmd/internal/latencylog" + peerpkg "omnisocketgo/cmd/internal/peer" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" +) + +func main() { + peerID := flag.String("id", "peer-a", "peer identity") + serverAddr := flag.String("server", "127.0.0.1:9001", "UDP server address") + targetPeer := flag.String("to", "", "optional target peer for one outgoing message") + text := flag.String("text", "", "optional text to send after connecting") + filePath := flag.String("file", "", "optional file path to send after connecting") + bindIP := flag.String("bind-ip", "", "optional local source IP used when dialing the server") + inboxDir := flag.String("inbox-dir", "inbox", "directory used to persist received text and file messages") + logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs") + txTimestampDebugLogPath := flag.String("tx-ts-debug-log", "", "optional JSONL file path for TX errqueue debug records") + interactive := flag.Bool("interactive", true, "enable interactive REPL for repeated text/file sends on the same connection") + flag.Parse() + + clientOptions := make([]peerpkg.Option, 0, 4) + if *logPath != "" { + logger, err := latencylog.NewJSONLLogger(*logPath) + if err != nil { + log.Fatalf("create latency logger %s: %v", *logPath, err) + } + defer logger.Close() + clientOptions = append(clientOptions, peerpkg.WithLogger(logger)) + } + if *txTimestampDebugLogPath != "" { + logger, err := transport.NewJSONLTXTimestampDebugLogger(*txTimestampDebugLogPath) + if err != nil { + log.Fatalf("create tx timestamp debug logger %s: %v", *txTimestampDebugLogPath, err) + } + defer logger.Close() + clientOptions = append(clientOptions, peerpkg.WithTXTimestampDebugLogger(logger)) + } + if *bindIP != "" { + clientOptions = append(clientOptions, peerpkg.WithBindIP(*bindIP)) + } + + client, err := peerpkg.DialUDP(*serverAddr, *peerID, clientOptions...) + if err != nil { + log.Fatalf("dial udp server %s: %v", *serverAddr, err) + } + defer client.Close() + + log.Printf("connected to %s as %s (UDP)", *serverAddr, client.ID()) + + receiveErr := make(chan error, 1) + go func() { + receiveErr <- client.ReceiveLoop(func(msg protocol.Message) error { + switch msg.Type { + case protocol.MessageTypeText: + path, err := client.PersistMessage(msg, *inboxDir) + if err != nil { + return err + } + log.Printf("received text from %s to %s and persisted to %s", msg.From, msg.To, path) + case protocol.MessageTypeFile: + path, err := client.PersistMessage(msg, *inboxDir) + if err != nil { + return err + } + log.Printf("received file from %s to %s: %s (%d bytes) -> %s", msg.From, msg.To, msg.FileName, len(msg.Body), path) + case protocol.MessageTypeError: + log.Printf("received %s from %s to %s: %s", msg.Type, msg.From, msg.To, string(msg.Body)) + default: + log.Printf("received unexpected message type %s from %s", msg.Type, msg.From) + } + return nil + }) + }() + + if *text != "" && *filePath != "" { + log.Fatal("only one of -text or -file may be specified") + } + + if (*text != "" || *filePath != "") && *targetPeer == "" { + log.Fatal("flag -to is required when sending text or file") + } + + if *targetPeer != "" && *text != "" { + if err := client.SendText(*targetPeer, *text); err != nil { + log.Fatalf("send text to %s: %v", *targetPeer, err) + } + log.Printf("sent text to %s", *targetPeer) + } + + if *targetPeer != "" && *filePath != "" { + if err := client.SendFilePath(*targetPeer, *filePath); err != nil { + log.Fatalf("send file %s to %s: %v", *filePath, *targetPeer, err) + } + log.Printf("sent file %s to %s", *filePath, *targetPeer) + } + + if *interactive { + if err := runUDPInteractiveShell(client, os.Stdin, os.Stdout, receiveErr); err != nil { + log.Printf("interactive shell ended: %v", err) + } + return + } + + if err := <-receiveErr; err != nil { + log.Printf("receive loop ended: %v", err) + } +} + +func runUDPInteractiveShell(client *peerpkg.UDPClient, in io.Reader, out io.Writer, receiveErr <-chan error) error { + printUDPInteractiveHelp(out) + lines, inputErr := readUDPInteractiveLines(in, out, fmt.Sprintf("%s> ", client.ID())) + + for { + select { + case err := <-receiveErr: + return err + case line, ok := <-lines: + if !ok { + return <-inputErr + } + + command, err := parseUDPInteractiveCommand(line) + if err != nil { + if err == errUDPEmptyCommand { + continue + } + log.Printf("interactive command error: %v", err) + continue + } + + switch command.name { + case "help": + printUDPInteractiveHelp(out) + case "quit": + return nil + case "text": + if err := client.SendText(command.to, command.value); err != nil { + log.Printf("send text to %s: %v", command.to, err) + continue + } + log.Printf("sent text to %s", command.to) + case "file": + if err := client.SendFilePath(command.to, command.value); err != nil { + log.Printf("send file %s to %s: %v", command.value, command.to, err) + continue + } + log.Printf("sent file %s to %s", command.value, command.to) + } + } + } +} + +func readUDPInteractiveLines(in io.Reader, out io.Writer, prompt string) (<-chan string, <-chan error) { + lines := make(chan string) + errs := make(chan error, 1) + + go func() { + defer close(lines) + + scanner := bufio.NewScanner(in) + scanner.Buffer(make([]byte, 0, 1024), 1024*1024) + + for { + if _, err := fmt.Fprint(out, prompt); err != nil { + errs <- err + return + } + if !scanner.Scan() { + errs <- scanner.Err() + return + } + lines <- scanner.Text() + } + }() + + return lines, errs +} + +func printUDPInteractiveHelp(w io.Writer) { + _, _ = fmt.Fprintln(w, "interactive mode commands (UDP):") + _, _ = fmt.Fprintln(w, " help show this help") + _, _ = fmt.Fprintln(w, " text send one text message over UDP") + _, _ = fmt.Fprintln(w, " file send one file over UDP") + _, _ = fmt.Fprintln(w, " quit exit this peer process") +} diff --git a/cmd/udpserver/main.go b/cmd/udpserver/main.go new file mode 100644 index 0000000..5a938e0 --- /dev/null +++ b/cmd/udpserver/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "flag" + "log" + "net" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/server" +) + +func main() { + listenAddr := flag.String("listen", ":9001", "UDP server listen address") + logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs") + flag.Parse() + + hubOptions := make([]server.UDPOption, 0, 1) + if *logPath != "" { + logger, err := latencylog.NewJSONLLogger(*logPath) + if err != nil { + log.Fatalf("create latency logger %s: %v", *logPath, err) + } + defer logger.Close() + hubOptions = append(hubOptions, server.WithUDPLogger(logger)) + } + + udpAddr, err := net.ResolveUDPAddr("udp", *listenAddr) + if err != nil { + log.Fatalf("resolve udp address %s: %v", *listenAddr, err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + log.Fatalf("listen udp on %s: %v", *listenAddr, err) + } + defer conn.Close() + + hub, err := server.NewUDPHub(conn, hubOptions...) + if err != nil { + log.Fatalf("create udp hub: %v", err) + } + + log.Printf("udp server listening on %s", conn.LocalAddr()) + + if err := hub.Serve(); err != nil { + log.Fatalf("udp hub serve: %v", err) + } +}