285 lines
6.5 KiB
Go
285 lines
6.5 KiB
Go
package transport
|
|
|
|
import (
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"omnisocketgo/cmd/internal/latencylog"
|
|
"omnisocketgo/cmd/internal/protocol"
|
|
)
|
|
|
|
type recordingKCPPacketDebugLogger struct {
|
|
mu sync.Mutex
|
|
records []KCPPacketDebugRecord
|
|
}
|
|
|
|
func (l *recordingKCPPacketDebugLogger) LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
l.records = append(l.records, record)
|
|
return nil
|
|
}
|
|
|
|
func (l *recordingKCPPacketDebugLogger) Records() []KCPPacketDebugRecord {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
return append([]KCPPacketDebugRecord(nil), l.records...)
|
|
}
|
|
|
|
type kcpAcceptResult struct {
|
|
conn *KCPConn
|
|
err error
|
|
}
|
|
|
|
func TestKCPSendReceiveMessage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg protocol.Message
|
|
}{
|
|
{
|
|
name: "text",
|
|
msg: protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 1,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
Body: []byte("hello kcp"),
|
|
},
|
|
},
|
|
{
|
|
name: "file",
|
|
msg: protocol.Message{
|
|
Type: protocol.MessageTypeFile,
|
|
ID: 2,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
FileName: "payload.bin",
|
|
Body: []byte{0x00, 0x10, 0xff},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
sender, accepted, cleanup := newKCPConnPair(
|
|
t,
|
|
nil,
|
|
[]KCPOption{WithKCPLogger(latencylog.NoopLogger{}, latencylog.NodeRolePeer, "peer-b")},
|
|
nil,
|
|
nil,
|
|
)
|
|
defer cleanup()
|
|
|
|
sendErr := make(chan error, 1)
|
|
go func() {
|
|
sendErr <- sender.Send(tt.msg)
|
|
}()
|
|
|
|
receiver := awaitAcceptedKCPConn(t, accepted)
|
|
got, err := receiver.Receive()
|
|
if err != nil {
|
|
t.Fatalf("receiver.Receive() error = %v", err)
|
|
}
|
|
if err := <-sendErr; err != nil {
|
|
t.Fatalf("sender.Send() error = %v", err)
|
|
}
|
|
if !reflect.DeepEqual(got, tt.msg) {
|
|
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKCPSendLogsHandoffEvents(t *testing.T) {
|
|
logger := &recordingLogger{}
|
|
sender, accepted, cleanup := newKCPConnPair(
|
|
t,
|
|
[]KCPOption{WithKCPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
|
|
nil,
|
|
nil,
|
|
nil,
|
|
)
|
|
defer cleanup()
|
|
|
|
msg := protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 7,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
Body: []byte("hello"),
|
|
}
|
|
|
|
sendErr := make(chan error, 1)
|
|
go func() {
|
|
sendErr <- sender.Send(msg)
|
|
}()
|
|
|
|
receiver := awaitAcceptedKCPConn(t, accepted)
|
|
got, err := receiver.Receive()
|
|
if err != nil {
|
|
t.Fatalf("receiver.Receive() error = %v", err)
|
|
}
|
|
if err := <-sendErr; err != nil {
|
|
t.Fatalf("sender.Send() error = %v", err)
|
|
}
|
|
if !reflect.DeepEqual(got, msg) {
|
|
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
|
}
|
|
|
|
events := logger.Events()
|
|
if len(events) != 2 {
|
|
t.Fatalf("event count = %d, want 2", len(events))
|
|
}
|
|
if events[0].Event != latencylog.EventSendHandoffBegin {
|
|
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
|
|
}
|
|
if events[1].Event != latencylog.EventSendHandoffEnd {
|
|
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventSendHandoffEnd)
|
|
}
|
|
}
|
|
|
|
func TestKCPReceiveLoopStopsOnClose(t *testing.T) {
|
|
sender, accepted, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
|
|
defer cleanup()
|
|
|
|
msg := protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 1,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
Body: []byte("hello"),
|
|
}
|
|
|
|
sendErr := make(chan error, 1)
|
|
go func() {
|
|
sendErr <- sender.Send(msg)
|
|
}()
|
|
|
|
receiver := awaitAcceptedKCPConn(t, accepted)
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
got []protocol.Message
|
|
)
|
|
loopErr := make(chan error, 1)
|
|
go func() {
|
|
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
|
mu.Lock()
|
|
got = append(got, msg)
|
|
mu.Unlock()
|
|
return receiver.Close()
|
|
})
|
|
}()
|
|
|
|
if err := <-sendErr; err != nil {
|
|
t.Fatalf("sender.Send() error = %v", err)
|
|
}
|
|
|
|
err := <-loopErr
|
|
if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "pipe")) {
|
|
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if len(got) != 1 || !reflect.DeepEqual(got[0], msg) {
|
|
t.Fatalf("received messages mismatch: got %+v want [%+v]", got, msg)
|
|
}
|
|
}
|
|
|
|
func TestKCPCloseIsIdempotent(t *testing.T) {
|
|
sender, _, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
|
|
defer cleanup()
|
|
|
|
if err := sender.Close(); err != nil {
|
|
t.Fatalf("Close(first) error = %v", err)
|
|
}
|
|
if err := sender.Close(); err != nil {
|
|
t.Fatalf("Close(second) error = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
func newKCPConnPair(t *testing.T, senderOpts []KCPOption, receiverOpts []KCPOption, senderPacketLogger KCPPacketDebugLogger, receiverPacketLogger KCPPacketDebugLogger) (*KCPConn, <-chan kcpAcceptResult, func()) {
|
|
t.Helper()
|
|
|
|
listener, packetConn, err := ListenKCPSessions("127.0.0.1:0", "", receiverPacketLogger, latencylog.NodeRolePeer, "peer-b")
|
|
if err != nil {
|
|
t.Fatalf("ListenKCPSessions() error = %v", err)
|
|
}
|
|
|
|
accepted := make(chan kcpAcceptResult, 1)
|
|
go func() {
|
|
session, acceptErr := listener.AcceptKCP()
|
|
if acceptErr != nil {
|
|
accepted <- kcpAcceptResult{err: acceptErr}
|
|
return
|
|
}
|
|
|
|
conn, connErr := NewKCPConn(session, receiverOpts...)
|
|
accepted <- kcpAcceptResult{conn: conn, err: connErr}
|
|
}()
|
|
|
|
session, err := DialKCPSession(listener.Addr().String(), "", "", senderPacketLogger, latencylog.NodeRolePeer, "peer-a")
|
|
if err != nil {
|
|
_ = packetConn.Close()
|
|
_ = listener.Close()
|
|
t.Fatalf("DialKCPSession() error = %v", err)
|
|
}
|
|
|
|
sender, err := NewKCPConn(session, senderOpts...)
|
|
if err != nil {
|
|
_ = session.Close()
|
|
_ = packetConn.Close()
|
|
_ = listener.Close()
|
|
t.Fatalf("NewKCPConn(sender) error = %v", err)
|
|
}
|
|
|
|
cleanup := func() {
|
|
_ = sender.Close()
|
|
select {
|
|
case result := <-accepted:
|
|
if result.conn != nil {
|
|
_ = result.conn.Close()
|
|
}
|
|
default:
|
|
}
|
|
_ = listener.Close()
|
|
_ = packetConn.Close()
|
|
}
|
|
|
|
return sender, accepted, cleanup
|
|
}
|
|
|
|
func awaitAcceptedKCPConn(t *testing.T, accepted <-chan kcpAcceptResult) *KCPConn {
|
|
t.Helper()
|
|
|
|
result := <-accepted
|
|
if result.err != nil {
|
|
t.Fatalf("AcceptKCP() error = %v", result.err)
|
|
}
|
|
if result.conn == nil {
|
|
t.Fatal("accepted KCP conn = nil")
|
|
}
|
|
return result.conn
|
|
}
|
|
|
|
func waitForKCPPacketRecords(t *testing.T, logger *recordingKCPPacketDebugLogger, condition func([]KCPPacketDebugRecord) bool, description string) {
|
|
t.Helper()
|
|
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
records := logger.Records()
|
|
if condition(records) {
|
|
return
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
t.Fatalf("timed out waiting for %s", description)
|
|
}
|