Feat: UDP 框架

This commit is contained in:
2026-03-24 15:39:00 +08:00
parent 44f39c12ed
commit c126b05961
12 changed files with 2119 additions and 0 deletions

View File

@@ -0,0 +1,141 @@
package transport
import (
"fmt"
"net"
"sync"
"syscall"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// UDPConn 是对 UDP 连接的轻量封装。
// server 侧:共享同一个 net.UDPConnSend 时通过 peerAddr 指定对端地址。
// peer 侧:独立的 net.UDPConn已通过 Dial 连接到 serverSend 直接写即可。
type UDPConn struct {
conn *net.UDPConn
peerAddr *net.UDPAddr // server 侧为对端地址peer 侧为 nil连接模式下直接 Write
raw syscall.RawConn // 底层 syscall 句柄,用于 Linux socket timestamping
logger latencylog.Logger
txTimestampDebugLogger TXTimestampDebugLogger
nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer"
nodeID string // 日志中记录的节点 ID
writeMu sync.Mutex // 保护 Send 的互斥锁
closeOnce sync.Once
closeErr error
}
// UDPOption 用于为 UDPConn 注入可选行为。
type UDPOption func(*UDPConn)
// WithUDPLogger 为 UDP 连接注入业务消息日志上下文。
func WithUDPLogger(logger latencylog.Logger, nodeRole, nodeID string) UDPOption {
return func(conn *UDPConn) {
conn.logger = logger
conn.nodeRole = nodeRole
conn.nodeID = nodeID
}
}
// WithUDPTXTimestampDebugLogger 为 UDP 连接注入可选的 TX errqueue 调试日志器。
func WithUDPTXTimestampDebugLogger(logger TXTimestampDebugLogger) UDPOption {
return func(conn *UDPConn) {
conn.txTimestampDebugLogger = logger
}
}
// NewUDPConn 创建 UDP transport 连接封装。
// peerAddr 为 nil 时表示 peer 侧已连接模式conn 已 Dial 到 server
// peerAddr 非 nil 时表示 server 侧Send 时需要指定目标地址。
func NewUDPConn(conn *net.UDPConn, peerAddr *net.UDPAddr, opts ...UDPOption) (*UDPConn, error) {
udpConn := &UDPConn{
conn: conn,
peerAddr: peerAddr,
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(udpConn)
}
if udpConn.logger == nil {
udpConn.logger = latencylog.NoopLogger{}
}
if err := udpConn.initUDPLinuxTimestamping(); err != nil {
return nil, err
}
return udpConn, nil
}
// Send 将一条协议消息编码为 UDP 数据报并发送。
// 多个 goroutine 可以并发调用,内部会串行化写入。
func (c *UDPConn) Send(msg protocol.Message) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
if err := c.sendMessageLinux(msg); err != nil {
return fmt.Errorf("transport: udp send message: %w", err)
}
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
return nil
}
// SendTo 将一条协议消息编码为 UDP 数据报并发送到指定地址。
// 主要用于 server 侧向特定 peer 发送消息。
func (c *UDPConn) SendTo(msg protocol.Message, addr *net.UDPAddr) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
if err := c.sendMessageToLinux(msg, addr); err != nil {
return fmt.Errorf("transport: udp send message to %s: %w", addr, err)
}
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
return nil
}
// Receive 从 UDP 连接读取一条完整协议消息。
// 返回解码后的消息和来源地址peer 侧来源地址始终为 server 地址)。
func (c *UDPConn) Receive() (protocol.Message, *net.UDPAddr, error) {
msg, addr, err := c.receiveMessageLinux()
if err != nil {
return protocol.Message{}, nil, fmt.Errorf("transport: udp receive message: %w", err)
}
return msg, addr, nil
}
// ReceiveLoop 持续从 UDP 连接读取消息并交给 handler 处理。
// handler 的第二个参数是消息来源地址。
func (c *UDPConn) ReceiveLoop(handler func(protocol.Message, *net.UDPAddr) error) error {
for {
msg, addr, err := c.Receive()
if err != nil {
return fmt.Errorf("transport: udp receive loop read: %w", err)
}
if err := handler(msg, addr); err != nil {
return fmt.Errorf("transport: udp receive loop handler: %w", err)
}
}
}
// Close 关闭底层 UDP 连接,保证重复调用安全。
// 注意server 侧多个 UDPConn 共享同一个 net.UDPConn 时,
// 只应由 UDPHub 负责关闭底层连接,不应通过此方法关闭。
func (c *UDPConn) Close() error {
c.closeOnce.Do(func() {
c.closeErr = c.conn.Close()
})
return c.closeErr
}

