package peer import ( "bytes" "net" "os" "path/filepath" "reflect" "strings" "sync" "sync/atomic" "testing" kcp "github.com/xtaci/kcp-go/v5" "omnisocketgo/cmd/internal/latencylog" "omnisocketgo/cmd/internal/protocol" "omnisocketgo/cmd/internal/server" "omnisocketgo/cmd/internal/transport" ) func TestKCPDialRegistersPeer(t *testing.T) { hub := server.NewKCPHub() serverAddr, cleanup := startRealKCPHubServer(t, hub) defer cleanup() client, err := DialKCP(serverAddr, "peer-a") if err != nil { t.Fatalf("DialKCP() error = %v", err) } defer func() { _ = client.Close() }() waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") } func TestKCPDialRejectsInvalidBindIP(t *testing.T) { _, err := DialKCP("127.0.0.1:9002", "peer-a", WithBindIP("not-an-ip")) if err == nil { t.Fatal("DialKCP() error = nil, want invalid bind ip error") } if !strings.Contains(err.Error(), `invalid bind ip "not-an-ip"`) { t.Fatalf("DialKCP() error = %v, want invalid bind ip error", err) } } func TestKCPClientsExchangeTextAndFileMessages(t *testing.T) { hub := server.NewKCPHub() serverAddr, cleanup := startRealKCPHubServer(t, hub) defer cleanup() peerA, err := DialKCP(serverAddr, "peer-a") if err != nil { t.Fatalf("DialKCP(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() peerB, err := DialKCP(serverAddr, "peer-b") if err != nil { t.Fatalf("DialKCP(peer-b) error = %v", err) } defer func() { _ = peerB.Close() }() waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") received := make(chan protocol.Message, 2) receiveErr := make(chan error, 1) go func() { for i := 0; i < 2; i++ { msg, err := peerB.Receive() if err != nil { receiveErr <- err return } received <- msg } receiveErr <- nil }() if err := peerA.SendText("peer-b", "hello over kcp"); err != nil { t.Fatalf("SendText() error = %v", err) } fileBody := []byte{0x01, 0x02, 0x03} if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil { t.Fatalf("SendFile() error = %v", err) } if err := <-receiveErr; err != nil { t.Fatalf("peerB.Receive() error = %v", err) } gotFirst := <-received wantFirst := protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello over kcp"), } if !reflect.DeepEqual(gotFirst, wantFirst) { t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst) } gotSecond := <-received wantSecond := protocol.Message{ Type: protocol.MessageTypeFile, ID: 2, From: "peer-a", To: "peer-b", FileName: "payload.bin", Body: fileBody, } if !reflect.DeepEqual(gotSecond, wantSecond) { t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond) } } func TestKCPClientReceivesServerErrorForUnknownTarget(t *testing.T) { hub := server.NewKCPHub() serverAddr, cleanup := startRealKCPHubServer(t, hub) defer cleanup() client, err := DialKCP(serverAddr, "peer-a") if err != nil { t.Fatalf("DialKCP() error = %v", err) } defer func() { _ = client.Close() }() waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") if err := client.SendText("missing-peer", "hello"); err != nil { t.Fatalf("SendText() error = %v", err) } got, err := client.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if got.Type != protocol.MessageTypeError { t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError) } if string(got.Body) != "unknown target: missing-peer" { t.Fatalf("error body = %q, want unknown target message", got.Body) } } func TestKCPClientsExchangeMessagesWithLatencyLogs(t *testing.T) { hub := server.NewKCPHub() serverAddr, cleanup := startRealKCPHubServer(t, hub) defer cleanup() peerALogger := &recordingLogger{} peerA, err := DialKCP(serverAddr, "peer-a", WithLogger(peerALogger)) if err != nil { t.Fatalf("DialKCP(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() peerBLogger := &recordingLogger{} peerB, err := DialKCP(serverAddr, "peer-b", WithLogger(peerBLogger)) if err != nil { t.Fatalf("DialKCP(peer-b) error = %v", err) } defer func() { _ = peerB.Close() }() inboxDir := t.TempDir() waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") if err := peerA.SendText("peer-b", "hello"); err != nil { t.Fatalf("SendText() error = %v", err) } textMsg, err := peerB.Receive() if err != nil { t.Fatalf("peerB.Receive(text) error = %v", err) } if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil { t.Fatalf("peerB.PersistMessage(text) error = %v", err) } filePath := filepath.Join(t.TempDir(), "payload.bin") if err := os.WriteFile(filePath, []byte{0x01, 0x02, 0x03}, 0o644); err != nil { t.Fatalf("os.WriteFile() error = %v", err) } if err := peerA.SendFilePath("peer-b", filePath); err != nil { t.Fatalf("SendFilePath() error = %v", err) } fileMsg, err := peerB.Receive() if err != nil { t.Fatalf("peerB.Receive(file) error = %v", err) } if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil { t.Fatalf("peerB.PersistMessage(file) error = %v", err) } waitFor(t, func() bool { return len(peerALogger.Events()) == 6 }, "peer-a latency events") waitFor(t, func() bool { return len(peerBLogger.Events()) == 6 }, "peer-b latency events") assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{ 1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd}, 2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd}, }) assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{ 1: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd}, 2: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd}, }) } func TestKCPClientsExchangeMessagesAcrossRelayedServers(t *testing.T) { fixture := startRelayedKCPHubs(t) defer fixture.cleanup() peerA, err := DialKCP(fixture.serverCAddr, "peer-a") if err != nil { t.Fatalf("DialKCP(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() peerB, err := DialKCP(fixture.serverDAddr, "peer-b") if err != nil { t.Fatalf("DialKCP(peer-b) error = %v", err) } defer func() { _ = peerB.Close() }() waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") && fixture.hubD.HasPeer("peer-b") }, "both relayed peers to be registered") if err := peerA.SendText("peer-b", "hello via relay"); err != nil { t.Fatalf("peerA.SendText() error = %v", err) } gotAtB, err := peerB.Receive() if err != nil { t.Fatalf("peerB.Receive() error = %v", err) } wantAtB := protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello via relay"), } if !reflect.DeepEqual(gotAtB, wantAtB) { t.Fatalf("peerB received %+v, want %+v", gotAtB, wantAtB) } if err := peerB.SendText("peer-a", "hello back"); err != nil { t.Fatalf("peerB.SendText() error = %v", err) } gotAtA, err := peerA.Receive() if err != nil { t.Fatalf("peerA.Receive() error = %v", err) } wantAtA := protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-b", To: "peer-a", Body: []byte("hello back"), } if !reflect.DeepEqual(gotAtA, wantAtA) { t.Fatalf("peerA received %+v, want %+v", gotAtA, wantAtA) } if got := fixture.relayC.WriteCount(); got != 1 { t.Fatalf("relayC write count = %d, want 1", got) } if got := fixture.relayD.WriteCount(); got != 1 { t.Fatalf("relayD write count = %d, want 1", got) } } func TestKCPHubPrefersLocalPeerBeforeRelay(t *testing.T) { fixture := startRelayedKCPHubs(t) defer fixture.cleanup() peerA, err := DialKCP(fixture.serverCAddr, "peer-a") if err != nil { t.Fatalf("DialKCP(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() peerB, err := DialKCP(fixture.serverCAddr, "peer-b") if err != nil { t.Fatalf("DialKCP(peer-b) error = %v", err) } defer func() { _ = peerB.Close() }() waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") && fixture.hubC.HasPeer("peer-b") }, "local peers on hubC to be registered") if err := peerA.SendText("peer-b", "local delivery"); err != nil { t.Fatalf("peerA.SendText() error = %v", err) } got, err := peerB.Receive() if err != nil { t.Fatalf("peerB.Receive() error = %v", err) } want := protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("local delivery"), } if !reflect.DeepEqual(got, want) { t.Fatalf("peerB received %+v, want %+v", got, want) } if got := fixture.relayC.WriteCount(); got != 0 { t.Fatalf("relayC write count = %d, want 0 for local delivery", got) } if got := fixture.relayD.WriteCount(); got != 0 { t.Fatalf("relayD write count = %d, want 0 for local delivery", got) } } func TestKCPRelayedUnknownTargetReturnsErrorToOriginalSender(t *testing.T) { fixture := startRelayedKCPHubs(t) defer fixture.cleanup() peerA, err := DialKCP(fixture.serverCAddr, "peer-a") if err != nil { t.Fatalf("DialKCP(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") }, "peer-a to be registered on hubC") if err := peerA.SendText("remote-missing", "hello"); err != nil { t.Fatalf("peerA.SendText() error = %v", err) } got, err := peerA.Receive() if err != nil { t.Fatalf("peerA.Receive() error = %v", err) } if got.Type != protocol.MessageTypeError { t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError) } if got.From != protocol.ServerPeerID { t.Fatalf("error from = %s, want %s", got.From, protocol.ServerPeerID) } if got.To != "peer-a" { t.Fatalf("error to = %s, want peer-a", got.To) } if string(got.Body) != "unknown target: remote-missing" { t.Fatalf("error body = %q, want unknown target from relayed hub", got.Body) } if got := fixture.relayC.WriteCount(); got != 1 { t.Fatalf("relayC write count = %d, want 1 for outbound relay", got) } if got := fixture.relayD.WriteCount(); got != 1 { t.Fatalf("relayD write count = %d, want 1 for return error relay", got) } } func TestKCPHubRejectsOversizeRelayedMessage(t *testing.T) { fixture := startRelayedKCPHubs(t) defer fixture.cleanup() peerA, err := DialKCP(fixture.serverCAddr, "peer-a") if err != nil { t.Fatalf("DialKCP(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") }, "peer-a to be registered on hubC") body := bytes.Repeat([]byte("a"), 70*1024) if err := peerA.SendFile("remote-peer", "payload.bin", body); err != nil { t.Fatalf("peerA.SendFile() error = %v", err) } got, err := peerA.Receive() if err != nil { t.Fatalf("peerA.Receive() error = %v", err) } if got.Type != protocol.MessageTypeError { t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError) } if string(got.Body) != "message too large for relay udp" { t.Fatalf("error body = %q, want oversize relay error", got.Body) } if got := fixture.relayC.WriteCount(); got != 0 { t.Fatalf("relayC write count = %d, want 0 when relay rejects oversize payload", got) } } func startRealKCPHubServer(t *testing.T, hub *server.KCPHub) (string, func()) { t.Helper() listener, packetConn, err := transport.ListenKCPSessions("127.0.0.1:0", "", nil, latencylog.NodeRoleServer, "hub") if err != nil { t.Fatalf("ListenKCPSessions() error = %v", err) } var ( wg sync.WaitGroup stop = make(chan struct{}) ) wg.Add(1) go func() { defer wg.Done() for { session, acceptErr := listener.AcceptKCP() if acceptErr != nil { select { case <-stop: return default: } if strings.Contains(acceptErr.Error(), "closed") { return } t.Errorf("AcceptKCP() error = %v", acceptErr) return } wg.Add(1) go func(sess *kcp.UDPSession) { defer wg.Done() if serveErr := hub.ServeSession(sess); serveErr != nil && !isExpectedKCPHubServeExit(serveErr) { t.Logf("hub.ServeSession() ended with %v", serveErr) } }(session) } }() cleanup := func() { close(stop) _ = listener.Close() _ = packetConn.Close() wg.Wait() } return listener.Addr().String(), cleanup } type relayedKCPHubFixture struct { hubC *server.KCPHub hubD *server.KCPHub serverCAddr string serverDAddr string relayC *countingPacketConn relayD *countingPacketConn cleanup func() } func startRelayedKCPHubs(t *testing.T) relayedKCPHubFixture { t.Helper() hubC := server.NewKCPHub() serverCAddr, cleanupC := startRealKCPHubServer(t, hubC) hubD := server.NewKCPHub() serverDAddr, cleanupD := startRealKCPHubServer(t, hubD) baseRelayC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { cleanupD() cleanupC() t.Fatalf("ListenPacket(relayC) error = %v", err) } relayC := &countingPacketConn{PacketConn: baseRelayC} baseRelayD, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { _ = relayC.Close() cleanupD() cleanupC() t.Fatalf("ListenPacket(relayD) error = %v", err) } relayD := &countingPacketConn{PacketConn: baseRelayD} hubC.SetRelaySocket(relayC, relayD.LocalAddr(), false) hubD.SetRelaySocket(relayD, relayC.LocalAddr(), false) stopRelayC := startRelayLoop(t, hubC, relayC) stopRelayD := startRelayLoop(t, hubD, relayD) cleanup := func() { stopRelayC() stopRelayD() cleanupD() cleanupC() } return relayedKCPHubFixture{ hubC: hubC, hubD: hubD, serverCAddr: serverCAddr, serverDAddr: serverDAddr, relayC: relayC, relayD: relayD, cleanup: cleanup, } } func startRelayLoop(t *testing.T, hub *server.KCPHub, conn net.PacketConn) func() { t.Helper() var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() if err := hub.ServeRelay(); err != nil && !isExpectedKCPRelayServeExit(err) { t.Errorf("hub.ServeRelay() error = %v", err) } }() return func() { _ = conn.Close() wg.Wait() } } type countingPacketConn struct { net.PacketConn writeCount int32 } func (c *countingPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { atomic.AddInt32(&c.writeCount, 1) return c.PacketConn.WriteTo(p, addr) } func (c *countingPacketConn) WriteCount() int { return int(atomic.LoadInt32(&c.writeCount)) } func isExpectedKCPHubServeExit(err error) bool { if err == nil { return true } message := err.Error() return strings.Contains(message, "closed") || strings.Contains(message, "broken pipe") || strings.Contains(message, "io: read/write on closed pipe") } func isExpectedKCPRelayServeExit(err error) bool { if err == nil { return true } message := err.Error() return strings.Contains(message, "closed") || strings.Contains(message, "use of closed network connection") }