Feat: UDP 框架
This commit is contained in:
202
cmd/internal/peer/udp_client.go
Normal file
202
cmd/internal/peer/udp_client.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// UDPClient 表示一个通过 UDP 连接到 server 的 peer。
|
||||
type UDPClient struct {
|
||||
id string
|
||||
conn *transport.UDPConn
|
||||
logger latencylog.Logger
|
||||
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
// DialUDP 通过 UDP 连接到 server,并发送 register 消息完成身份注册。
|
||||
func DialUDP(serverAddr, peerID string, opts ...Option) (*UDPClient, error) {
|
||||
options := clientOptions{
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
if options.logger == nil {
|
||||
options.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
udpServerAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer: resolve udp server addr %s: %w", serverAddr, err)
|
||||
}
|
||||
|
||||
var localAddr *net.UDPAddr
|
||||
if options.bindIP != "" {
|
||||
ip := net.ParseIP(options.bindIP)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("peer: invalid bind ip %q", options.bindIP)
|
||||
}
|
||||
localAddr = &net.UDPAddr{IP: ip}
|
||||
}
|
||||
|
||||
rawConn, err := net.DialUDP("udp", localAddr, udpServerAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer: dial udp server %s: %w", serverAddr, err)
|
||||
}
|
||||
|
||||
conn, err := transport.NewUDPConn(
|
||||
rawConn,
|
||||
nil, // peer 侧已连接模式,不需要指定 peerAddr
|
||||
transport.WithUDPLogger(options.logger, latencylog.NodeRolePeer, peerID),
|
||||
transport.WithUDPTXTimestampDebugLogger(options.txTimestampDebugLogger),
|
||||
)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return nil, fmt.Errorf("peer: create udp transport conn: %w", err)
|
||||
}
|
||||
|
||||
client := &UDPClient{
|
||||
id: peerID,
|
||||
conn: conn,
|
||||
logger: options.logger,
|
||||
}
|
||||
|
||||
// 发送 register 消息完成身份注册
|
||||
if err := conn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: peerID,
|
||||
To: protocol.ServerPeerID,
|
||||
}); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("peer: udp register with server: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// ID 返回当前 client 的 peer 标识。
|
||||
func (c *UDPClient) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
// SendText 向目标 peer 发送一条文本消息。
|
||||
func (c *UDPClient) SendText(to, body string) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
}
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
msg.Body = []byte(body)
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// SendFile 向目标 peer 发送一条文件消息。
|
||||
func (c *UDPClient) SendFile(to, fileName string, body []byte) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
FileName: fileName,
|
||||
}
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
bodyCopy := make([]byte, len(body))
|
||||
copy(bodyCopy, body)
|
||||
|
||||
msg.Body = bodyCopy
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// SendFilePath 从本地文件读取内容并发送给目标 peer。
|
||||
func (c *UDPClient) SendFilePath(to, path string) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
FileName: filepath.Base(path),
|
||||
}
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("peer: read file %s: %w", path, err)
|
||||
}
|
||||
|
||||
msg.Body = body
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// Receive 读取一条来自 server 的消息。
|
||||
func (c *UDPClient) Receive() (protocol.Message, error) {
|
||||
msg, _, err := c.conn.Receive()
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("peer: udp receive from server: %w", err)
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReceiveLoop 持续接收 server 消息并交给 handler 处理。
|
||||
func (c *UDPClient) ReceiveLoop(handler func(protocol.Message) error) error {
|
||||
return c.conn.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError:
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
|
||||
return handler(msg)
|
||||
default:
|
||||
return fmt.Errorf("peer: unexpected message type from server: %s", msg.Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PersistMessage 将收到的业务消息写入本地磁盘。
|
||||
func (c *UDPClient) PersistMessage(msg protocol.Message, inboxDir string) (string, error) {
|
||||
if !latencylog.IsBusinessMessage(msg) {
|
||||
return "", fmt.Errorf("peer: cannot persist message type %s", msg.Type)
|
||||
}
|
||||
if inboxDir == "" {
|
||||
return "", fmt.Errorf("peer: inbox directory is required")
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistBegin, msg)
|
||||
|
||||
if err := os.MkdirAll(inboxDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("peer: create inbox dir %s: %w", inboxDir, err)
|
||||
}
|
||||
|
||||
path, err := persistMessageToDisk(msg, inboxDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistEnd, msg)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Close 关闭与 server 的 UDP 连接。
|
||||
func (c *UDPClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *UDPClient) nextMessageID() uint64 {
|
||||
return atomic.AddUint64(&c.nextID, 1)
|
||||
}
|
||||
211
cmd/internal/peer/udp_client_test.go
Normal file
211
cmd/internal/peer/udp_client_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/server"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// TestUDPDialAndSendText 验证 UDP 客户端可以成功连接、注册并发送文本消息。
|
||||
func TestUDPDialAndSendText(t *testing.T) {
|
||||
hubAddr := startUDPTestHub(t)
|
||||
|
||||
clientA, err := DialUDP(hubAddr.String(), "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP(peer-a) error = %v", err)
|
||||
}
|
||||
defer clientA.Close()
|
||||
|
||||
clientB, err := DialUDP(hubAddr.String(), "peer-b")
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP(peer-b) error = %v", err)
|
||||
}
|
||||
defer clientB.Close()
|
||||
|
||||
// 等待注册被处理
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// peer-a 发送文本给 peer-b
|
||||
if err := clientA.SendText("peer-b", "hello from udp"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
|
||||
// peer-b 接收
|
||||
msg := receiveUDPClientMessage(t, clientB)
|
||||
if msg.Type != protocol.MessageTypeText {
|
||||
t.Fatalf("message type = %s, want text", msg.Type)
|
||||
}
|
||||
if string(msg.Body) != "hello from udp" {
|
||||
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from udp")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPClientID 验证 ID() 返回正确的 peer 标识。
|
||||
func TestUDPClientID(t *testing.T) {
|
||||
hubAddr := startUDPTestHub(t)
|
||||
|
||||
client, err := DialUDP(hubAddr.String(), "my-peer-id")
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP() error = %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if got := client.ID(); got != "my-peer-id" {
|
||||
t.Fatalf("ID() = %q, want %q", got, "my-peer-id")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPClientPersistMessage 验证 UDP 客户端可以将消息持久化到磁盘。
|
||||
func TestUDPClientPersistMessage(t *testing.T) {
|
||||
hubAddr := startUDPTestHub(t)
|
||||
|
||||
client, err := DialUDP(hubAddr.String(), "peer-persist")
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP() error = %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
inboxDir := t.TempDir()
|
||||
|
||||
// 持久化文本消息
|
||||
textMsg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "sender",
|
||||
To: "peer-persist",
|
||||
Body: []byte("persisted text"),
|
||||
}
|
||||
|
||||
path, err := client.PersistMessage(textMsg, inboxDir)
|
||||
if err != nil {
|
||||
t.Fatalf("PersistMessage(text) error = %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(path) && path == "" {
|
||||
t.Fatalf("PersistMessage(text) returned empty path")
|
||||
}
|
||||
|
||||
// 持久化文件消息
|
||||
fileMsg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "sender",
|
||||
To: "peer-persist",
|
||||
FileName: "test.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
|
||||
filePath, err := client.PersistMessage(fileMsg, inboxDir)
|
||||
if err != nil {
|
||||
t.Fatalf("PersistMessage(file) error = %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile(%s) error = %v", filePath, err)
|
||||
}
|
||||
if len(content) != 3 || content[0] != 0x01 {
|
||||
t.Fatalf("file content mismatch: got %v", content)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPClientSendFile 验证 UDP 客户端可以发送文件消息。
|
||||
func TestUDPClientSendFile(t *testing.T) {
|
||||
hubAddr := startUDPTestHub(t)
|
||||
|
||||
clientA, err := DialUDP(hubAddr.String(), "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP(peer-a) error = %v", err)
|
||||
}
|
||||
defer clientA.Close()
|
||||
|
||||
clientB, err := DialUDP(hubAddr.String(), "peer-b")
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP(peer-b) error = %v", err)
|
||||
}
|
||||
defer clientB.Close()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
fileBody := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||
if err := clientA.SendFile("peer-b", "test.bin", fileBody); err != nil {
|
||||
t.Fatalf("SendFile() error = %v", err)
|
||||
}
|
||||
|
||||
msg := receiveUDPClientMessage(t, clientB)
|
||||
if msg.Type != protocol.MessageTypeFile {
|
||||
t.Fatalf("message type = %s, want file", msg.Type)
|
||||
}
|
||||
if msg.FileName != "test.bin" {
|
||||
t.Fatalf("file name = %q, want %q", msg.FileName, "test.bin")
|
||||
}
|
||||
if len(msg.Body) != 4 {
|
||||
t.Fatalf("body length = %d, want 4", len(msg.Body))
|
||||
}
|
||||
}
|
||||
|
||||
// startUDPTestHub 创建并启动一个测试用 UDPHub。
|
||||
func startUDPTestHub(t *testing.T) *net.UDPAddr {
|
||||
t.Helper()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP() error = %v", err)
|
||||
}
|
||||
|
||||
hub, err := server.NewUDPHub(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("NewUDPHub() error = %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = hub.Serve()
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = hub.Close()
|
||||
})
|
||||
|
||||
return conn.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
// receiveUDPClientMessage 从 UDP 客户端接收一条消息,带超时。
|
||||
func receiveUDPClientMessage(t *testing.T, client *UDPClient) protocol.Message {
|
||||
t.Helper()
|
||||
|
||||
type result struct {
|
||||
msg protocol.Message
|
||||
err error
|
||||
}
|
||||
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
msg, err := client.Receive()
|
||||
ch <- result{msg: msg, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case r := <-ch:
|
||||
if r.err != nil {
|
||||
t.Fatalf("Receive() error = %v", r.err)
|
||||
}
|
||||
return r.msg
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive() timed out after 2s")
|
||||
return protocol.Message{}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure transport package is used (needed for WithTXTimestampDebugLogger option).
|
||||
var _ transport.TXTimestampDebugLogger = nil
|
||||
185
cmd/internal/server/udp_hub.go
Normal file
185
cmd/internal/server/udp_hub.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// UDPOption 用于配置 UDPHub 的可选行为。
|
||||
type UDPOption func(*UDPHub)
|
||||
|
||||
// WithUDPLogger 为 UDP hub 注入时延日志记录器。
|
||||
func WithUDPLogger(logger latencylog.Logger) UDPOption {
|
||||
return func(hub *UDPHub) {
|
||||
hub.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// UDPHub 管理通过 UDP 注册的 peer,并负责在它们之间转发消息。
|
||||
// 与 TCP Hub 不同,UDPHub 使用单个 net.UDPConn 与所有 peer 通信,
|
||||
// 通过维护 peerID -> UDPAddr 映射表来寻址。
|
||||
type UDPHub struct {
|
||||
mu sync.RWMutex
|
||||
peers map[string]*net.UDPAddr // peerID -> 对端 UDP 地址
|
||||
addrs map[string]string // addr.String() -> peerID,用于反查
|
||||
|
||||
conn *transport.UDPConn
|
||||
logger latencylog.Logger
|
||||
}
|
||||
|
||||
// NewUDPHub 创建一个新的 UDP 连接中心。
|
||||
func NewUDPHub(conn *net.UDPConn, opts ...UDPOption) (*UDPHub, error) {
|
||||
hub := &UDPHub{
|
||||
peers: make(map[string]*net.UDPAddr),
|
||||
addrs: make(map[string]string),
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(hub)
|
||||
}
|
||||
|
||||
if hub.logger == nil {
|
||||
hub.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
udpConn, err := transport.NewUDPConn(
|
||||
conn,
|
||||
nil,
|
||||
transport.WithUDPLogger(hub.logger, latencylog.NodeRoleServer, "hub"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server: create udp transport conn: %w", err)
|
||||
}
|
||||
|
||||
hub.conn = udpConn
|
||||
|
||||
return hub, nil
|
||||
}
|
||||
|
||||
// Serve 启动 UDP 接收主循环,持续读取消息并处理注册/转发。
|
||||
// 此方法会阻塞,直到底层连接关闭或发生不可恢复的错误。
|
||||
func (h *UDPHub) Serve() error {
|
||||
return h.conn.ReceiveLoop(func(msg protocol.Message, addr *net.UDPAddr) error {
|
||||
if err := h.handleMessage(msg, addr); err != nil {
|
||||
log.Printf("udp hub: handle message from %s: %v", addr, err)
|
||||
}
|
||||
return nil // 不因为单条消息处理失败而退出主循环
|
||||
})
|
||||
}
|
||||
|
||||
// HasPeer 返回给定 ID 是否已注册到 hub。
|
||||
func (h *UDPHub) HasPeer(peerID string) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
_, ok := h.peers[peerID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// handleMessage 处理从指定地址收到的消息。
|
||||
func (h *UDPHub) handleMessage(msg protocol.Message, addr *net.UDPAddr) error {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeRegister:
|
||||
return h.registerPeer(msg, addr)
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile:
|
||||
return h.forwardMessage(msg, addr)
|
||||
case protocol.MessageTypeError:
|
||||
return h.sendErrorTo(addr, msg.From, "peers cannot send error messages")
|
||||
default:
|
||||
peerID := h.lookupPeerID(addr)
|
||||
if peerID == "" {
|
||||
peerID = msg.From
|
||||
}
|
||||
return h.sendErrorTo(addr, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// registerPeer 处理 peer 的注册请求。
|
||||
func (h *UDPHub) registerPeer(msg protocol.Message, addr *net.UDPAddr) error {
|
||||
peerID := msg.From
|
||||
if peerID == "" {
|
||||
return h.sendErrorTo(addr, "", "register: missing peer id")
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 如果同一个 peerID 从新地址注册,更新地址映射(支持 peer 重启换端口)。
|
||||
if existingAddr, exists := h.peers[peerID]; exists {
|
||||
// 清理旧地址的反查映射
|
||||
delete(h.addrs, existingAddr.String())
|
||||
}
|
||||
|
||||
h.peers[peerID] = addr
|
||||
h.addrs[addr.String()] = peerID
|
||||
log.Printf("udp hub: registered peer %s from %s", peerID, addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// forwardMessage 转发业务消息到目标 peer。
|
||||
func (h *UDPHub) forwardMessage(msg protocol.Message, senderAddr *net.UDPAddr) error {
|
||||
// 通过来源地址反查发送者 peerID
|
||||
senderID := h.lookupPeerID(senderAddr)
|
||||
if senderID == "" {
|
||||
return h.sendErrorTo(senderAddr, msg.From, "not registered; send register first")
|
||||
}
|
||||
|
||||
// server 覆盖 From,不信任客户端自报身份
|
||||
msg.From = senderID
|
||||
|
||||
// 查找目标 peer 地址
|
||||
targetAddr := h.lookupAddr(msg.To)
|
||||
if targetAddr == nil {
|
||||
return h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("unknown target: %s", msg.To))
|
||||
}
|
||||
|
||||
// 转发消息
|
||||
if err := h.conn.SendTo(msg, targetAddr); err != nil {
|
||||
// 转发失败,通知发送方
|
||||
_ = h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("failed to forward to %s", msg.To))
|
||||
return fmt.Errorf("forward to %s at %s: %w", msg.To, targetAddr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookupPeerID 通过 UDP 地址反查 peerID。
|
||||
func (h *UDPHub) lookupPeerID(addr *net.UDPAddr) string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
return h.addrs[addr.String()]
|
||||
}
|
||||
|
||||
// lookupAddr 通过 peerID 查找 UDP 地址。
|
||||
func (h *UDPHub) lookupAddr(peerID string) *net.UDPAddr {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
return h.peers[peerID]
|
||||
}
|
||||
|
||||
// sendErrorTo 向指定地址发送错误消息。
|
||||
func (h *UDPHub) sendErrorTo(addr *net.UDPAddr, to, message string) error {
|
||||
if to == "" {
|
||||
to = "unknown"
|
||||
}
|
||||
return h.conn.SendTo(protocol.Message{
|
||||
Type: protocol.MessageTypeError,
|
||||
From: protocol.ServerPeerID,
|
||||
To: to,
|
||||
Body: []byte(message),
|
||||
}, addr)
|
||||
}
|
||||
|
||||
// Close 关闭底层 UDP 连接。
|
||||
func (h *UDPHub) Close() error {
|
||||
return h.conn.Close()
|
||||
}
|
||||
238
cmd/internal/server/udp_hub_test.go
Normal file
238
cmd/internal/server/udp_hub_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// TestUDPHubRegisterAndForward 验证 peer 注册后可以互相转发消息。
|
||||
func TestUDPHubRegisterAndForward(t *testing.T) {
|
||||
hub, hubAddr := startUDPHub(t)
|
||||
_ = hub
|
||||
|
||||
peerA := dialUDPPeer(t, hubAddr)
|
||||
peerB := dialUDPPeer(t, hubAddr)
|
||||
|
||||
// 注册 peer-a
|
||||
sendUDPMessage(t, peerA, protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-a",
|
||||
To: protocol.ServerPeerID,
|
||||
})
|
||||
|
||||
// 注册 peer-b
|
||||
sendUDPMessage(t, peerB, protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-b",
|
||||
To: protocol.ServerPeerID,
|
||||
})
|
||||
|
||||
// 等待注册被处理
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// peer-a 发送消息给 peer-b
|
||||
sendUDPMessage(t, peerA, protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello from peer-a"),
|
||||
})
|
||||
|
||||
// peer-b 应该收到消息
|
||||
msg := receiveUDPMessage(t, peerB)
|
||||
if msg.Type != protocol.MessageTypeText {
|
||||
t.Fatalf("message type = %s, want text", msg.Type)
|
||||
}
|
||||
if msg.From != "peer-a" {
|
||||
t.Fatalf("message from = %s, want peer-a", msg.From)
|
||||
}
|
||||
if msg.To != "peer-b" {
|
||||
t.Fatalf("message to = %s, want peer-b", msg.To)
|
||||
}
|
||||
if string(msg.Body) != "hello from peer-a" {
|
||||
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-a")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPHubRejectsUnregistered 验证未注册的 peer 发送业务消息会收到错误。
|
||||
func TestUDPHubRejectsUnregistered(t *testing.T) {
|
||||
_, hubAddr := startUDPHub(t)
|
||||
|
||||
peer := dialUDPPeer(t, hubAddr)
|
||||
|
||||
// 直接发送业务消息而不注册
|
||||
sendUDPMessage(t, peer, protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("should fail"),
|
||||
})
|
||||
|
||||
// 应该收到错误响应
|
||||
msg := receiveUDPMessage(t, peer)
|
||||
if msg.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("message type = %s, want error", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPHubRejectsUnknownTarget 验证发送到不存在的目标会返回错误。
|
||||
func TestUDPHubRejectsUnknownTarget(t *testing.T) {
|
||||
_, hubAddr := startUDPHub(t)
|
||||
|
||||
peer := dialUDPPeer(t, hubAddr)
|
||||
|
||||
// 注册
|
||||
sendUDPMessage(t, peer, protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-a",
|
||||
To: protocol.ServerPeerID,
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 发送到不存在的目标
|
||||
sendUDPMessage(t, peer, protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-nonexistent",
|
||||
Body: []byte("should fail"),
|
||||
})
|
||||
|
||||
// 应该收到错误响应
|
||||
msg := receiveUDPMessage(t, peer)
|
||||
if msg.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("message type = %s, want error", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPHubOverridesFromField 验证 server 会覆盖消息的 From 字段。
|
||||
func TestUDPHubOverridesFromField(t *testing.T) {
|
||||
_, hubAddr := startUDPHub(t)
|
||||
|
||||
peerA := dialUDPPeer(t, hubAddr)
|
||||
peerB := dialUDPPeer(t, hubAddr)
|
||||
|
||||
sendUDPMessage(t, peerA, protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-a",
|
||||
To: protocol.ServerPeerID,
|
||||
})
|
||||
sendUDPMessage(t, peerB, protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-b",
|
||||
To: protocol.ServerPeerID,
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// peer-a 伪造 From 为 "fake-id"
|
||||
sendUDPMessage(t, peerA, protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "fake-id",
|
||||
To: "peer-b",
|
||||
Body: []byte("spoofed"),
|
||||
})
|
||||
|
||||
msg := receiveUDPMessage(t, peerB)
|
||||
// server 应该用实际注册的 "peer-a" 覆盖 From
|
||||
if msg.From != "peer-a" {
|
||||
t.Fatalf("message from = %s, want peer-a (server should override)", msg.From)
|
||||
}
|
||||
}
|
||||
|
||||
// startUDPHub 创建并启动一个 UDPHub,返回 hub 和监听地址。
|
||||
func startUDPHub(t *testing.T) (*UDPHub, *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP() error = %v", err)
|
||||
}
|
||||
|
||||
hub, err := NewUDPHub(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("NewUDPHub() error = %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = hub.Serve()
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = hub.Close()
|
||||
})
|
||||
|
||||
return hub, conn.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
// dialUDPPeer 创建一个连接到指定地址的 UDP transport 连接。
|
||||
func dialUDPPeer(t *testing.T, serverAddr *net.UDPAddr) *transport.UDPConn {
|
||||
t.Helper()
|
||||
|
||||
raw, err := net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP() error = %v", err)
|
||||
}
|
||||
|
||||
conn, err := transport.NewUDPConn(raw, nil)
|
||||
if err != nil {
|
||||
_ = raw.Close()
|
||||
t.Fatalf("NewUDPConn() error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// sendUDPMessage 发送一条 UDP 消息。
|
||||
func sendUDPMessage(t *testing.T, conn *transport.UDPConn, msg protocol.Message) {
|
||||
t.Helper()
|
||||
|
||||
if err := conn.Send(msg); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// receiveUDPMessage 接收一条 UDP 消息,带超时。
|
||||
func receiveUDPMessage(t *testing.T, conn *transport.UDPConn) protocol.Message {
|
||||
t.Helper()
|
||||
|
||||
type result struct {
|
||||
msg protocol.Message
|
||||
err error
|
||||
}
|
||||
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
msg, _, err := conn.Receive()
|
||||
ch <- result{msg: msg, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case r := <-ch:
|
||||
if r.err != nil {
|
||||
t.Fatalf("Receive() error = %v", r.err)
|
||||
}
|
||||
return r.msg
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive() timed out after 2s")
|
||||
return protocol.Message{}
|
||||
}
|
||||
}
|
||||
141
cmd/internal/transport/udp.go
Normal file
141
cmd/internal/transport/udp.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// UDPConn 是对 UDP 连接的轻量封装。
|
||||
// server 侧:共享同一个 net.UDPConn,Send 时通过 peerAddr 指定对端地址。
|
||||
// peer 侧:独立的 net.UDPConn,已通过 Dial 连接到 server,Send 直接写即可。
|
||||
type UDPConn struct {
|
||||
conn *net.UDPConn
|
||||
peerAddr *net.UDPAddr // server 侧为对端地址;peer 侧为 nil(连接模式下直接 Write)
|
||||
raw syscall.RawConn // 底层 syscall 句柄,用于 Linux socket timestamping
|
||||
|
||||
logger latencylog.Logger
|
||||
txTimestampDebugLogger TXTimestampDebugLogger
|
||||
nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer"
|
||||
nodeID string // 日志中记录的节点 ID
|
||||
writeMu sync.Mutex // 保护 Send 的互斥锁
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// UDPOption 用于为 UDPConn 注入可选行为。
|
||||
type UDPOption func(*UDPConn)
|
||||
|
||||
// WithUDPLogger 为 UDP 连接注入业务消息日志上下文。
|
||||
func WithUDPLogger(logger latencylog.Logger, nodeRole, nodeID string) UDPOption {
|
||||
return func(conn *UDPConn) {
|
||||
conn.logger = logger
|
||||
conn.nodeRole = nodeRole
|
||||
conn.nodeID = nodeID
|
||||
}
|
||||
}
|
||||
|
||||
// WithUDPTXTimestampDebugLogger 为 UDP 连接注入可选的 TX errqueue 调试日志器。
|
||||
func WithUDPTXTimestampDebugLogger(logger TXTimestampDebugLogger) UDPOption {
|
||||
return func(conn *UDPConn) {
|
||||
conn.txTimestampDebugLogger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// NewUDPConn 创建 UDP transport 连接封装。
|
||||
// peerAddr 为 nil 时表示 peer 侧已连接模式(conn 已 Dial 到 server)。
|
||||
// peerAddr 非 nil 时表示 server 侧,Send 时需要指定目标地址。
|
||||
func NewUDPConn(conn *net.UDPConn, peerAddr *net.UDPAddr, opts ...UDPOption) (*UDPConn, error) {
|
||||
udpConn := &UDPConn{
|
||||
conn: conn,
|
||||
peerAddr: peerAddr,
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(udpConn)
|
||||
}
|
||||
|
||||
if udpConn.logger == nil {
|
||||
udpConn.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
if err := udpConn.initUDPLinuxTimestamping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
// Send 将一条协议消息编码为 UDP 数据报并发送。
|
||||
// 多个 goroutine 可以并发调用,内部会串行化写入。
|
||||
func (c *UDPConn) Send(msg protocol.Message) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
|
||||
|
||||
if err := c.sendMessageLinux(msg); err != nil {
|
||||
return fmt.Errorf("transport: udp send message: %w", err)
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendTo 将一条协议消息编码为 UDP 数据报并发送到指定地址。
|
||||
// 主要用于 server 侧向特定 peer 发送消息。
|
||||
func (c *UDPConn) SendTo(msg protocol.Message, addr *net.UDPAddr) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
|
||||
|
||||
if err := c.sendMessageToLinux(msg, addr); err != nil {
|
||||
return fmt.Errorf("transport: udp send message to %s: %w", addr, err)
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive 从 UDP 连接读取一条完整协议消息。
|
||||
// 返回解码后的消息和来源地址(peer 侧来源地址始终为 server 地址)。
|
||||
func (c *UDPConn) Receive() (protocol.Message, *net.UDPAddr, error) {
|
||||
msg, addr, err := c.receiveMessageLinux()
|
||||
if err != nil {
|
||||
return protocol.Message{}, nil, fmt.Errorf("transport: udp receive message: %w", err)
|
||||
}
|
||||
|
||||
return msg, addr, nil
|
||||
}
|
||||
|
||||
// ReceiveLoop 持续从 UDP 连接读取消息并交给 handler 处理。
|
||||
// handler 的第二个参数是消息来源地址。
|
||||
func (c *UDPConn) ReceiveLoop(handler func(protocol.Message, *net.UDPAddr) error) error {
|
||||
for {
|
||||
msg, addr, err := c.Receive()
|
||||
if err != nil {
|
||||
return fmt.Errorf("transport: udp receive loop read: %w", err)
|
||||
}
|
||||
|
||||
if err := handler(msg, addr); err != nil {
|
||||
return fmt.Errorf("transport: udp receive loop handler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭底层 UDP 连接,保证重复调用安全。
|
||||
// 注意:server 侧多个 UDPConn 共享同一个 net.UDPConn 时,
|
||||
// 只应由 UDPHub 负责关闭底层连接,不应通过此方法关闭。
|
||||
func (c *UDPConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
|
||||
return c.closeErr
|
||||
}
|
||||
358
cmd/internal/transport/udp_linux.go
Normal file
358
cmd/internal/transport/udp_linux.go
Normal file
@@ -0,0 +1,358 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// UDP 接收缓冲区大小,足以容纳 MaxFrameSize 加上协议头。
|
||||
const udpReceiveBufferSize = protocol.MaxFrameSize + 1024
|
||||
|
||||
// initUDPLinuxTimestamping 拿到底层 fd,并打开 Linux timestamping。
|
||||
func (c *UDPConn) initUDPLinuxTimestamping() error {
|
||||
rawConn, err := c.conn.SyscallConn()
|
||||
if err != nil || rawConn == nil {
|
||||
if err != nil {
|
||||
return fmt.Errorf("transport: udp get syscall conn: %w", err)
|
||||
}
|
||||
return fmt.Errorf("transport: udp missing syscall conn")
|
||||
}
|
||||
|
||||
// UDP 不需要 OPT_ID_TCP,使用标准的 OPT_ID 即可。
|
||||
flagCandidates := []int{
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptID |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, flags := range flagCandidates {
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if lastErr == nil {
|
||||
c.raw = rawConn
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(lastErr, syscall.EINVAL) {
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// sendMessageLinux 编码消息并通过 UDP 发送,采集 TX 时间戳。
|
||||
func (c *UDPConn) sendMessageLinux(msg protocol.Message) error {
|
||||
payload, err := protocol.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("protocol: encode message: %w", err)
|
||||
}
|
||||
|
||||
readIndex := 0
|
||||
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex)
|
||||
|
||||
if c.peerAddr != nil {
|
||||
if err := c.udpSendTo(payload, c.peerAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := c.udpSend(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.collectAndLogUDPTXTimestampEvents(msg, &readIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendMessageToLinux 编码消息并通过 UDP 发送到指定地址,采集 TX 时间戳。
|
||||
func (c *UDPConn) sendMessageToLinux(msg protocol.Message, addr *net.UDPAddr) error {
|
||||
payload, err := protocol.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("protocol: encode message: %w", err)
|
||||
}
|
||||
|
||||
readIndex := 0
|
||||
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex)
|
||||
|
||||
if err := c.udpSendTo(payload, addr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.collectAndLogUDPTXTimestampEvents(msg, &readIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// udpSend 通过已连接的 UDP socket 发送数据。
|
||||
func (c *UDPConn) udpSend(payload []byte) error {
|
||||
if c.raw != nil {
|
||||
return c.udpSendmsgRaw(payload, nil)
|
||||
}
|
||||
_, err := c.conn.Write(payload)
|
||||
return err
|
||||
}
|
||||
|
||||
// udpSendTo 通过 UDP socket 发送数据到指定地址。
|
||||
func (c *UDPConn) udpSendTo(payload []byte, addr *net.UDPAddr) error {
|
||||
if c.raw != nil {
|
||||
sa := udpAddrToSockaddr(addr)
|
||||
if sa != nil {
|
||||
return c.udpSendmsgRaw(payload, sa)
|
||||
}
|
||||
}
|
||||
_, err := c.conn.WriteToUDP(payload, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
// udpSendmsgRaw 通过 sendmsg syscall 发送 UDP 数据。
|
||||
func (c *UDPConn) udpSendmsgRaw(payload []byte, to syscall.Sockaddr) error {
|
||||
var opErr error
|
||||
|
||||
for {
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
opErr = syscall.Sendmsg(int(fd), payload, nil, to, 0)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if opErr == nil {
|
||||
return nil
|
||||
}
|
||||
if isWouldBlock(opErr) {
|
||||
time.Sleep(linuxDataPollInterval)
|
||||
continue
|
||||
}
|
||||
return opErr
|
||||
}
|
||||
}
|
||||
|
||||
// receiveMessageLinux 从 UDP 连接读取一条完整消息,并记录 RX 时间戳。
|
||||
func (c *UDPConn) receiveMessageLinux() (protocol.Message, *net.UDPAddr, error) {
|
||||
payload, addr, rxTimestamp, err := c.udpRecvFrom()
|
||||
if err != nil {
|
||||
return protocol.Message{}, nil, fmt.Errorf("protocol: udp read: %w", err)
|
||||
}
|
||||
|
||||
msg, err := protocol.DecodeMessage(payload)
|
||||
if err != nil {
|
||||
return protocol.Message{}, nil, fmt.Errorf("protocol: decode message: %w", err)
|
||||
}
|
||||
|
||||
if rxTimestamp > 0 {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg)
|
||||
}
|
||||
|
||||
return msg, addr, nil
|
||||
}
|
||||
|
||||
// udpRecvFrom 从 UDP socket 接收一个完整数据报,返回数据、来源地址和 RX 时间戳。
|
||||
func (c *UDPConn) udpRecvFrom() ([]byte, *net.UDPAddr, int64, error) {
|
||||
if c.raw != nil {
|
||||
return c.udpRecvmsgRaw()
|
||||
}
|
||||
|
||||
buf := make([]byte, udpReceiveBufferSize)
|
||||
n, addr, err := c.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
return buf[:n], addr, 0, nil
|
||||
}
|
||||
|
||||
// udpRecvmsgRaw 通过 recvmsg syscall 接收 UDP 数据,同时采集 RX 时间戳。
|
||||
func (c *UDPConn) udpRecvmsgRaw() ([]byte, *net.UDPAddr, int64, error) {
|
||||
for {
|
||||
var (
|
||||
n int
|
||||
rxTimeNS int64
|
||||
from syscall.Sockaddr
|
||||
opErr error
|
||||
)
|
||||
|
||||
buf := make([]byte, udpReceiveBufferSize)
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
readN, oobN, _, sa, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
|
||||
if recvErr != nil {
|
||||
opErr = recvErr
|
||||
return
|
||||
}
|
||||
n = readN
|
||||
from = sa
|
||||
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
if opErr != nil {
|
||||
if isWouldBlock(opErr) {
|
||||
time.Sleep(linuxDataPollInterval)
|
||||
continue
|
||||
}
|
||||
return nil, nil, 0, opErr
|
||||
}
|
||||
|
||||
addr := sockaddrToUDPAddr(from)
|
||||
return buf[:n], addr, rxTimeNS, nil
|
||||
}
|
||||
}
|
||||
|
||||
// collectAndLogUDPTXTimestampEvents 采集并记录 UDP 发送的 TX 时间戳事件。
|
||||
func (c *UDPConn) collectAndLogUDPTXTimestampEvents(msg protocol.Message, readIndex *int) {
|
||||
timestamps := c.collectUDPTXTimestampEvents(msg, readIndex)
|
||||
|
||||
if ts, ok := timestamps[latencylog.EventATXSched]; ok {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg)
|
||||
}
|
||||
if ts, ok := timestamps[latencylog.EventATXSoftware]; ok {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// collectUDPTXTimestampEvents 在 errqueue 中等待 TX 时间戳。
|
||||
func (c *UDPConn) collectUDPTXTimestampEvents(msg protocol.Message, readIndex *int) map[string]int64 {
|
||||
if c.raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(linuxTXTimestampWaitTimeout)
|
||||
timestamps := make(map[string]int64, 2)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
event, err := c.recvUDPTXTimestampOnce()
|
||||
if err != nil {
|
||||
if isWouldBlock(err) {
|
||||
time.Sleep(linuxTXTimestampPollInterval)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if event.EventName == "" || event.TSUnixNano <= 0 {
|
||||
continue
|
||||
}
|
||||
*readIndex++
|
||||
|
||||
if isBusinessTXTimestampEventName(event.EventName) {
|
||||
if _, exists := timestamps[event.EventName]; !exists {
|
||||
timestamps[event.EventName] = event.TSUnixNano
|
||||
}
|
||||
}
|
||||
|
||||
if hasCompleteTXTimestampPair(timestamps) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePostSelectDrain, readIndex)
|
||||
return timestamps
|
||||
}
|
||||
|
||||
// drainPendingUDPTXTimestampEvents 清空 errqueue 中残留的时间戳事件。
|
||||
func (c *UDPConn) drainPendingUDPTXTimestampEvents(msg protocol.Message, phase string, readIndex *int) {
|
||||
if c.raw == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
event, err := c.recvUDPTXTimestampOnce()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if event.EventName == "" || event.TSUnixNano <= 0 {
|
||||
continue
|
||||
}
|
||||
*readIndex++
|
||||
}
|
||||
}
|
||||
|
||||
// recvUDPTXTimestampOnce 从 errqueue 读一次时间戳事件。
|
||||
func (c *UDPConn) recvUDPTXTimestampOnce() (txTimestampEvent, error) {
|
||||
var (
|
||||
event txTimestampEvent
|
||||
opErr error
|
||||
)
|
||||
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
_, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
|
||||
if recvErr != nil {
|
||||
opErr = recvErr
|
||||
return
|
||||
}
|
||||
event, _ = parseTXTimestampControlMessages(oob[:oobn])
|
||||
})
|
||||
if err != nil {
|
||||
return txTimestampEvent{}, err
|
||||
}
|
||||
if opErr != nil {
|
||||
return txTimestampEvent{}, opErr
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// udpAddrToSockaddr 将 net.UDPAddr 转换为 syscall.Sockaddr。
|
||||
func udpAddrToSockaddr(addr *net.UDPAddr) syscall.Sockaddr {
|
||||
if ip4 := addr.IP.To4(); ip4 != nil {
|
||||
sa := &syscall.SockaddrInet4{Port: addr.Port}
|
||||
copy(sa.Addr[:], ip4)
|
||||
return sa
|
||||
}
|
||||
if ip6 := addr.IP.To16(); ip6 != nil {
|
||||
sa := &syscall.SockaddrInet6{Port: addr.Port}
|
||||
copy(sa.Addr[:], ip6)
|
||||
return sa
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sockaddrToUDPAddr 将 syscall.Sockaddr 转换为 net.UDPAddr。
|
||||
func sockaddrToUDPAddr(sa syscall.Sockaddr) *net.UDPAddr {
|
||||
switch addr := sa.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
return &net.UDPAddr{
|
||||
IP: net.IP(addr.Addr[:]),
|
||||
Port: addr.Port,
|
||||
}
|
||||
case *syscall.SockaddrInet6:
|
||||
return &net.UDPAddr{
|
||||
IP: net.IP(addr.Addr[:]),
|
||||
Port: addr.Port,
|
||||
Zone: zoneToString(addr.ZoneId),
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func zoneToString(zone uint32) string {
|
||||
if zone == 0 {
|
||||
return ""
|
||||
}
|
||||
iface, err := net.InterfaceByIndex(int(zone))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return iface.Name
|
||||
}
|
||||
105
cmd/internal/transport/udp_linux_test.go
Normal file
105
cmd/internal/transport/udp_linux_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// TestUDPLinuxTimestampingRecordsKernelEvents 验证 UDP 在 Linux 上能正确采集内核时间戳。
|
||||
func TestUDPLinuxTimestampingRecordsKernelEvents(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 41,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello over udp"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 42,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x00, 0x01, 0x02, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
senderLogger := &recordingLogger{}
|
||||
receiverLogger := &recordingLogger{}
|
||||
|
||||
// 创建 server 侧监听
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||
}
|
||||
serverRaw, err := net.ListenUDP("udp", serverAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP() error = %v", err)
|
||||
}
|
||||
receiver, err := NewUDPConn(
|
||||
serverRaw,
|
||||
nil,
|
||||
WithUDPLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"),
|
||||
)
|
||||
if err != nil {
|
||||
_ = serverRaw.Close()
|
||||
t.Fatalf("NewUDPConn(receiver) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = receiver.Close() })
|
||||
|
||||
// 创建 peer 侧连接
|
||||
peerRaw, err := net.DialUDP("udp", nil, serverRaw.LocalAddr().(*net.UDPAddr))
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP() error = %v", err)
|
||||
}
|
||||
sender, err := NewUDPConn(
|
||||
peerRaw,
|
||||
nil,
|
||||
WithUDPLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"),
|
||||
)
|
||||
if err != nil {
|
||||
_ = peerRaw.Close()
|
||||
t.Fatalf("NewUDPConn(sender) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = sender.Close() })
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg)
|
||||
}()
|
||||
|
||||
got, _, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
|
||||
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID)
|
||||
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID)
|
||||
assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
358
cmd/internal/transport/udp_test.go
Normal file
358
cmd/internal/transport/udp_test.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// TestUDPSendReceiveMessage 验证 UDP transport 可以正常收发 text 和 file 消息。
|
||||
func TestUDPSendReceiveMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello udp"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "data.bin",
|
||||
Body: []byte{0x00, 0x10, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sender, receiver := newUDPConnPair(t, nil, nil)
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg)
|
||||
}()
|
||||
|
||||
got, _, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPSendLogsHandoffEvents 验证 UDP Send 会记录 handoff 事件。
|
||||
func TestUDPSendLogsHandoffEvents(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
sender, receiver := newUDPConnPair(
|
||||
t,
|
||||
[]UDPOption{WithUDPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
got, _, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) < 2 {
|
||||
t.Fatalf("event count = %d, want at least 2", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventSendHandoffBegin {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
|
||||
}
|
||||
// 最后一个事件应该是 SendHandoffEnd
|
||||
lastEvent := events[len(events)-1]
|
||||
if lastEvent.Event != latencylog.EventSendHandoffEnd {
|
||||
t.Fatalf("last event = %q, want %q", lastEvent.Event, latencylog.EventSendHandoffEnd)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
|
||||
func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
|
||||
sender, receiver := newUDPConnPair(t, nil, nil)
|
||||
|
||||
want := []protocol.Message{
|
||||
{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
got []protocol.Message
|
||||
)
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
got = append(got, msg)
|
||||
if len(got) >= len(want) {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
for _, msg := range want {
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭发送端,ReceiveLoop 会因读取错误退出
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "udp receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUDPCloseIsIdempotent 验证 Close 可以安全地被重复调用。
|
||||
func TestUDPCloseIsIdempotent(t *testing.T) {
|
||||
conn, peer := newUDPConnPair(t, nil, nil)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(first) error = %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(second) error = %v, want nil", err)
|
||||
}
|
||||
_ = peer.Close()
|
||||
}
|
||||
|
||||
// TestUDPSendToMessage 验证 SendTo 可以向指定地址发送消息。
|
||||
func TestUDPSendToMessage(t *testing.T) {
|
||||
serverConn := newUDPListener(t)
|
||||
peerConn := newUDPDialed(t, serverConn.conn.LocalAddr().(*net.UDPAddr))
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "server",
|
||||
Body: []byte("hello sendto"),
|
||||
}
|
||||
|
||||
// peer 发送消息到 server
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- peerConn.Send(msg)
|
||||
}()
|
||||
|
||||
got, addr, err := serverConn.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
|
||||
// server 用 SendTo 回复到 peer 地址
|
||||
reply := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 2,
|
||||
From: "server",
|
||||
To: "peer-a",
|
||||
Body: []byte("reply"),
|
||||
}
|
||||
|
||||
sendErr2 := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr2 <- serverConn.SendTo(reply, addr)
|
||||
}()
|
||||
|
||||
gotReply, _, err := peerConn.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peer Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr2; err != nil {
|
||||
t.Fatalf("SendTo() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotReply, reply) {
|
||||
t.Fatalf("reply mismatch: got %+v want %+v", gotReply, reply)
|
||||
}
|
||||
}
|
||||
|
||||
// newUDPConnPair 创建一对互相连接的 UDP transport 连接,用于测试。
|
||||
func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOption) (*UDPConn, *UDPConn) {
|
||||
t.Helper()
|
||||
|
||||
// 创建两个 UDP socket,通过 Dial 互相连接
|
||||
addr1, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||
}
|
||||
addr2, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||
}
|
||||
|
||||
conn1, err := net.ListenUDP("udp", addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP(1) error = %v", err)
|
||||
}
|
||||
conn2, err := net.ListenUDP("udp", addr2)
|
||||
if err != nil {
|
||||
_ = conn1.Close()
|
||||
t.Fatalf("ListenUDP(2) error = %v", err)
|
||||
}
|
||||
|
||||
// 用 Dial 模式连接对端
|
||||
senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr))
|
||||
if err != nil {
|
||||
_ = conn1.Close()
|
||||
_ = conn2.Close()
|
||||
t.Fatalf("DialUDP(sender) error = %v", err)
|
||||
}
|
||||
_ = conn1.Close() // 不再需要 conn1
|
||||
|
||||
receiverRaw, err := net.DialUDP("udp", conn2.LocalAddr().(*net.UDPAddr), senderRaw.LocalAddr().(*net.UDPAddr))
|
||||
if err != nil {
|
||||
_ = senderRaw.Close()
|
||||
_ = conn2.Close()
|
||||
t.Fatalf("DialUDP(receiver) error = %v", err)
|
||||
}
|
||||
_ = conn2.Close() // 不再需要 conn2
|
||||
|
||||
sender, err := NewUDPConn(senderRaw, nil, senderOpts...)
|
||||
if err != nil {
|
||||
_ = senderRaw.Close()
|
||||
_ = receiverRaw.Close()
|
||||
t.Fatalf("NewUDPConn(sender) error = %v", err)
|
||||
}
|
||||
|
||||
receiver, err := NewUDPConn(receiverRaw, nil, receiverOpts...)
|
||||
if err != nil {
|
||||
_ = sender.Close()
|
||||
_ = receiverRaw.Close()
|
||||
t.Fatalf("NewUDPConn(receiver) error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
})
|
||||
|
||||
return sender, receiver
|
||||
}
|
||||
|
||||
// newUDPListener 创建一个监听模式的 UDP 连接,用于测试 server 场景。
|
||||
func newUDPListener(t *testing.T) *UDPConn {
|
||||
t.Helper()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||||
}
|
||||
|
||||
raw, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP() error = %v", err)
|
||||
}
|
||||
|
||||
conn, err := NewUDPConn(raw, nil)
|
||||
if err != nil {
|
||||
_ = raw.Close()
|
||||
t.Fatalf("NewUDPConn() error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// newUDPDialed 创建一个已连接到指定地址的 UDP 连接,用于测试 peer 场景。
|
||||
func newUDPDialed(t *testing.T, serverAddr *net.UDPAddr) *UDPConn {
|
||||
t.Helper()
|
||||
|
||||
raw, err := net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("DialUDP() error = %v", err)
|
||||
}
|
||||
|
||||
conn, err := NewUDPConn(raw, nil)
|
||||
if err != nil {
|
||||
_ = raw.Close()
|
||||
t.Fatalf("NewUDPConn() error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
||||
return conn
|
||||
}
|
||||
Reference in New Issue
Block a user