Files
OmniSocketGo/cmd/internal/transport/kcp_test.go
nnbcccscdscdsc be013b701b feat:KCP协议
2026-03-24 21:09:06 +08:00

285 lines
6.5 KiB
Go

package transport
import (
"reflect"
"strings"
"sync"
"testing"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
type recordingKCPPacketDebugLogger struct {
mu sync.Mutex
records []KCPPacketDebugRecord
}
func (l *recordingKCPPacketDebugLogger) LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error {
l.mu.Lock()
defer l.mu.Unlock()
l.records = append(l.records, record)
return nil
}
func (l *recordingKCPPacketDebugLogger) Records() []KCPPacketDebugRecord {
l.mu.Lock()
defer l.mu.Unlock()
return append([]KCPPacketDebugRecord(nil), l.records...)
}
type kcpAcceptResult struct {
conn *KCPConn
err error
}
func TestKCPSendReceiveMessage(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello kcp"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x00, 0x10, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sender, accepted, cleanup := newKCPConnPair(
t,
nil,
[]KCPOption{WithKCPLogger(latencylog.NoopLogger{}, latencylog.NodeRolePeer, "peer-b")},
nil,
nil,
)
defer cleanup()
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg)
}()
receiver := awaitAcceptedKCPConn(t, accepted)
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
})
}
}
func TestKCPSendLogsHandoffEvents(t *testing.T) {
logger := &recordingLogger{}
sender, accepted, cleanup := newKCPConnPair(
t,
[]KCPOption{WithKCPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
nil,
nil,
)
defer cleanup()
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
receiver := awaitAcceptedKCPConn(t, accepted)
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventSendHandoffBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
}
if events[1].Event != latencylog.EventSendHandoffEnd {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventSendHandoffEnd)
}
}
func TestKCPReceiveLoopStopsOnClose(t *testing.T) {
sender, accepted, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
defer cleanup()
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
receiver := awaitAcceptedKCPConn(t, accepted)
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
mu.Lock()
got = append(got, msg)
mu.Unlock()
return receiver.Close()
})
}()
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
err := <-loopErr
if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "pipe")) {
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
}
mu.Lock()
defer mu.Unlock()
if len(got) != 1 || !reflect.DeepEqual(got[0], msg) {
t.Fatalf("received messages mismatch: got %+v want [%+v]", got, msg)
}
}
func TestKCPCloseIsIdempotent(t *testing.T) {
sender, _, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
defer cleanup()
if err := sender.Close(); err != nil {
t.Fatalf("Close(first) error = %v", err)
}
if err := sender.Close(); err != nil {
t.Fatalf("Close(second) error = %v, want nil", err)
}
}
func newKCPConnPair(t *testing.T, senderOpts []KCPOption, receiverOpts []KCPOption, senderPacketLogger KCPPacketDebugLogger, receiverPacketLogger KCPPacketDebugLogger) (*KCPConn, <-chan kcpAcceptResult, func()) {
t.Helper()
listener, packetConn, err := ListenKCPSessions("127.0.0.1:0", "", receiverPacketLogger, latencylog.NodeRolePeer, "peer-b")
if err != nil {
t.Fatalf("ListenKCPSessions() error = %v", err)
}
accepted := make(chan kcpAcceptResult, 1)
go func() {
session, acceptErr := listener.AcceptKCP()
if acceptErr != nil {
accepted <- kcpAcceptResult{err: acceptErr}
return
}
conn, connErr := NewKCPConn(session, receiverOpts...)
accepted <- kcpAcceptResult{conn: conn, err: connErr}
}()
session, err := DialKCPSession(listener.Addr().String(), "", "", senderPacketLogger, latencylog.NodeRolePeer, "peer-a")
if err != nil {
_ = packetConn.Close()
_ = listener.Close()
t.Fatalf("DialKCPSession() error = %v", err)
}
sender, err := NewKCPConn(session, senderOpts...)
if err != nil {
_ = session.Close()
_ = packetConn.Close()
_ = listener.Close()
t.Fatalf("NewKCPConn(sender) error = %v", err)
}
cleanup := func() {
_ = sender.Close()
select {
case result := <-accepted:
if result.conn != nil {
_ = result.conn.Close()
}
default:
}
_ = listener.Close()
_ = packetConn.Close()
}
return sender, accepted, cleanup
}
func awaitAcceptedKCPConn(t *testing.T, accepted <-chan kcpAcceptResult) *KCPConn {
t.Helper()
result := <-accepted
if result.err != nil {
t.Fatalf("AcceptKCP() error = %v", result.err)
}
if result.conn == nil {
t.Fatal("accepted KCP conn = nil")
}
return result.conn
}
func waitForKCPPacketRecords(t *testing.T, logger *recordingKCPPacketDebugLogger, condition func([]KCPPacketDebugRecord) bool, description string) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
records := logger.Records()
if condition(records) {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for %s", description)
}