Files
nnbcccscdscdsc be013b701b feat:KCP协议
2026-03-24 21:09:06 +08:00

355 lines
8.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package transport
import (
"net"
"reflect"
"strings"
"sync"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// TestUDPSendReceiveMessage 验证 UDP transport 可以正常收发 text 和 file 消息。
func TestUDPSendReceiveMessage(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello udp"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0x10, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sender, receiver := newUDPConnPair(t, nil, nil)
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg)
}()
got, _, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
})
}
}
// TestUDPSendLogsHandoffEvents 验证 UDP Send 会记录 handoff 事件。
func TestUDPSendLogsHandoffEvents(t *testing.T) {
logger := &recordingLogger{}
sender, receiver := newUDPConnPair(
t,
[]UDPOption{WithUDPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
)
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
got, _, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) < 2 {
t.Fatalf("event count = %d, want at least 2", len(events))
}
if events[0].Event != latencylog.EventSendHandoffBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
}
// 最后一个事件应该是 SendHandoffEnd
lastEvent := events[len(events)-1]
if lastEvent.Event != latencylog.EventSendHandoffEnd {
t.Fatalf("last event = %q, want %q", lastEvent.Event, latencylog.EventSendHandoffEnd)
}
}
// TestUDPReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
sender, receiver := newUDPConnPair(t, nil, nil)
want := []protocol.Message{
{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
}
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error {
mu.Lock()
got = append(got, msg)
done := len(got) >= len(want)
mu.Unlock()
if done {
return receiver.Close()
}
return nil
})
}()
for _, msg := range want {
if err := sender.Send(msg); err != nil {
t.Fatalf("Send() error = %v", err)
}
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil after receiver close")
}
if !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection") {
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
}
mu.Lock()
defer mu.Unlock()
if !reflect.DeepEqual(got, want) {
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
}
}
// TestUDPCloseIsIdempotent 验证 Close 可以安全地被重复调用。
func TestUDPCloseIsIdempotent(t *testing.T) {
conn, peer := newUDPConnPair(t, nil, nil)
if err := conn.Close(); err != nil {
t.Fatalf("Close(first) error = %v", err)
}
if err := conn.Close(); err != nil {
t.Fatalf("Close(second) error = %v, want nil", err)
}
_ = peer.Close()
}
// TestUDPSendToMessage 验证 SendTo 可以向指定地址发送消息。
func TestUDPSendToMessage(t *testing.T) {
serverConn := newUDPListener(t)
peerConn := newUDPDialed(t, serverConn.conn.LocalAddr().(*net.UDPAddr))
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "server",
Body: []byte("hello sendto"),
}
// peer 发送消息到 server
sendErr := make(chan error, 1)
go func() {
sendErr <- peerConn.Send(msg)
}()
got, addr, err := serverConn.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
// server 用 SendTo 回复到 peer 地址
reply := protocol.Message{
Type: protocol.MessageTypeText,
ID: 2,
From: "server",
To: "peer-a",
Body: []byte("reply"),
}
sendErr2 := make(chan error, 1)
go func() {
sendErr2 <- serverConn.SendTo(reply, addr)
}()
gotReply, _, err := peerConn.Receive()
if err != nil {
t.Fatalf("peer Receive() error = %v", err)
}
if err := <-sendErr2; err != nil {
t.Fatalf("SendTo() error = %v", err)
}
if !reflect.DeepEqual(gotReply, reply) {
t.Fatalf("reply mismatch: got %+v want %+v", gotReply, reply)
}
}
// newUDPConnPair 创建一对互相连接的 UDP transport 连接,用于测试。
func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOption) (*UDPConn, *UDPConn) {
t.Helper()
// 创建两个 UDP socket通过 Dial 互相连接
addr1, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
addr2, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
conn1, err := net.ListenUDP("udp", addr1)
if err != nil {
t.Fatalf("ListenUDP(1) error = %v", err)
}
conn2, err := net.ListenUDP("udp", addr2)
if err != nil {
_ = conn1.Close()
t.Fatalf("ListenUDP(2) error = %v", err)
}
receiverLocalAddr := conn2.LocalAddr().(*net.UDPAddr)
// 用 Dial 模式连接对端
senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = conn1.Close()
_ = conn2.Close()
t.Fatalf("DialUDP(sender) error = %v", err)
}
_ = conn1.Close() // 不再需要 conn1
_ = conn2.Close() // 释放 receiver 计划使用的本地地址
receiverRaw, err := net.DialUDP("udp", receiverLocalAddr, senderRaw.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = senderRaw.Close()
t.Fatalf("DialUDP(receiver) error = %v", err)
}
sender, err := NewUDPConn(senderRaw, nil, senderOpts...)
if err != nil {
_ = senderRaw.Close()
_ = receiverRaw.Close()
t.Fatalf("NewUDPConn(sender) error = %v", err)
}
receiver, err := NewUDPConn(receiverRaw, nil, receiverOpts...)
if err != nil {
_ = sender.Close()
_ = receiverRaw.Close()
t.Fatalf("NewUDPConn(receiver) error = %v", err)
}
t.Cleanup(func() {
_ = sender.Close()
_ = receiver.Close()
})
return sender, receiver
}
// newUDPListener 创建一个监听模式的 UDP 连接,用于测试 server 场景。
func newUDPListener(t *testing.T) *UDPConn {
t.Helper()
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
raw, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("ListenUDP() error = %v", err)
}
conn, err := NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}
// newUDPDialed 创建一个已连接到指定地址的 UDP 连接,用于测试 peer 场景。
func newUDPDialed(t *testing.T, serverAddr *net.UDPAddr) *UDPConn {
t.Helper()
raw, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
t.Fatalf("DialUDP() error = %v", err)
}
conn, err := NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}