Files
OmniSocketGo/cmd/internal/peer/kcp_client_test.go
2026-03-27 23:03:00 +08:00

637 lines
17 KiB
Go

package peer
import (
"bytes"
"net"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"sync/atomic"
"testing"
kcp "github.com/xtaci/kcp-go/v5"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/server"
"omnisocketgo/cmd/internal/transport"
)
func TestKCPDialRegistersPeer(t *testing.T) {
hub := server.NewKCPHub()
serverAddr, cleanup := startRealKCPHubServer(t, hub)
defer cleanup()
client, err := DialKCP(serverAddr, "peer-a")
if err != nil {
t.Fatalf("DialKCP() error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
}
func TestKCPDialRejectsInvalidBindIP(t *testing.T) {
_, err := DialKCP("127.0.0.1:9002", "peer-a", WithBindIP("not-an-ip"))
if err == nil {
t.Fatal("DialKCP() error = nil, want invalid bind ip error")
}
if !strings.Contains(err.Error(), `invalid bind ip "not-an-ip"`) {
t.Fatalf("DialKCP() error = %v, want invalid bind ip error", err)
}
}
func TestKCPClientsExchangeTextAndFileMessages(t *testing.T) {
hub := server.NewKCPHub()
serverAddr, cleanup := startRealKCPHubServer(t, hub)
defer cleanup()
peerA, err := DialKCP(serverAddr, "peer-a")
if err != nil {
t.Fatalf("DialKCP(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := DialKCP(serverAddr, "peer-b")
if err != nil {
t.Fatalf("DialKCP(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
received := make(chan protocol.Message, 2)
receiveErr := make(chan error, 1)
go func() {
for i := 0; i < 2; i++ {
msg, err := peerB.Receive()
if err != nil {
receiveErr <- err
return
}
received <- msg
}
receiveErr <- nil
}()
if err := peerA.SendText("peer-b", "hello over kcp"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
fileBody := []byte{0x01, 0x02, 0x03}
if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil {
t.Fatalf("SendFile() error = %v", err)
}
if err := <-receiveErr; err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
gotFirst := <-received
wantFirst := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello over kcp"),
}
if !reflect.DeepEqual(gotFirst, wantFirst) {
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst)
}
gotSecond := <-received
wantSecond := protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: fileBody,
}
if !reflect.DeepEqual(gotSecond, wantSecond) {
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond)
}
}
func TestKCPClientReceivesServerErrorForUnknownTarget(t *testing.T) {
hub := server.NewKCPHub()
serverAddr, cleanup := startRealKCPHubServer(t, hub)
defer cleanup()
client, err := DialKCP(serverAddr, "peer-a")
if err != nil {
t.Fatalf("DialKCP() error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.SendText("missing-peer", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "unknown target: missing-peer" {
t.Fatalf("error body = %q, want unknown target message", got.Body)
}
}
func TestKCPClientsExchangeMessagesWithLatencyLogs(t *testing.T) {
hub := server.NewKCPHub()
serverAddr, cleanup := startRealKCPHubServer(t, hub)
defer cleanup()
peerALogger := &recordingLogger{}
peerA, err := DialKCP(serverAddr, "peer-a", WithLogger(peerALogger))
if err != nil {
t.Fatalf("DialKCP(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerBLogger := &recordingLogger{}
peerB, err := DialKCP(serverAddr, "peer-b", WithLogger(peerBLogger))
if err != nil {
t.Fatalf("DialKCP(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
inboxDir := t.TempDir()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
textMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(text) error = %v", err)
}
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
}
filePath := filepath.Join(t.TempDir(), "payload.bin")
if err := os.WriteFile(filePath, []byte{0x01, 0x02, 0x03}, 0o644); err != nil {
t.Fatalf("os.WriteFile() error = %v", err)
}
if err := peerA.SendFilePath("peer-b", filePath); err != nil {
t.Fatalf("SendFilePath() error = %v", err)
}
fileMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(file) error = %v", err)
}
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
}
waitFor(t, func() bool { return len(peerALogger.Events()) == 6 }, "peer-a latency events")
waitFor(t, func() bool { return len(peerBLogger.Events()) == 6 }, "peer-b latency events")
assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{
1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd},
2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd},
})
assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{
1: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
2: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
})
}
func TestKCPClientsExchangeMessagesAcrossRelayedServers(t *testing.T) {
fixture := startRelayedKCPHubs(t)
defer fixture.cleanup()
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
if err != nil {
t.Fatalf("DialKCP(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := DialKCP(fixture.serverDAddr, "peer-b")
if err != nil {
t.Fatalf("DialKCP(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") && fixture.hubD.HasPeer("peer-b") }, "both relayed peers to be registered")
if err := peerA.SendText("peer-b", "hello via relay"); err != nil {
t.Fatalf("peerA.SendText() error = %v", err)
}
gotAtB, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
wantAtB := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello via relay"),
}
if !reflect.DeepEqual(gotAtB, wantAtB) {
t.Fatalf("peerB received %+v, want %+v", gotAtB, wantAtB)
}
if err := peerB.SendText("peer-a", "hello back"); err != nil {
t.Fatalf("peerB.SendText() error = %v", err)
}
gotAtA, err := peerA.Receive()
if err != nil {
t.Fatalf("peerA.Receive() error = %v", err)
}
wantAtA := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-b",
To: "peer-a",
Body: []byte("hello back"),
}
if !reflect.DeepEqual(gotAtA, wantAtA) {
t.Fatalf("peerA received %+v, want %+v", gotAtA, wantAtA)
}
if got := fixture.relayC.WriteCount(); got != 1 {
t.Fatalf("relayC write count = %d, want 1", got)
}
if got := fixture.relayD.WriteCount(); got != 1 {
t.Fatalf("relayD write count = %d, want 1", got)
}
}
func TestKCPClientsExchangeMessagesViaUDPRelayToSingleHub(t *testing.T) {
hub := server.NewKCPHub()
serverAddr, cleanupHub := startRealKCPHubServer(t, hub)
defer cleanupHub()
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
t.Fatalf("ResolveUDPAddr(server) error = %v", err)
}
baseRelayConn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket(relay) error = %v", err)
}
relayConn := &countingPacketConn{PacketConn: baseRelayConn}
relay, err := server.NewUDPRelay(relayConn, remoteAddr)
if err != nil {
_ = relayConn.Close()
t.Fatalf("NewUDPRelay() error = %v", err)
}
var relayWG sync.WaitGroup
relayWG.Add(1)
go func() {
defer relayWG.Done()
if serveErr := relay.Serve(); serveErr != nil {
t.Errorf("relay.Serve() error = %v", serveErr)
}
}()
defer func() {
_ = relayConn.Close()
relayWG.Wait()
}()
peerA, err := DialKCP(serverAddr, "peer-a", WithKCPDialAddress(relayConn.LocalAddr().String()))
if err != nil {
t.Fatalf("DialKCP(peer-a via relay) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := DialKCP(serverAddr, "peer-b")
if err != nil {
t.Fatalf("DialKCP(peer-b direct) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered on the single hub")
if err := peerB.SendText("peer-a", "hello via udp relay"); err != nil {
t.Fatalf("peerB.SendText() error = %v", err)
}
gotAtA, err := peerA.Receive()
if err != nil {
t.Fatalf("peerA.Receive() error = %v", err)
}
wantAtA := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-b",
To: "peer-a",
Body: []byte("hello via udp relay"),
}
if !reflect.DeepEqual(gotAtA, wantAtA) {
t.Fatalf("peerA received %+v, want %+v", gotAtA, wantAtA)
}
if err := peerA.SendText("peer-b", "hello back through relay"); err != nil {
t.Fatalf("peerA.SendText() error = %v", err)
}
gotAtB, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
wantAtB := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello back through relay"),
}
if !reflect.DeepEqual(gotAtB, wantAtB) {
t.Fatalf("peerB received %+v, want %+v", gotAtB, wantAtB)
}
if got := relayConn.WriteCount(); got == 0 {
t.Fatal("relay should have forwarded packets for peer-a session")
}
}
func TestKCPHubPrefersLocalPeerBeforeRelay(t *testing.T) {
fixture := startRelayedKCPHubs(t)
defer fixture.cleanup()
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
if err != nil {
t.Fatalf("DialKCP(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := DialKCP(fixture.serverCAddr, "peer-b")
if err != nil {
t.Fatalf("DialKCP(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") && fixture.hubC.HasPeer("peer-b") }, "local peers on hubC to be registered")
if err := peerA.SendText("peer-b", "local delivery"); err != nil {
t.Fatalf("peerA.SendText() error = %v", err)
}
got, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
want := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("local delivery"),
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("peerB received %+v, want %+v", got, want)
}
if got := fixture.relayC.WriteCount(); got != 0 {
t.Fatalf("relayC write count = %d, want 0 for local delivery", got)
}
if got := fixture.relayD.WriteCount(); got != 0 {
t.Fatalf("relayD write count = %d, want 0 for local delivery", got)
}
}
func TestKCPRelayedUnknownTargetReturnsErrorToOriginalSender(t *testing.T) {
fixture := startRelayedKCPHubs(t)
defer fixture.cleanup()
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
if err != nil {
t.Fatalf("DialKCP(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") }, "peer-a to be registered on hubC")
if err := peerA.SendText("remote-missing", "hello"); err != nil {
t.Fatalf("peerA.SendText() error = %v", err)
}
got, err := peerA.Receive()
if err != nil {
t.Fatalf("peerA.Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
}
if got.From != protocol.ServerPeerID {
t.Fatalf("error from = %s, want %s", got.From, protocol.ServerPeerID)
}
if got.To != "peer-a" {
t.Fatalf("error to = %s, want peer-a", got.To)
}
if string(got.Body) != "unknown target: remote-missing" {
t.Fatalf("error body = %q, want unknown target from relayed hub", got.Body)
}
if got := fixture.relayC.WriteCount(); got != 1 {
t.Fatalf("relayC write count = %d, want 1 for outbound relay", got)
}
if got := fixture.relayD.WriteCount(); got != 1 {
t.Fatalf("relayD write count = %d, want 1 for return error relay", got)
}
}
func TestKCPHubRejectsOversizeRelayedMessage(t *testing.T) {
fixture := startRelayedKCPHubs(t)
defer fixture.cleanup()
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
if err != nil {
t.Fatalf("DialKCP(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") }, "peer-a to be registered on hubC")
body := bytes.Repeat([]byte("a"), 70*1024)
if err := peerA.SendFile("remote-peer", "payload.bin", body); err != nil {
t.Fatalf("peerA.SendFile() error = %v", err)
}
got, err := peerA.Receive()
if err != nil {
t.Fatalf("peerA.Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "message too large for relay udp" {
t.Fatalf("error body = %q, want oversize relay error", got.Body)
}
if got := fixture.relayC.WriteCount(); got != 0 {
t.Fatalf("relayC write count = %d, want 0 when relay rejects oversize payload", got)
}
}
func startRealKCPHubServer(t *testing.T, hub *server.KCPHub) (string, func()) {
t.Helper()
listener, packetConn, err := transport.ListenKCPSessions("127.0.0.1:0", "", nil, latencylog.NodeRoleServer, "hub")
if err != nil {
t.Fatalf("ListenKCPSessions() error = %v", err)
}
var (
wg sync.WaitGroup
stop = make(chan struct{})
)
wg.Add(1)
go func() {
defer wg.Done()
for {
session, acceptErr := listener.AcceptKCP()
if acceptErr != nil {
select {
case <-stop:
return
default:
}
if strings.Contains(acceptErr.Error(), "closed") {
return
}
t.Errorf("AcceptKCP() error = %v", acceptErr)
return
}
wg.Add(1)
go func(sess *kcp.UDPSession) {
defer wg.Done()
if serveErr := hub.ServeSession(sess); serveErr != nil && !isExpectedKCPHubServeExit(serveErr) {
t.Logf("hub.ServeSession() ended with %v", serveErr)
}
}(session)
}
}()
cleanup := func() {
close(stop)
_ = listener.Close()
_ = packetConn.Close()
wg.Wait()
}
return listener.Addr().String(), cleanup
}
type relayedKCPHubFixture struct {
hubC *server.KCPHub
hubD *server.KCPHub
serverCAddr string
serverDAddr string
relayC *countingPacketConn
relayD *countingPacketConn
cleanup func()
}
func startRelayedKCPHubs(t *testing.T) relayedKCPHubFixture {
t.Helper()
hubC := server.NewKCPHub()
serverCAddr, cleanupC := startRealKCPHubServer(t, hubC)
hubD := server.NewKCPHub()
serverDAddr, cleanupD := startRealKCPHubServer(t, hubD)
baseRelayC, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
cleanupD()
cleanupC()
t.Fatalf("ListenPacket(relayC) error = %v", err)
}
relayC := &countingPacketConn{PacketConn: baseRelayC}
baseRelayD, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
_ = relayC.Close()
cleanupD()
cleanupC()
t.Fatalf("ListenPacket(relayD) error = %v", err)
}
relayD := &countingPacketConn{PacketConn: baseRelayD}
hubC.SetRelaySocket(relayC, relayD.LocalAddr(), false)
hubD.SetRelaySocket(relayD, relayC.LocalAddr(), false)
stopRelayC := startRelayLoop(t, hubC, relayC)
stopRelayD := startRelayLoop(t, hubD, relayD)
cleanup := func() {
stopRelayC()
stopRelayD()
cleanupD()
cleanupC()
}
return relayedKCPHubFixture{
hubC: hubC,
hubD: hubD,
serverCAddr: serverCAddr,
serverDAddr: serverDAddr,
relayC: relayC,
relayD: relayD,
cleanup: cleanup,
}
}
func startRelayLoop(t *testing.T, hub *server.KCPHub, conn net.PacketConn) func() {
t.Helper()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := hub.ServeRelay(); err != nil && !isExpectedKCPRelayServeExit(err) {
t.Errorf("hub.ServeRelay() error = %v", err)
}
}()
return func() {
_ = conn.Close()
wg.Wait()
}
}
type countingPacketConn struct {
net.PacketConn
writeCount int32
}
func (c *countingPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
atomic.AddInt32(&c.writeCount, 1)
return c.PacketConn.WriteTo(p, addr)
}
func (c *countingPacketConn) WriteCount() int {
return int(atomic.LoadInt32(&c.writeCount))
}
func isExpectedKCPHubServeExit(err error) bool {
if err == nil {
return true
}
message := err.Error()
return strings.Contains(message, "closed") || strings.Contains(message, "broken pipe") || strings.Contains(message, "io: read/write on closed pipe")
}
func isExpectedKCPRelayServeExit(err error) bool {
if err == nil {
return true
}
message := err.Error()
return strings.Contains(message, "closed") || strings.Contains(message, "use of closed network connection")
}