View File

@@ -0,0 +1,358 @@
//go:build linux
package transport
import (
"errors"
"fmt"
"net"
"syscall"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// UDP 接收缓冲区大小,足以容纳 MaxFrameSize 加上协议头。
const udpReceiveBufferSize = protocol.MaxFrameSize + 1024
// initUDPLinuxTimestamping 拿到底层 fd并打开 Linux timestamping。
func (c *UDPConn) initUDPLinuxTimestamping() error {
rawConn, err := c.conn.SyscallConn()
if err != nil || rawConn == nil {
if err != nil {
return fmt.Errorf("transport: udp get syscall conn: %w", err)
}
return fmt.Errorf("transport: udp missing syscall conn")
}
// UDP 不需要 OPT_ID_TCP使用标准的 OPT_ID 即可。
flagCandidates := []int{
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptID |
linuxSOFTimestampingOptTSONLY,
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptTSONLY,
}
var lastErr error
for _, flags := range flagCandidates {
err := rawConn.Control(func(fd uintptr) {
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
})
if err != nil {
return err
}
if lastErr == nil {
c.raw = rawConn
return nil
}
if !errors.Is(lastErr, syscall.EINVAL) {
return lastErr
}
}
return lastErr
}
// sendMessageLinux 编码消息并通过 UDP 发送,采集 TX 时间戳。
func (c *UDPConn) sendMessageLinux(msg protocol.Message) error {
payload, err := protocol.EncodeMessage(msg)
if err != nil {
return fmt.Errorf("protocol: encode message: %w", err)
}
readIndex := 0
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex)
if c.peerAddr != nil {
if err := c.udpSendTo(payload, c.peerAddr); err != nil {
return err
}
} else {
if err := c.udpSend(payload); err != nil {
return err
}
}
c.collectAndLogUDPTXTimestampEvents(msg, &readIndex)
return nil
}
// sendMessageToLinux 编码消息并通过 UDP 发送到指定地址,采集 TX 时间戳。
func (c *UDPConn) sendMessageToLinux(msg protocol.Message, addr *net.UDPAddr) error {
payload, err := protocol.EncodeMessage(msg)
if err != nil {
return fmt.Errorf("protocol: encode message: %w", err)
}
readIndex := 0
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex)
if err := c.udpSendTo(payload, addr); err != nil {
return err
}
c.collectAndLogUDPTXTimestampEvents(msg, &readIndex)
return nil
}
// udpSend 通过已连接的 UDP socket 发送数据。
func (c *UDPConn) udpSend(payload []byte) error {
if c.raw != nil {
return c.udpSendmsgRaw(payload, nil)
}
_, err := c.conn.Write(payload)
return err
}
// udpSendTo 通过 UDP socket 发送数据到指定地址。
func (c *UDPConn) udpSendTo(payload []byte, addr *net.UDPAddr) error {
if c.raw != nil {
sa := udpAddrToSockaddr(addr)
if sa != nil {
return c.udpSendmsgRaw(payload, sa)
}
}
_, err := c.conn.WriteToUDP(payload, addr)
return err
}
// udpSendmsgRaw 通过 sendmsg syscall 发送 UDP 数据。
func (c *UDPConn) udpSendmsgRaw(payload []byte, to syscall.Sockaddr) error {
var opErr error
for {
err := c.raw.Control(func(fd uintptr) {
opErr = syscall.Sendmsg(int(fd), payload, nil, to, 0)
})
if err != nil {
return err
}
if opErr == nil {
return nil
}
if isWouldBlock(opErr) {
time.Sleep(linuxDataPollInterval)
continue
}
return opErr
}
}
// receiveMessageLinux 从 UDP 连接读取一条完整消息,并记录 RX 时间戳。
func (c *UDPConn) receiveMessageLinux() (protocol.Message, *net.UDPAddr, error) {
payload, addr, rxTimestamp, err := c.udpRecvFrom()
if err != nil {
return protocol.Message{}, nil, fmt.Errorf("protocol: udp read: %w", err)
}
msg, err := protocol.DecodeMessage(payload)
if err != nil {
return protocol.Message{}, nil, fmt.Errorf("protocol: decode message: %w", err)
}
if rxTimestamp > 0 {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg)
}
return msg, addr, nil
}
// udpRecvFrom 从 UDP socket 接收一个完整数据报,返回数据、来源地址和 RX 时间戳。
func (c *UDPConn) udpRecvFrom() ([]byte, *net.UDPAddr, int64, error) {
if c.raw != nil {
return c.udpRecvmsgRaw()
}
buf := make([]byte, udpReceiveBufferSize)
n, addr, err := c.conn.ReadFromUDP(buf)
if err != nil {
return nil, nil, 0, err
}
return buf[:n], addr, 0, nil
}
// udpRecvmsgRaw 通过 recvmsg syscall 接收 UDP 数据,同时采集 RX 时间戳。
func (c *UDPConn) udpRecvmsgRaw() ([]byte, *net.UDPAddr, int64, error) {
for {
var (
n int
rxTimeNS int64
from syscall.Sockaddr
opErr error
)
buf := make([]byte, udpReceiveBufferSize)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
readN, oobN, _, sa, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
if recvErr != nil {
opErr = recvErr
return
}
n = readN
from = sa
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
})
if err != nil {
return nil, nil, 0, err
}
if opErr != nil {
if isWouldBlock(opErr) {
time.Sleep(linuxDataPollInterval)
continue
}
return nil, nil, 0, opErr
}
addr := sockaddrToUDPAddr(from)
return buf[:n], addr, rxTimeNS, nil
}
}
// collectAndLogUDPTXTimestampEvents 采集并记录 UDP 发送的 TX 时间戳事件。
func (c *UDPConn) collectAndLogUDPTXTimestampEvents(msg protocol.Message, readIndex *int) {
timestamps := c.collectUDPTXTimestampEvents(msg, readIndex)
if ts, ok := timestamps[latencylog.EventATXSched]; ok {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg)
}
if ts, ok := timestamps[latencylog.EventATXSoftware]; ok {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg)
}
}
// collectUDPTXTimestampEvents 在 errqueue 中等待 TX 时间戳。
func (c *UDPConn) collectUDPTXTimestampEvents(msg protocol.Message, readIndex *int) map[string]int64 {
if c.raw == nil {
return nil
}
deadline := time.Now().Add(linuxTXTimestampWaitTimeout)
timestamps := make(map[string]int64, 2)
for time.Now().Before(deadline) {
event, err := c.recvUDPTXTimestampOnce()
if err != nil {
if isWouldBlock(err) {
time.Sleep(linuxTXTimestampPollInterval)
continue
}
break
}
if event.EventName == "" || event.TSUnixNano <= 0 {
continue
}
*readIndex++
if isBusinessTXTimestampEventName(event.EventName) {
if _, exists := timestamps[event.EventName]; !exists {
timestamps[event.EventName] = event.TSUnixNano
}
}
if hasCompleteTXTimestampPair(timestamps) {
break
}
}
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePostSelectDrain, readIndex)
return timestamps
}
// drainPendingUDPTXTimestampEvents 清空 errqueue 中残留的时间戳事件。
func (c *UDPConn) drainPendingUDPTXTimestampEvents(msg protocol.Message, phase string, readIndex *int) {
if c.raw == nil {
return
}
for {
event, err := c.recvUDPTXTimestampOnce()
if err != nil {
return
}
if event.EventName == "" || event.TSUnixNano <= 0 {
continue
}
*readIndex++
}
}
// recvUDPTXTimestampOnce 从 errqueue 读一次时间戳事件。
func (c *UDPConn) recvUDPTXTimestampOnce() (txTimestampEvent, error) {
var (
event txTimestampEvent
opErr error
)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
_, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
if recvErr != nil {
opErr = recvErr
return
}
event, _ = parseTXTimestampControlMessages(oob[:oobn])
})
if err != nil {
return txTimestampEvent{}, err
}
if opErr != nil {
return txTimestampEvent{}, opErr
}
return event, nil
}
// udpAddrToSockaddr 将 net.UDPAddr 转换为 syscall.Sockaddr。
func udpAddrToSockaddr(addr *net.UDPAddr) syscall.Sockaddr {
if ip4 := addr.IP.To4(); ip4 != nil {
sa := &syscall.SockaddrInet4{Port: addr.Port}
copy(sa.Addr[:], ip4)
return sa
}
if ip6 := addr.IP.To16(); ip6 != nil {
sa := &syscall.SockaddrInet6{Port: addr.Port}
copy(sa.Addr[:], ip6)
return sa
}
return nil
}
// sockaddrToUDPAddr 将 syscall.Sockaddr 转换为 net.UDPAddr。
func sockaddrToUDPAddr(sa syscall.Sockaddr) *net.UDPAddr {
switch addr := sa.(type) {
case *syscall.SockaddrInet4:
return &net.UDPAddr{
IP: net.IP(addr.Addr[:]),
Port: addr.Port,
}
case *syscall.SockaddrInet6:
return &net.UDPAddr{
IP: net.IP(addr.Addr[:]),
Port: addr.Port,
Zone: zoneToString(addr.ZoneId),
}
default:
return nil
}
}
func zoneToString(zone uint32) string {
if zone == 0 {
return ""
}
iface, err := net.InterfaceByIndex(int(zone))
if err != nil {
return ""
}
return iface.Name
}

