547 lines
15 KiB
Go
547 lines
15 KiB
Go
package peer
|
|
|
|
import (
|
|
"bytes"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
kcp "github.com/xtaci/kcp-go/v5"
|
|
|
|
"omnisocketgo/cmd/internal/latencylog"
|
|
"omnisocketgo/cmd/internal/protocol"
|
|
"omnisocketgo/cmd/internal/server"
|
|
"omnisocketgo/cmd/internal/transport"
|
|
)
|
|
|
|
func TestKCPDialRegistersPeer(t *testing.T) {
|
|
hub := server.NewKCPHub()
|
|
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
|
defer cleanup()
|
|
|
|
client, err := DialKCP(serverAddr, "peer-a")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP() error = %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
|
}
|
|
|
|
func TestKCPDialRejectsInvalidBindIP(t *testing.T) {
|
|
_, err := DialKCP("127.0.0.1:9002", "peer-a", WithBindIP("not-an-ip"))
|
|
if err == nil {
|
|
t.Fatal("DialKCP() error = nil, want invalid bind ip error")
|
|
}
|
|
if !strings.Contains(err.Error(), `invalid bind ip "not-an-ip"`) {
|
|
t.Fatalf("DialKCP() error = %v, want invalid bind ip error", err)
|
|
}
|
|
}
|
|
|
|
func TestKCPClientsExchangeTextAndFileMessages(t *testing.T) {
|
|
hub := server.NewKCPHub()
|
|
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
|
defer cleanup()
|
|
|
|
peerA, err := DialKCP(serverAddr, "peer-a")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
|
}
|
|
defer func() { _ = peerA.Close() }()
|
|
|
|
peerB, err := DialKCP(serverAddr, "peer-b")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-b) error = %v", err)
|
|
}
|
|
defer func() { _ = peerB.Close() }()
|
|
|
|
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
|
|
|
received := make(chan protocol.Message, 2)
|
|
receiveErr := make(chan error, 1)
|
|
go func() {
|
|
for i := 0; i < 2; i++ {
|
|
msg, err := peerB.Receive()
|
|
if err != nil {
|
|
receiveErr <- err
|
|
return
|
|
}
|
|
received <- msg
|
|
}
|
|
receiveErr <- nil
|
|
}()
|
|
|
|
if err := peerA.SendText("peer-b", "hello over kcp"); err != nil {
|
|
t.Fatalf("SendText() error = %v", err)
|
|
}
|
|
fileBody := []byte{0x01, 0x02, 0x03}
|
|
if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil {
|
|
t.Fatalf("SendFile() error = %v", err)
|
|
}
|
|
|
|
if err := <-receiveErr; err != nil {
|
|
t.Fatalf("peerB.Receive() error = %v", err)
|
|
}
|
|
|
|
gotFirst := <-received
|
|
wantFirst := protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 1,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
Body: []byte("hello over kcp"),
|
|
}
|
|
if !reflect.DeepEqual(gotFirst, wantFirst) {
|
|
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst)
|
|
}
|
|
|
|
gotSecond := <-received
|
|
wantSecond := protocol.Message{
|
|
Type: protocol.MessageTypeFile,
|
|
ID: 2,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
FileName: "payload.bin",
|
|
Body: fileBody,
|
|
}
|
|
if !reflect.DeepEqual(gotSecond, wantSecond) {
|
|
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond)
|
|
}
|
|
}
|
|
|
|
func TestKCPClientReceivesServerErrorForUnknownTarget(t *testing.T) {
|
|
hub := server.NewKCPHub()
|
|
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
|
defer cleanup()
|
|
|
|
client, err := DialKCP(serverAddr, "peer-a")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP() error = %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
|
|
|
if err := client.SendText("missing-peer", "hello"); err != nil {
|
|
t.Fatalf("SendText() error = %v", err)
|
|
}
|
|
|
|
got, err := client.Receive()
|
|
if err != nil {
|
|
t.Fatalf("Receive() error = %v", err)
|
|
}
|
|
if got.Type != protocol.MessageTypeError {
|
|
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
|
|
}
|
|
if string(got.Body) != "unknown target: missing-peer" {
|
|
t.Fatalf("error body = %q, want unknown target message", got.Body)
|
|
}
|
|
}
|
|
|
|
func TestKCPClientsExchangeMessagesWithLatencyLogs(t *testing.T) {
|
|
hub := server.NewKCPHub()
|
|
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
|
defer cleanup()
|
|
|
|
peerALogger := &recordingLogger{}
|
|
peerA, err := DialKCP(serverAddr, "peer-a", WithLogger(peerALogger))
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
|
}
|
|
defer func() { _ = peerA.Close() }()
|
|
|
|
peerBLogger := &recordingLogger{}
|
|
peerB, err := DialKCP(serverAddr, "peer-b", WithLogger(peerBLogger))
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-b) error = %v", err)
|
|
}
|
|
defer func() { _ = peerB.Close() }()
|
|
|
|
inboxDir := t.TempDir()
|
|
|
|
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
|
|
|
if err := peerA.SendText("peer-b", "hello"); err != nil {
|
|
t.Fatalf("SendText() error = %v", err)
|
|
}
|
|
textMsg, err := peerB.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerB.Receive(text) error = %v", err)
|
|
}
|
|
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
|
|
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
|
|
}
|
|
|
|
filePath := filepath.Join(t.TempDir(), "payload.bin")
|
|
if err := os.WriteFile(filePath, []byte{0x01, 0x02, 0x03}, 0o644); err != nil {
|
|
t.Fatalf("os.WriteFile() error = %v", err)
|
|
}
|
|
if err := peerA.SendFilePath("peer-b", filePath); err != nil {
|
|
t.Fatalf("SendFilePath() error = %v", err)
|
|
}
|
|
fileMsg, err := peerB.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerB.Receive(file) error = %v", err)
|
|
}
|
|
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
|
|
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
|
|
}
|
|
|
|
waitFor(t, func() bool { return len(peerALogger.Events()) == 6 }, "peer-a latency events")
|
|
waitFor(t, func() bool { return len(peerBLogger.Events()) == 6 }, "peer-b latency events")
|
|
|
|
assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{
|
|
1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd},
|
|
2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd},
|
|
})
|
|
assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{
|
|
1: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
|
|
2: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
|
|
})
|
|
}
|
|
|
|
func TestKCPClientsExchangeMessagesAcrossRelayedServers(t *testing.T) {
|
|
fixture := startRelayedKCPHubs(t)
|
|
defer fixture.cleanup()
|
|
|
|
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
|
}
|
|
defer func() { _ = peerA.Close() }()
|
|
|
|
peerB, err := DialKCP(fixture.serverDAddr, "peer-b")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-b) error = %v", err)
|
|
}
|
|
defer func() { _ = peerB.Close() }()
|
|
|
|
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") && fixture.hubD.HasPeer("peer-b") }, "both relayed peers to be registered")
|
|
|
|
if err := peerA.SendText("peer-b", "hello via relay"); err != nil {
|
|
t.Fatalf("peerA.SendText() error = %v", err)
|
|
}
|
|
gotAtB, err := peerB.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerB.Receive() error = %v", err)
|
|
}
|
|
wantAtB := protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 1,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
Body: []byte("hello via relay"),
|
|
}
|
|
if !reflect.DeepEqual(gotAtB, wantAtB) {
|
|
t.Fatalf("peerB received %+v, want %+v", gotAtB, wantAtB)
|
|
}
|
|
|
|
if err := peerB.SendText("peer-a", "hello back"); err != nil {
|
|
t.Fatalf("peerB.SendText() error = %v", err)
|
|
}
|
|
gotAtA, err := peerA.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerA.Receive() error = %v", err)
|
|
}
|
|
wantAtA := protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 1,
|
|
From: "peer-b",
|
|
To: "peer-a",
|
|
Body: []byte("hello back"),
|
|
}
|
|
if !reflect.DeepEqual(gotAtA, wantAtA) {
|
|
t.Fatalf("peerA received %+v, want %+v", gotAtA, wantAtA)
|
|
}
|
|
|
|
if got := fixture.relayC.WriteCount(); got != 1 {
|
|
t.Fatalf("relayC write count = %d, want 1", got)
|
|
}
|
|
if got := fixture.relayD.WriteCount(); got != 1 {
|
|
t.Fatalf("relayD write count = %d, want 1", got)
|
|
}
|
|
}
|
|
|
|
func TestKCPHubPrefersLocalPeerBeforeRelay(t *testing.T) {
|
|
fixture := startRelayedKCPHubs(t)
|
|
defer fixture.cleanup()
|
|
|
|
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
|
}
|
|
defer func() { _ = peerA.Close() }()
|
|
|
|
peerB, err := DialKCP(fixture.serverCAddr, "peer-b")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-b) error = %v", err)
|
|
}
|
|
defer func() { _ = peerB.Close() }()
|
|
|
|
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") && fixture.hubC.HasPeer("peer-b") }, "local peers on hubC to be registered")
|
|
|
|
if err := peerA.SendText("peer-b", "local delivery"); err != nil {
|
|
t.Fatalf("peerA.SendText() error = %v", err)
|
|
}
|
|
got, err := peerB.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerB.Receive() error = %v", err)
|
|
}
|
|
want := protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 1,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
Body: []byte("local delivery"),
|
|
}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Fatalf("peerB received %+v, want %+v", got, want)
|
|
}
|
|
|
|
if got := fixture.relayC.WriteCount(); got != 0 {
|
|
t.Fatalf("relayC write count = %d, want 0 for local delivery", got)
|
|
}
|
|
if got := fixture.relayD.WriteCount(); got != 0 {
|
|
t.Fatalf("relayD write count = %d, want 0 for local delivery", got)
|
|
}
|
|
}
|
|
|
|
func TestKCPRelayedUnknownTargetReturnsErrorToOriginalSender(t *testing.T) {
|
|
fixture := startRelayedKCPHubs(t)
|
|
defer fixture.cleanup()
|
|
|
|
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
|
}
|
|
defer func() { _ = peerA.Close() }()
|
|
|
|
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") }, "peer-a to be registered on hubC")
|
|
|
|
if err := peerA.SendText("remote-missing", "hello"); err != nil {
|
|
t.Fatalf("peerA.SendText() error = %v", err)
|
|
}
|
|
|
|
got, err := peerA.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerA.Receive() error = %v", err)
|
|
}
|
|
if got.Type != protocol.MessageTypeError {
|
|
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
|
|
}
|
|
if got.From != protocol.ServerPeerID {
|
|
t.Fatalf("error from = %s, want %s", got.From, protocol.ServerPeerID)
|
|
}
|
|
if got.To != "peer-a" {
|
|
t.Fatalf("error to = %s, want peer-a", got.To)
|
|
}
|
|
if string(got.Body) != "unknown target: remote-missing" {
|
|
t.Fatalf("error body = %q, want unknown target from relayed hub", got.Body)
|
|
}
|
|
|
|
if got := fixture.relayC.WriteCount(); got != 1 {
|
|
t.Fatalf("relayC write count = %d, want 1 for outbound relay", got)
|
|
}
|
|
if got := fixture.relayD.WriteCount(); got != 1 {
|
|
t.Fatalf("relayD write count = %d, want 1 for return error relay", got)
|
|
}
|
|
}
|
|
|
|
func TestKCPHubRejectsOversizeRelayedMessage(t *testing.T) {
|
|
fixture := startRelayedKCPHubs(t)
|
|
defer fixture.cleanup()
|
|
|
|
peerA, err := DialKCP(fixture.serverCAddr, "peer-a")
|
|
if err != nil {
|
|
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
|
}
|
|
defer func() { _ = peerA.Close() }()
|
|
|
|
waitFor(t, func() bool { return fixture.hubC.HasPeer("peer-a") }, "peer-a to be registered on hubC")
|
|
|
|
body := bytes.Repeat([]byte("a"), 70*1024)
|
|
if err := peerA.SendFile("remote-peer", "payload.bin", body); err != nil {
|
|
t.Fatalf("peerA.SendFile() error = %v", err)
|
|
}
|
|
|
|
got, err := peerA.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerA.Receive() error = %v", err)
|
|
}
|
|
if got.Type != protocol.MessageTypeError {
|
|
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
|
|
}
|
|
if string(got.Body) != "message too large for relay udp" {
|
|
t.Fatalf("error body = %q, want oversize relay error", got.Body)
|
|
}
|
|
if got := fixture.relayC.WriteCount(); got != 0 {
|
|
t.Fatalf("relayC write count = %d, want 0 when relay rejects oversize payload", got)
|
|
}
|
|
}
|
|
|
|
func startRealKCPHubServer(t *testing.T, hub *server.KCPHub) (string, func()) {
|
|
t.Helper()
|
|
|
|
listener, packetConn, err := transport.ListenKCPSessions("127.0.0.1:0", "", nil, latencylog.NodeRoleServer, "hub")
|
|
if err != nil {
|
|
t.Fatalf("ListenKCPSessions() error = %v", err)
|
|
}
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
stop = make(chan struct{})
|
|
)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
session, acceptErr := listener.AcceptKCP()
|
|
if acceptErr != nil {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
default:
|
|
}
|
|
if strings.Contains(acceptErr.Error(), "closed") {
|
|
return
|
|
}
|
|
t.Errorf("AcceptKCP() error = %v", acceptErr)
|
|
return
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func(sess *kcp.UDPSession) {
|
|
defer wg.Done()
|
|
if serveErr := hub.ServeSession(sess); serveErr != nil && !isExpectedKCPHubServeExit(serveErr) {
|
|
t.Logf("hub.ServeSession() ended with %v", serveErr)
|
|
}
|
|
}(session)
|
|
}
|
|
}()
|
|
|
|
cleanup := func() {
|
|
close(stop)
|
|
_ = listener.Close()
|
|
_ = packetConn.Close()
|
|
wg.Wait()
|
|
}
|
|
|
|
return listener.Addr().String(), cleanup
|
|
}
|
|
|
|
type relayedKCPHubFixture struct {
|
|
hubC *server.KCPHub
|
|
hubD *server.KCPHub
|
|
serverCAddr string
|
|
serverDAddr string
|
|
relayC *countingPacketConn
|
|
relayD *countingPacketConn
|
|
cleanup func()
|
|
}
|
|
|
|
func startRelayedKCPHubs(t *testing.T) relayedKCPHubFixture {
|
|
t.Helper()
|
|
|
|
hubC := server.NewKCPHub()
|
|
serverCAddr, cleanupC := startRealKCPHubServer(t, hubC)
|
|
|
|
hubD := server.NewKCPHub()
|
|
serverDAddr, cleanupD := startRealKCPHubServer(t, hubD)
|
|
|
|
baseRelayC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
cleanupD()
|
|
cleanupC()
|
|
t.Fatalf("ListenPacket(relayC) error = %v", err)
|
|
}
|
|
relayC := &countingPacketConn{PacketConn: baseRelayC}
|
|
|
|
baseRelayD, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
_ = relayC.Close()
|
|
cleanupD()
|
|
cleanupC()
|
|
t.Fatalf("ListenPacket(relayD) error = %v", err)
|
|
}
|
|
relayD := &countingPacketConn{PacketConn: baseRelayD}
|
|
|
|
hubC.SetRelaySocket(relayC, relayD.LocalAddr(), false)
|
|
hubD.SetRelaySocket(relayD, relayC.LocalAddr(), false)
|
|
|
|
stopRelayC := startRelayLoop(t, hubC, relayC)
|
|
stopRelayD := startRelayLoop(t, hubD, relayD)
|
|
|
|
cleanup := func() {
|
|
stopRelayC()
|
|
stopRelayD()
|
|
cleanupD()
|
|
cleanupC()
|
|
}
|
|
|
|
return relayedKCPHubFixture{
|
|
hubC: hubC,
|
|
hubD: hubD,
|
|
serverCAddr: serverCAddr,
|
|
serverDAddr: serverDAddr,
|
|
relayC: relayC,
|
|
relayD: relayD,
|
|
cleanup: cleanup,
|
|
}
|
|
}
|
|
|
|
func startRelayLoop(t *testing.T, hub *server.KCPHub, conn net.PacketConn) func() {
|
|
t.Helper()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if err := hub.ServeRelay(); err != nil && !isExpectedKCPRelayServeExit(err) {
|
|
t.Errorf("hub.ServeRelay() error = %v", err)
|
|
}
|
|
}()
|
|
|
|
return func() {
|
|
_ = conn.Close()
|
|
wg.Wait()
|
|
}
|
|
}
|
|
|
|
type countingPacketConn struct {
|
|
net.PacketConn
|
|
writeCount int32
|
|
}
|
|
|
|
func (c *countingPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
|
atomic.AddInt32(&c.writeCount, 1)
|
|
return c.PacketConn.WriteTo(p, addr)
|
|
}
|
|
|
|
func (c *countingPacketConn) WriteCount() int {
|
|
return int(atomic.LoadInt32(&c.writeCount))
|
|
}
|
|
|
|
func isExpectedKCPHubServeExit(err error) bool {
|
|
if err == nil {
|
|
return true
|
|
}
|
|
|
|
message := err.Error()
|
|
return strings.Contains(message, "closed") || strings.Contains(message, "broken pipe") || strings.Contains(message, "io: read/write on closed pipe")
|
|
}
|
|
|
|
func isExpectedKCPRelayServeExit(err error) bool {
|
|
if err == nil {
|
|
return true
|
|
}
|
|
|
|
message := err.Error()
|
|
return strings.Contains(message, "closed") || strings.Contains(message, "use of closed network connection")
|
|
}
|