package peer import ( "bytes" "encoding/json" "net" "os" "path/filepath" "reflect" "strings" "sync" "testing" "time" "omnisocketgo/cmd/internal/latencylog" "omnisocketgo/cmd/internal/protocol" "omnisocketgo/cmd/internal/server" "omnisocketgo/cmd/internal/transport" ) type recordingLogger struct { mu sync.Mutex events []latencylog.Event } func (l *recordingLogger) LogEvent(event latencylog.Event) error { l.mu.Lock() defer l.mu.Unlock() l.events = append(l.events, event) return nil } func (l *recordingLogger) Events() []latencylog.Event { l.mu.Lock() defer l.mu.Unlock() return append([]latencylog.Event(nil), l.events...) } type failingLogger struct{} func (failingLogger) LogEvent(latencylog.Event) error { return net.ErrClosed } func TestDialRegistersPeer(t *testing.T) { hub := server.NewHub() cleanup := stubDialToHub(t, hub) defer cleanup() client, err := Dial("ignored", "peer-a") if err != nil { t.Fatalf("Dial() error = %v", err) } defer func() { _ = client.Close() }() waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") } func TestClientsExchangeTextAndFileMessages(t *testing.T) { hub := server.NewHub() cleanup := stubDialToHub(t, hub) defer cleanup() peerA, err := Dial("ignored", "peer-a") if err != nil { t.Fatalf("Dial(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() peerB, err := Dial("ignored", "peer-b") if err != nil { t.Fatalf("Dial(peer-b) error = %v", err) } defer func() { _ = peerB.Close() }() waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") received := make(chan protocol.Message, 2) receiveErr := make(chan error, 1) go func() { for i := 0; i < 2; i++ { msg, err := peerB.Receive() if err != nil { receiveErr <- err return } received <- msg } receiveErr <- nil }() if err := peerA.SendText("peer-b", "hello"); err != nil { t.Fatalf("SendText() error = %v", err) } fileBody := []byte{0x01, 0x02, 0x03} if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil { t.Fatalf("SendFile() error = %v", err) } if err := <-receiveErr; err != nil { t.Fatalf("peerB.Receive() error = %v", err) } gotFirst := <-received wantFirst := protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), } if !reflect.DeepEqual(gotFirst, wantFirst) { t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst) } gotSecond := <-received wantSecond := protocol.Message{ Type: protocol.MessageTypeFile, ID: 2, From: "peer-a", To: "peer-b", FileName: "payload.bin", Body: fileBody, } if !reflect.DeepEqual(gotSecond, wantSecond) { t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond) } } func TestClientReceivesServerErrorForUnknownTarget(t *testing.T) { hub := server.NewHub() cleanup := stubDialToHub(t, hub) defer cleanup() client, err := Dial("ignored", "peer-a") if err != nil { t.Fatalf("Dial() error = %v", err) } defer func() { _ = client.Close() }() waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") if err := client.SendText("missing-peer", "hello"); err != nil { t.Fatalf("SendText() error = %v", err) } got, err := client.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if got.Type != protocol.MessageTypeError { t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError) } if string(got.Body) != "unknown target: missing-peer" { t.Fatalf("error body = %q, want unknown target message", got.Body) } } func TestClientReceiveLoopHandlesForwardedMessages(t *testing.T) { hub := server.NewHub() cleanup := stubDialToHub(t, hub) defer cleanup() peerA, err := Dial("ignored", "peer-a") if err != nil { t.Fatalf("Dial(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() peerB, err := Dial("ignored", "peer-b") if err != nil { t.Fatalf("Dial(peer-b) error = %v", err) } defer func() { _ = peerB.Close() }() waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") var ( mu sync.Mutex got []protocol.Message ) loopErr := make(chan error, 1) go func() { loopErr <- peerB.ReceiveLoop(func(msg protocol.Message) error { mu.Lock() defer mu.Unlock() got = append(got, msg) if len(got) == 1 { return peerB.Close() } return nil }) }() if err := peerA.SendText("peer-b", "hello"); err != nil { t.Fatalf("SendText() error = %v", err) } err = <-loopErr if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection")) { t.Fatalf("ReceiveLoop() error = %v, want close-related error", err) } mu.Lock() defer mu.Unlock() want := []protocol.Message{ { Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), }, } if !reflect.DeepEqual(got, want) { t.Fatalf("received messages mismatch: got %+v want %+v", got, want) } } func TestClientSendLogsLatencyEvents(t *testing.T) { tests := []struct { name string setup func(*testing.T) string send func(*Client, string) error wantMsg protocol.Message wantEvents []string }{ { name: "text", send: func(client *Client, _ string) error { return client.SendText("peer-b", "hello") }, wantMsg: protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), }, wantEvents: []string{ latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd, }, }, { name: "file-bytes", send: func(client *Client, _ string) error { return client.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}) }, wantMsg: protocol.Message{ Type: protocol.MessageTypeFile, ID: 1, From: "peer-a", To: "peer-b", FileName: "payload.bin", Body: []byte{0x01, 0x02, 0x03}, }, wantEvents: []string{ latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd, }, }, { name: "file-path", setup: func(t *testing.T) string { t.Helper() path := filepath.Join(t.TempDir(), "payload.bin") if err := os.WriteFile(path, []byte{0x01, 0x02, 0x03}, 0o644); err != nil { t.Fatalf("os.WriteFile() error = %v", err) } return path }, send: func(client *Client, path string) error { return client.SendFilePath("peer-b", path) }, wantMsg: protocol.Message{ Type: protocol.MessageTypeFile, ID: 1, From: "peer-a", To: "peer-b", FileName: "payload.bin", Body: []byte{0x01, 0x02, 0x03}, }, wantEvents: []string{ latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { inputPath := "" if tt.setup != nil { inputPath = tt.setup(t) } logger := &recordingLogger{} clientConn, receiver := newClientTransportPair( t, []transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-a")}, nil, ) client := &Client{ id: "peer-a", conn: clientConn, logger: logger, } sendErr := make(chan error, 1) go func() { sendErr <- tt.send(client, inputPath) }() got, err := receiver.Receive() if err != nil { t.Fatalf("receiver.Receive() error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("send() error = %v", err) } if !reflect.DeepEqual(got, tt.wantMsg) { t.Fatalf("message mismatch: got %+v want %+v", got, tt.wantMsg) } events := logger.Events() if len(events) != len(tt.wantEvents) { t.Fatalf("event count = %d, want %d", len(events), len(tt.wantEvents)) } for i, wantEvent := range tt.wantEvents { if events[i].Event != wantEvent { t.Fatalf("event[%d] = %q, want %q", i, events[i].Event, wantEvent) } if events[i].MessageID != tt.wantMsg.ID || events[i].From != tt.wantMsg.From || events[i].To != tt.wantMsg.To { t.Fatalf("event[%d] metadata mismatch: %+v", i, events[i]) } } }) } } func TestClientReceiveLogsOnlyBusinessMessages(t *testing.T) { logger := &recordingLogger{} clientConn, sender := newClientTransportPair( t, []transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-b")}, nil, ) client := &Client{ id: "peer-b", conn: clientConn, logger: logger, } textMsg := protocol.Message{ Type: protocol.MessageTypeText, ID: 21, From: "peer-a", To: "peer-b", Body: []byte("hello"), } sendErr := make(chan error, 1) go func() { sendErr <- sender.Send(textMsg) }() if _, err := client.Receive(); err != nil { t.Fatalf("client.Receive(text) error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("sender.Send(text) error = %v", err) } errorMsg := protocol.Message{ Type: protocol.MessageTypeError, ID: 22, From: protocol.ServerPeerID, To: "peer-b", Body: []byte("failure"), } sendErr = make(chan error, 1) go func() { sendErr <- sender.Send(errorMsg) }() if _, err := client.Receive(); err != nil { t.Fatalf("client.Receive(error) error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("sender.Send(error) error = %v", err) } events := logger.Events() if len(events) != 2 { t.Fatalf("event count = %d, want 2", len(events)) } if events[0].Event != latencylog.EventBRXSoftware { t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBRXSoftware) } if events[1].Event != latencylog.EventBAppRecv { t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBAppRecv) } if events[0].MessageID != textMsg.ID || events[1].MessageID != textMsg.ID { t.Fatalf("message IDs = %d,%d, want %d", events[0].MessageID, events[1].MessageID, textMsg.ID) } } func TestClientPersistTextMessageWritesInboxFileAndLogs(t *testing.T) { inboxDir := t.TempDir() logger := &recordingLogger{} client := &Client{ id: "peer-b", logger: logger, } msg := protocol.Message{ Type: protocol.MessageTypeText, ID: 31, From: "peer-a", To: "peer-b", Body: []byte("hello"), } path, err := client.PersistMessage(msg, inboxDir) if err != nil { t.Fatalf("PersistMessage() error = %v", err) } if path != filepath.Join(inboxDir, textInboxFileName) { t.Fatalf("path = %q, want %q", path, filepath.Join(inboxDir, textInboxFileName)) } data, err := os.ReadFile(path) if err != nil { t.Fatalf("os.ReadFile() error = %v", err) } var record textInboxRecord if err := json.Unmarshal(bytes.TrimSpace(data), &record); err != nil { t.Fatalf("json.Unmarshal() error = %v", err) } if record.MessageID != msg.ID || record.From != msg.From || record.To != msg.To || record.Body != "hello" { t.Fatalf("record mismatch: got %+v want message %+v", record, msg) } events := logger.Events() if len(events) != 2 { t.Fatalf("event count = %d, want 2", len(events)) } if events[0].Event != latencylog.EventBPersistBegin { t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin) } if events[1].Event != latencylog.EventBPersistEnd { t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd) } } func TestClientPersistFileMessageWritesInboxFileAndLogs(t *testing.T) { inboxDir := t.TempDir() logger := &recordingLogger{} client := &Client{ id: "peer-b", logger: logger, } msg := protocol.Message{ Type: protocol.MessageTypeFile, ID: 32, From: "peer-a", To: "peer-b", FileName: "payload.bin", Body: []byte{0x01, 0x02, 0x03}, } path, err := client.PersistMessage(msg, inboxDir) if err != nil { t.Fatalf("PersistMessage() error = %v", err) } wantPath := filepath.Join(inboxDir, "peer-a-32-payload.bin") if path != wantPath { t.Fatalf("path = %q, want %q", path, wantPath) } data, err := os.ReadFile(path) if err != nil { t.Fatalf("os.ReadFile() error = %v", err) } if !reflect.DeepEqual(data, msg.Body) { t.Fatalf("file body mismatch: got %v want %v", data, msg.Body) } events := logger.Events() if len(events) != 2 { t.Fatalf("event count = %d, want 2", len(events)) } if events[0].Event != latencylog.EventBPersistBegin { t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin) } if events[1].Event != latencylog.EventBPersistEnd { t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd) } } func TestClientPersistMessageReturnsErrorOnWriteFailure(t *testing.T) { blocker := filepath.Join(t.TempDir(), "blocker") if err := os.WriteFile(blocker, []byte("not a directory"), 0o644); err != nil { t.Fatalf("os.WriteFile() error = %v", err) } logger := &recordingLogger{} client := &Client{ id: "peer-b", logger: logger, } msg := protocol.Message{ Type: protocol.MessageTypeText, ID: 33, From: "peer-a", To: "peer-b", Body: []byte("hello"), } if _, err := client.PersistMessage(msg, blocker); err == nil { t.Fatal("PersistMessage() error = nil, want non-nil") } events := logger.Events() if len(events) != 1 { t.Fatalf("event count = %d, want 1", len(events)) } if events[0].Event != latencylog.EventBPersistBegin { t.Fatalf("event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin) } } func TestClientIgnoresLoggerFailure(t *testing.T) { clientConn, receiver := newClientTransportPair( t, []transport.Option{transport.WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")}, nil, ) client := &Client{ id: "peer-a", conn: clientConn, logger: failingLogger{}, } sendErr := make(chan error, 1) go func() { sendErr <- client.SendText("peer-b", "hello") }() got, err := receiver.Receive() if err != nil { t.Fatalf("receiver.Receive() error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("SendText() error = %v, want nil even if logger fails", err) } if string(got.Body) != "hello" { t.Fatalf("body = %q, want hello", got.Body) } } func TestClientPersistIgnoresLoggerFailure(t *testing.T) { client := &Client{ id: "peer-b", logger: failingLogger{}, } msg := protocol.Message{ Type: protocol.MessageTypeText, ID: 34, From: "peer-a", To: "peer-b", Body: []byte("hello"), } path, err := client.PersistMessage(msg, t.TempDir()) if err != nil { t.Fatalf("PersistMessage() error = %v, want nil even if logger fails", err) } if path == "" { t.Fatal("PersistMessage() path = empty, want non-empty") } } func TestClientsExchangeMessagesWithLatencyLogs(t *testing.T) { hub := server.NewHub() cleanup := stubDialToHub(t, hub) defer cleanup() peerALogger := &recordingLogger{} peerA, err := Dial("ignored", "peer-a", WithLogger(peerALogger)) if err != nil { t.Fatalf("Dial(peer-a) error = %v", err) } defer func() { _ = peerA.Close() }() peerBLogger := &recordingLogger{} peerB, err := Dial("ignored", "peer-b", WithLogger(peerBLogger)) if err != nil { t.Fatalf("Dial(peer-b) error = %v", err) } defer func() { _ = peerB.Close() }() inboxDir := t.TempDir() waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") if err := peerA.SendText("peer-b", "hello"); err != nil { t.Fatalf("SendText() error = %v", err) } textMsg, err := peerB.Receive() if err != nil { t.Fatalf("peerB.Receive(text) error = %v", err) } if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil { t.Fatalf("peerB.PersistMessage(text) error = %v", err) } if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil { t.Fatalf("SendFile() error = %v", err) } fileMsg, err := peerB.Receive() if err != nil { t.Fatalf("peerB.Receive(file) error = %v", err) } if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil { t.Fatalf("peerB.PersistMessage(file) error = %v", err) } waitFor(t, func() bool { return len(peerALogger.Events()) == 10 }, "peer-a latency events") waitFor(t, func() bool { return len(peerBLogger.Events()) == 8 }, "peer-b latency events") assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{ 1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd}, 2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd}, }) assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{ 1: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd}, 2: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd}, }) } func stubDialToHub(t *testing.T, hub *server.Hub) func() { t.Helper() originalDial := dialServer serverAddr, cleanup := startRealHubServer(t, hub) dialServer = func(network, addr string) (net.Conn, error) { return net.Dial(network, serverAddr) } return func() { dialServer = originalDial cleanup() } } func waitFor(t *testing.T, condition func() bool, description string) { t.Helper() deadline := time.Now().Add(500 * time.Millisecond) for time.Now().Before(deadline) { if condition() { return } time.Sleep(10 * time.Millisecond) } t.Fatalf("timed out waiting for %s", description) } func assertEventSequencesByMessage(t *testing.T, events []latencylog.Event, want map[uint64][]string) { t.Helper() grouped := make(map[uint64][]latencylog.Event) for _, event := range events { grouped[event.MessageID] = append(grouped[event.MessageID], event) if event.TsUnixNano <= 0 { t.Fatalf("event timestamp must be positive: %+v", event) } } if len(grouped) != len(want) { t.Fatalf("message group count = %d, want %d", len(grouped), len(want)) } for messageID, wantEvents := range want { gotEvents := grouped[messageID] if len(gotEvents) != len(wantEvents) { t.Fatalf("message %d event count = %d, want %d", messageID, len(gotEvents), len(wantEvents)) } for i, wantEvent := range wantEvents { if gotEvents[i].Event != wantEvent { t.Fatalf("message %d event[%d] = %q, want %q", messageID, i, gotEvents[i].Event, wantEvent) } } } } func newClientTransportPair(t *testing.T, clientOpts []transport.Option, peerOpts []transport.Option) (*transport.TCPConn, *transport.TCPConn) { t.Helper() listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("net.Listen() error = %v", err) } type acceptResult struct { conn net.Conn err error } accepted := make(chan acceptResult, 1) go func() { conn, acceptErr := listener.Accept() accepted <- acceptResult{conn: conn, err: acceptErr} }() clientSide, err := net.Dial("tcp", listener.Addr().String()) if err != nil { _ = listener.Close() t.Fatalf("net.Dial() error = %v", err) } result := <-accepted if err := listener.Close(); err != nil { t.Fatalf("listener.Close() error = %v", err) } if result.err != nil { _ = clientSide.Close() t.Fatalf("listener.Accept() error = %v", result.err) } clientConn, err := transport.NewTCPConn(clientSide, clientOpts...) if err != nil { _ = clientSide.Close() _ = result.conn.Close() t.Fatalf("transport.NewTCPConn(client) error = %v", err) } peerConn, err := transport.NewTCPConn(result.conn, peerOpts...) if err != nil { _ = clientConn.Close() _ = result.conn.Close() t.Fatalf("transport.NewTCPConn(peer) error = %v", err) } t.Cleanup(func() { _ = clientConn.Close() _ = peerConn.Close() }) return clientConn, peerConn }