View File

@@ -0,0 +1,105 @@
//go:build linux
package transport
import (
"net"
"reflect"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// TestUDPLinuxTimestampingRecordsKernelEvents 验证 UDP 在 Linux 上能正确采集内核时间戳。
func TestUDPLinuxTimestampingRecordsKernelEvents(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 41,
From: "peer-a",
To: "peer-b",
Body: []byte("hello over udp"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 42,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x00, 0x01, 0x02, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
senderLogger := &recordingLogger{}
receiverLogger := &recordingLogger{}
// 创建 server 侧监听
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
serverRaw, err := net.ListenUDP("udp", serverAddr)
if err != nil {
t.Fatalf("ListenUDP() error = %v", err)
}
receiver, err := NewUDPConn(
serverRaw,
nil,
WithUDPLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"),
)
if err != nil {
_ = serverRaw.Close()
t.Fatalf("NewUDPConn(receiver) error = %v", err)
}
t.Cleanup(func() { _ = receiver.Close() })
// 创建 peer 侧连接
peerRaw, err := net.DialUDP("udp", nil, serverRaw.LocalAddr().(*net.UDPAddr))
if err != nil {
t.Fatalf("DialUDP() error = %v", err)
}
sender, err := NewUDPConn(
peerRaw,
nil,
WithUDPLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"),
)
if err != nil {
_ = peerRaw.Close()
t.Fatalf("NewUDPConn(sender) error = %v", err)
}
t.Cleanup(func() { _ = sender.Close() })
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg)
}()
got, _, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID)
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID)
assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID)
})
}
}

