Files
OmniSocketGo/cmd/internal/peer/client_test.go
2026-03-23 20:50:44 +08:00

820 lines
21 KiB
Go

package peer
import (
"bytes"
"encoding/json"
"net"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"testing"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/server"
"omnisocketgo/cmd/internal/transport"
)
type recordingLogger struct {
mu sync.Mutex
events []latencylog.Event
}
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
l.mu.Lock()
defer l.mu.Unlock()
l.events = append(l.events, event)
return nil
}
func (l *recordingLogger) Events() []latencylog.Event {
l.mu.Lock()
defer l.mu.Unlock()
return append([]latencylog.Event(nil), l.events...)
}
type failingLogger struct{}
func (failingLogger) LogEvent(latencylog.Event) error {
return net.ErrClosed
}
func TestDialRegistersPeer(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
client, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial() error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
}
func TestDialRegistersPeerWithBindIP(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
client, err := Dial("ignored", "peer-a", WithBindIP("127.0.0.1"))
if err != nil {
t.Fatalf("Dial() with bind ip error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
}
func TestDialRejectsInvalidBindIP(t *testing.T) {
_, err := Dial("ignored", "peer-a", WithBindIP("not-an-ip"))
if err == nil {
t.Fatal("Dial() error = nil, want invalid bind ip error")
}
if !strings.Contains(err.Error(), `invalid bind ip "not-an-ip"`) {
t.Fatalf("Dial() error = %v, want invalid bind ip error", err)
}
}
func TestDialPassesBindDeviceOptionToDialer(t *testing.T) {
originalDial := dialServer
defer func() {
dialServer = originalDial
}()
gotDevice := ""
dialServer = func(_ string, options clientOptions) (net.Conn, error) {
gotDevice = options.bindDevice
return nil, net.ErrClosed
}
_, err := Dial("ignored", "peer-a", WithBindDevice("wwan0"))
if err == nil {
t.Fatal("Dial() error = nil, want dial error")
}
if gotDevice != "wwan0" {
t.Fatalf("bind device = %q, want %q", gotDevice, "wwan0")
}
}
func TestClientsExchangeTextAndFileMessages(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
peerA, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := Dial("ignored", "peer-b")
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
received := make(chan protocol.Message, 2)
receiveErr := make(chan error, 1)
go func() {
for i := 0; i < 2; i++ {
msg, err := peerB.Receive()
if err != nil {
receiveErr <- err
return
}
received <- msg
}
receiveErr <- nil
}()
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
fileBody := []byte{0x01, 0x02, 0x03}
if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil {
t.Fatalf("SendFile() error = %v", err)
}
if err := <-receiveErr; err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
gotFirst := <-received
wantFirst := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
if !reflect.DeepEqual(gotFirst, wantFirst) {
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst)
}
gotSecond := <-received
wantSecond := protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: fileBody,
}
if !reflect.DeepEqual(gotSecond, wantSecond) {
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond)
}
}
func TestClientReceivesServerErrorForUnknownTarget(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
client, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial() error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.SendText("missing-peer", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "unknown target: missing-peer" {
t.Fatalf("error body = %q, want unknown target message", got.Body)
}
}
func TestClientReceiveLoopHandlesForwardedMessages(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
peerA, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := Dial("ignored", "peer-b")
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- peerB.ReceiveLoop(func(msg protocol.Message) error {
mu.Lock()
defer mu.Unlock()
got = append(got, msg)
if len(got) == 1 {
return peerB.Close()
}
return nil
})
}()
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
err = <-loopErr
if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection")) {
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
}
mu.Lock()
defer mu.Unlock()
want := []protocol.Message{
{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
}
}
func TestClientSendLogsLatencyEvents(t *testing.T) {
tests := []struct {
name string
setup func(*testing.T) string
send func(*Client, string) error
wantMsg protocol.Message
wantEvents []string
}{
{
name: "text",
send: func(client *Client, _ string) error {
return client.SendText("peer-b", "hello")
},
wantMsg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
wantEvents: []string{
latencylog.EventAAppPrepBegin,
latencylog.EventSendHandoffBegin,
latencylog.EventATXSched,
latencylog.EventATXSoftware,
latencylog.EventSendHandoffEnd,
},
},
{
name: "file-bytes",
send: func(client *Client, _ string) error {
return client.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03})
},
wantMsg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 1,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
wantEvents: []string{
latencylog.EventAAppPrepBegin,
latencylog.EventSendHandoffBegin,
latencylog.EventATXSched,
latencylog.EventATXSoftware,
latencylog.EventSendHandoffEnd,
},
},
{
name: "file-path",
setup: func(t *testing.T) string {
t.Helper()
path := filepath.Join(t.TempDir(), "payload.bin")
if err := os.WriteFile(path, []byte{0x01, 0x02, 0x03}, 0o644); err != nil {
t.Fatalf("os.WriteFile() error = %v", err)
}
return path
},
send: func(client *Client, path string) error {
return client.SendFilePath("peer-b", path)
},
wantMsg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 1,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
wantEvents: []string{
latencylog.EventAAppPrepBegin,
latencylog.EventSendHandoffBegin,
latencylog.EventATXSched,
latencylog.EventATXSoftware,
latencylog.EventSendHandoffEnd,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputPath := ""
if tt.setup != nil {
inputPath = tt.setup(t)
}
logger := &recordingLogger{}
clientConn, receiver := newClientTransportPair(
t,
[]transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
)
client := &Client{
id: "peer-a",
conn: clientConn,
logger: logger,
}
sendErr := make(chan error, 1)
go func() {
sendErr <- tt.send(client, inputPath)
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.wantMsg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.wantMsg)
}
events := logger.Events()
if len(events) != len(tt.wantEvents) {
t.Fatalf("event count = %d, want %d", len(events), len(tt.wantEvents))
}
for i, wantEvent := range tt.wantEvents {
if events[i].Event != wantEvent {
t.Fatalf("event[%d] = %q, want %q", i, events[i].Event, wantEvent)
}
if events[i].MessageID != tt.wantMsg.ID || events[i].From != tt.wantMsg.From || events[i].To != tt.wantMsg.To {
t.Fatalf("event[%d] metadata mismatch: %+v", i, events[i])
}
}
})
}
}
func TestClientReceiveLogsOnlyBusinessMessages(t *testing.T) {
logger := &recordingLogger{}
clientConn, sender := newClientTransportPair(
t,
[]transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-b")},
nil,
)
client := &Client{
id: "peer-b",
conn: clientConn,
logger: logger,
}
textMsg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 21,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(textMsg)
}()
if _, err := client.Receive(); err != nil {
t.Fatalf("client.Receive(text) error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send(text) error = %v", err)
}
errorMsg := protocol.Message{
Type: protocol.MessageTypeError,
ID: 22,
From: protocol.ServerPeerID,
To: "peer-b",
Body: []byte("failure"),
}
sendErr = make(chan error, 1)
go func() {
sendErr <- sender.Send(errorMsg)
}()
if _, err := client.Receive(); err != nil {
t.Fatalf("client.Receive(error) error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send(error) error = %v", err)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventBRXSoftware {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBRXSoftware)
}
if events[1].Event != latencylog.EventBAppRecv {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBAppRecv)
}
if events[0].MessageID != textMsg.ID || events[1].MessageID != textMsg.ID {
t.Fatalf("message IDs = %d,%d, want %d", events[0].MessageID, events[1].MessageID, textMsg.ID)
}
}
func TestClientPersistTextMessageWritesInboxFileAndLogs(t *testing.T) {
inboxDir := t.TempDir()
logger := &recordingLogger{}
client := &Client{
id: "peer-b",
logger: logger,
}
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 31,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
path, err := client.PersistMessage(msg, inboxDir)
if err != nil {
t.Fatalf("PersistMessage() error = %v", err)
}
if path != filepath.Join(inboxDir, textInboxFileName) {
t.Fatalf("path = %q, want %q", path, filepath.Join(inboxDir, textInboxFileName))
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("os.ReadFile() error = %v", err)
}
var record textInboxRecord
if err := json.Unmarshal(bytes.TrimSpace(data), &record); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if record.MessageID != msg.ID || record.From != msg.From || record.To != msg.To || record.Body != "hello" {
t.Fatalf("record mismatch: got %+v want message %+v", record, msg)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventBPersistBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
}
if events[1].Event != latencylog.EventBPersistEnd {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
}
}
func TestClientPersistFileMessageWritesInboxFileAndLogs(t *testing.T) {
inboxDir := t.TempDir()
logger := &recordingLogger{}
client := &Client{
id: "peer-b",
logger: logger,
}
msg := protocol.Message{
Type: protocol.MessageTypeFile,
ID: 32,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
}
path, err := client.PersistMessage(msg, inboxDir)
if err != nil {
t.Fatalf("PersistMessage() error = %v", err)
}
wantPath := filepath.Join(inboxDir, "peer-a-32-payload.bin")
if path != wantPath {
t.Fatalf("path = %q, want %q", path, wantPath)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("os.ReadFile() error = %v", err)
}
if !reflect.DeepEqual(data, msg.Body) {
t.Fatalf("file body mismatch: got %v want %v", data, msg.Body)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventBPersistBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
}
if events[1].Event != latencylog.EventBPersistEnd {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
}
}
func TestClientPersistMessageReturnsErrorOnWriteFailure(t *testing.T) {
blocker := filepath.Join(t.TempDir(), "blocker")
if err := os.WriteFile(blocker, []byte("not a directory"), 0o644); err != nil {
t.Fatalf("os.WriteFile() error = %v", err)
}
logger := &recordingLogger{}
client := &Client{
id: "peer-b",
logger: logger,
}
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 33,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
if _, err := client.PersistMessage(msg, blocker); err == nil {
t.Fatal("PersistMessage() error = nil, want non-nil")
}
events := logger.Events()
if len(events) != 1 {
t.Fatalf("event count = %d, want 1", len(events))
}
if events[0].Event != latencylog.EventBPersistBegin {
t.Fatalf("event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
}
}
func TestClientIgnoresLoggerFailure(t *testing.T) {
clientConn, receiver := newClientTransportPair(
t,
[]transport.Option{transport.WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
nil,
)
client := &Client{
id: "peer-a",
conn: clientConn,
logger: failingLogger{},
}
sendErr := make(chan error, 1)
go func() {
sendErr <- client.SendText("peer-b", "hello")
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("SendText() error = %v, want nil even if logger fails", err)
}
if string(got.Body) != "hello" {
t.Fatalf("body = %q, want hello", got.Body)
}
}
func TestClientPersistIgnoresLoggerFailure(t *testing.T) {
client := &Client{
id: "peer-b",
logger: failingLogger{},
}
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 34,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
path, err := client.PersistMessage(msg, t.TempDir())
if err != nil {
t.Fatalf("PersistMessage() error = %v, want nil even if logger fails", err)
}
if path == "" {
t.Fatal("PersistMessage() path = empty, want non-empty")
}
}
func TestClientsExchangeMessagesWithLatencyLogs(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
peerALogger := &recordingLogger{}
peerA, err := Dial("ignored", "peer-a", WithLogger(peerALogger))
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerBLogger := &recordingLogger{}
peerB, err := Dial("ignored", "peer-b", WithLogger(peerBLogger))
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
inboxDir := t.TempDir()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
textMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(text) error = %v", err)
}
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
}
if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil {
t.Fatalf("SendFile() error = %v", err)
}
fileMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(file) error = %v", err)
}
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
}
waitFor(t, func() bool { return len(peerALogger.Events()) == 10 }, "peer-a latency events")
waitFor(t, func() bool { return len(peerBLogger.Events()) == 8 }, "peer-b latency events")
assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{
1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
})
assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{
1: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
2: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
})
}
func stubDialToHub(t *testing.T, hub *server.Hub) func() {
t.Helper()
originalDial := dialServer
serverAddr, cleanup := startRealHubServer(t, hub)
dialServer = func(_ string, options clientOptions) (net.Conn, error) {
dialer, err := buildDialer(options)
if err != nil {
return nil, err
}
return dialer.Dial("tcp", serverAddr)
}
return func() {
dialServer = originalDial
cleanup()
}
}
func waitFor(t *testing.T, condition func() bool, description string) {
t.Helper()
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for %s", description)
}
func assertEventSequencesByMessage(t *testing.T, events []latencylog.Event, want map[uint64][]string) {
t.Helper()
grouped := make(map[uint64][]latencylog.Event)
for _, event := range events {
grouped[event.MessageID] = append(grouped[event.MessageID], event)
if event.TsUnixNano <= 0 {
t.Fatalf("event timestamp must be positive: %+v", event)
}
}
if len(grouped) != len(want) {
t.Fatalf("message group count = %d, want %d", len(grouped), len(want))
}
for messageID, wantEvents := range want {
gotEvents := grouped[messageID]
if len(gotEvents) != len(wantEvents) {
t.Fatalf("message %d event count = %d, want %d", messageID, len(gotEvents), len(wantEvents))
}
for i, wantEvent := range wantEvents {
if gotEvents[i].Event != wantEvent {
t.Fatalf("message %d event[%d] = %q, want %q", messageID, i, gotEvents[i].Event, wantEvent)
}
}
}
}
func newClientTransportPair(t *testing.T, clientOpts []transport.Option, peerOpts []transport.Option) (*transport.TCPConn, *transport.TCPConn) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
type acceptResult struct {
conn net.Conn
err error
}
accepted := make(chan acceptResult, 1)
go func() {
conn, acceptErr := listener.Accept()
accepted <- acceptResult{conn: conn, err: acceptErr}
}()
clientSide, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
_ = listener.Close()
t.Fatalf("net.Dial() error = %v", err)
}
result := <-accepted
if err := listener.Close(); err != nil {
t.Fatalf("listener.Close() error = %v", err)
}
if result.err != nil {
_ = clientSide.Close()
t.Fatalf("listener.Accept() error = %v", result.err)
}
clientConn, err := transport.NewTCPConn(clientSide, clientOpts...)
if err != nil {
_ = clientSide.Close()
_ = result.conn.Close()
t.Fatalf("transport.NewTCPConn(client) error = %v", err)
}
peerConn, err := transport.NewTCPConn(result.conn, peerOpts...)
if err != nil {
_ = clientConn.Close()
_ = result.conn.Close()
t.Fatalf("transport.NewTCPConn(peer) error = %v", err)
}
t.Cleanup(func() {
_ = clientConn.Close()
_ = peerConn.Close()
})
return clientConn, peerConn
}