Files
OmniSocketGo/cmd/internal/transport/udp_test.go
2026-03-24 15:39:00 +08:00

359 lines
8.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package transport
import (
"net"
"reflect"
"strings"
"sync"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// TestUDPSendReceiveMessage 验证 UDP transport 可以正常收发 text 和 file 消息。
func TestUDPSendReceiveMessage(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello udp"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0x10, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sender, receiver := newUDPConnPair(t, nil, nil)
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg)
}()
got, _, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
})
}
}
// TestUDPSendLogsHandoffEvents 验证 UDP Send 会记录 handoff 事件。
func TestUDPSendLogsHandoffEvents(t *testing.T) {
logger := &recordingLogger{}
sender, receiver := newUDPConnPair(
t,
[]UDPOption{WithUDPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
)
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
got, _, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) < 2 {
t.Fatalf("event count = %d, want at least 2", len(events))
}
if events[0].Event != latencylog.EventSendHandoffBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
}
// 最后一个事件应该是 SendHandoffEnd
lastEvent := events[len(events)-1]
if lastEvent.Event != latencylog.EventSendHandoffEnd {
t.Fatalf("last event = %q, want %q", lastEvent.Event, latencylog.EventSendHandoffEnd)
}
}
// TestUDPReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
sender, receiver := newUDPConnPair(t, nil, nil)
want := []protocol.Message{
{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
}
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error {
mu.Lock()
defer mu.Unlock()
got = append(got, msg)
if len(got) >= len(want) {
return nil
}
return nil
})
}()
for _, msg := range want {
if err := sender.Send(msg); err != nil {
t.Fatalf("Send() error = %v", err)
}
}
// 关闭发送端ReceiveLoop 会因读取错误退出
if err := sender.Close(); err != nil {
t.Fatalf("sender.Close() error = %v", err)
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
}
if !strings.Contains(err.Error(), "udp receive loop read") {
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
}
mu.Lock()
defer mu.Unlock()
if !reflect.DeepEqual(got, want) {
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
}
}
// TestUDPCloseIsIdempotent 验证 Close 可以安全地被重复调用。
func TestUDPCloseIsIdempotent(t *testing.T) {
conn, peer := newUDPConnPair(t, nil, nil)
if err := conn.Close(); err != nil {
t.Fatalf("Close(first) error = %v", err)
}
if err := conn.Close(); err != nil {
t.Fatalf("Close(second) error = %v, want nil", err)
}
_ = peer.Close()
}
// TestUDPSendToMessage 验证 SendTo 可以向指定地址发送消息。
func TestUDPSendToMessage(t *testing.T) {
serverConn := newUDPListener(t)
peerConn := newUDPDialed(t, serverConn.conn.LocalAddr().(*net.UDPAddr))
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "server",
Body: []byte("hello sendto"),
}
// peer 发送消息到 server
sendErr := make(chan error, 1)
go func() {
sendErr <- peerConn.Send(msg)
}()
got, addr, err := serverConn.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
// server 用 SendTo 回复到 peer 地址
reply := protocol.Message{
Type: protocol.MessageTypeText,
ID: 2,
From: "server",
To: "peer-a",
Body: []byte("reply"),
}
sendErr2 := make(chan error, 1)
go func() {
sendErr2 <- serverConn.SendTo(reply, addr)
}()
gotReply, _, err := peerConn.Receive()
if err != nil {
t.Fatalf("peer Receive() error = %v", err)
}
if err := <-sendErr2; err != nil {
t.Fatalf("SendTo() error = %v", err)
}
if !reflect.DeepEqual(gotReply, reply) {
t.Fatalf("reply mismatch: got %+v want %+v", gotReply, reply)
}
}
// newUDPConnPair 创建一对互相连接的 UDP transport 连接,用于测试。
func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOption) (*UDPConn, *UDPConn) {
t.Helper()
// 创建两个 UDP socket通过 Dial 互相连接
addr1, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
addr2, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
conn1, err := net.ListenUDP("udp", addr1)
if err != nil {
t.Fatalf("ListenUDP(1) error = %v", err)
}
conn2, err := net.ListenUDP("udp", addr2)
if err != nil {
_ = conn1.Close()
t.Fatalf("ListenUDP(2) error = %v", err)
}
// 用 Dial 模式连接对端
senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = conn1.Close()
_ = conn2.Close()
t.Fatalf("DialUDP(sender) error = %v", err)
}
_ = conn1.Close() // 不再需要 conn1
receiverRaw, err := net.DialUDP("udp", conn2.LocalAddr().(*net.UDPAddr), senderRaw.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = senderRaw.Close()
_ = conn2.Close()
t.Fatalf("DialUDP(receiver) error = %v", err)
}
_ = conn2.Close() // 不再需要 conn2
sender, err := NewUDPConn(senderRaw, nil, senderOpts...)
if err != nil {
_ = senderRaw.Close()
_ = receiverRaw.Close()
t.Fatalf("NewUDPConn(sender) error = %v", err)
}
receiver, err := NewUDPConn(receiverRaw, nil, receiverOpts...)
if err != nil {
_ = sender.Close()
_ = receiverRaw.Close()
t.Fatalf("NewUDPConn(receiver) error = %v", err)
}
t.Cleanup(func() {
_ = sender.Close()
_ = receiver.Close()
})
return sender, receiver
}
// newUDPListener 创建一个监听模式的 UDP 连接,用于测试 server 场景。
func newUDPListener(t *testing.T) *UDPConn {
t.Helper()
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
raw, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("ListenUDP() error = %v", err)
}
conn, err := NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}
// newUDPDialed 创建一个已连接到指定地址的 UDP 连接,用于测试 peer 场景。
func newUDPDialed(t *testing.T, serverAddr *net.UDPAddr) *UDPConn {
t.Helper()
raw, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
t.Fatalf("DialUDP() error = %v", err)
}
conn, err := NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}