View File

@@ -0,0 +1,358 @@
package transport
import (
"net"
"reflect"
"strings"
"sync"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// TestUDPSendReceiveMessage 验证 UDP transport 可以正常收发 text 和 file 消息。
func TestUDPSendReceiveMessage(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello udp"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0x10, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sender, receiver := newUDPConnPair(t, nil, nil)
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg)
}()
got, _, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
})
}
}
// TestUDPSendLogsHandoffEvents 验证 UDP Send 会记录 handoff 事件。
func TestUDPSendLogsHandoffEvents(t *testing.T) {
logger := &recordingLogger{}
sender, receiver := newUDPConnPair(
t,
[]UDPOption{WithUDPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
)
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
got, _, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) < 2 {
t.Fatalf("event count = %d, want at least 2", len(events))
}
if events[0].Event != latencylog.EventSendHandoffBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
}
// 最后一个事件应该是 SendHandoffEnd
lastEvent := events[len(events)-1]
if lastEvent.Event != latencylog.EventSendHandoffEnd {
t.Fatalf("last event = %q, want %q", lastEvent.Event, latencylog.EventSendHandoffEnd)
}
}
// TestUDPReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
sender, receiver := newUDPConnPair(t, nil, nil)
want := []protocol.Message{
{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
}
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error {
mu.Lock()
defer mu.Unlock()
got = append(got, msg)
if len(got) >= len(want) {
return nil
}
return nil
})
}()
for _, msg := range want {
if err := sender.Send(msg); err != nil {
t.Fatalf("Send() error = %v", err)
}
}
// 关闭发送端ReceiveLoop 会因读取错误退出
if err := sender.Close(); err != nil {
t.Fatalf("sender.Close() error = %v", err)
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
}
if !strings.Contains(err.Error(), "udp receive loop read") {
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
}
mu.Lock()
defer mu.Unlock()
if !reflect.DeepEqual(got, want) {
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
}
}
// TestUDPCloseIsIdempotent 验证 Close 可以安全地被重复调用。
func TestUDPCloseIsIdempotent(t *testing.T) {
conn, peer := newUDPConnPair(t, nil, nil)
if err := conn.Close(); err != nil {
t.Fatalf("Close(first) error = %v", err)
}
if err := conn.Close(); err != nil {
t.Fatalf("Close(second) error = %v, want nil", err)
}
_ = peer.Close()
}
// TestUDPSendToMessage 验证 SendTo 可以向指定地址发送消息。
func TestUDPSendToMessage(t *testing.T) {
serverConn := newUDPListener(t)
peerConn := newUDPDialed(t, serverConn.conn.LocalAddr().(*net.UDPAddr))
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "server",
Body: []byte("hello sendto"),
}
// peer 发送消息到 server
sendErr := make(chan error, 1)
go func() {
sendErr <- peerConn.Send(msg)
}()
got, addr, err := serverConn.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
// server 用 SendTo 回复到 peer 地址
reply := protocol.Message{
Type: protocol.MessageTypeText,
ID: 2,
From: "server",
To: "peer-a",
Body: []byte("reply"),
}
sendErr2 := make(chan error, 1)
go func() {
sendErr2 <- serverConn.SendTo(reply, addr)
}()
gotReply, _, err := peerConn.Receive()
if err != nil {
t.Fatalf("peer Receive() error = %v", err)
}
if err := <-sendErr2; err != nil {
t.Fatalf("SendTo() error = %v", err)
}
if !reflect.DeepEqual(gotReply, reply) {
t.Fatalf("reply mismatch: got %+v want %+v", gotReply, reply)
}
}
// newUDPConnPair 创建一对互相连接的 UDP transport 连接,用于测试。
func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOption) (*UDPConn, *UDPConn) {
t.Helper()
// 创建两个 UDP socket通过 Dial 互相连接
addr1, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
addr2, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
conn1, err := net.ListenUDP("udp", addr1)
if err != nil {
t.Fatalf("ListenUDP(1) error = %v", err)
}
conn2, err := net.ListenUDP("udp", addr2)
if err != nil {
_ = conn1.Close()
t.Fatalf("ListenUDP(2) error = %v", err)
}
// 用 Dial 模式连接对端
senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = conn1.Close()
_ = conn2.Close()
t.Fatalf("DialUDP(sender) error = %v", err)
}
_ = conn1.Close() // 不再需要 conn1
receiverRaw, err := net.DialUDP("udp", conn2.LocalAddr().(*net.UDPAddr), senderRaw.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = senderRaw.Close()
_ = conn2.Close()
t.Fatalf("DialUDP(receiver) error = %v", err)
}
_ = conn2.Close() // 不再需要 conn2
sender, err := NewUDPConn(senderRaw, nil, senderOpts...)
if err != nil {
_ = senderRaw.Close()
_ = receiverRaw.Close()
t.Fatalf("NewUDPConn(sender) error = %v", err)
}
receiver, err := NewUDPConn(receiverRaw, nil, receiverOpts...)
if err != nil {
_ = sender.Close()
_ = receiverRaw.Close()
t.Fatalf("NewUDPConn(receiver) error = %v", err)
}
t.Cleanup(func() {
_ = sender.Close()
_ = receiver.Close()
})
return sender, receiver
}
// newUDPListener 创建一个监听模式的 UDP 连接,用于测试 server 场景。
func newUDPListener(t *testing.T) *UDPConn {
t.Helper()
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
raw, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("ListenUDP() error = %v", err)
}
conn, err := NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}
// newUDPDialed 创建一个已连接到指定地址的 UDP 连接,用于测试 peer 场景。
func newUDPDialed(t *testing.T, serverAddr *net.UDPAddr) *UDPConn {
t.Helper()
raw, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
t.Fatalf("DialUDP() error = %v", err)
}
conn, err := NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}