package transport import ( "errors" "io" "reflect" "strings" "sync" "testing" "omnisocketgo/cmd/internal/latencylog" "omnisocketgo/cmd/internal/protocol" ) type recordingLogger struct { mu sync.Mutex events []latencylog.Event } func (l *recordingLogger) LogEvent(event latencylog.Event) error { l.mu.Lock() defer l.mu.Unlock() l.events = append(l.events, event) return nil } func (l *recordingLogger) Events() []latencylog.Event { l.mu.Lock() defer l.mu.Unlock() return append([]latencylog.Event(nil), l.events...) } type failingLogger struct{} func (failingLogger) LogEvent(latencylog.Event) error { return errors.New("log failed") } // TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。 func TestSendReceiveMessage(t *testing.T) { tests := []struct { name string msg protocol.Message }{ { name: "text", msg: protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), }, }, { name: "file", msg: protocol.Message{ Type: protocol.MessageTypeFile, ID: 2, From: "peer-a", To: "peer-b", FileName: "data.bin", Body: []byte{0x00, 0x10, 0xff}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sender, receiver := newTransportConnPair(t, nil, nil) //创建一个容量为1的缓冲通道sendErr,用于接收发送操作的错误结果。 sendErr := make(chan error, 1) go func() { sendErr <- sender.Send(tt.msg) //发送消息,并将结果(错误或nil)发送到sendErr通道。 }() got, err := receiver.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。 t.Fatalf("Send() error = %v", err) } if !reflect.DeepEqual(got, tt.msg) { t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg) } }) } } func TestSendLogsHandoffEvents(t *testing.T) { logger := &recordingLogger{} sender, receiver := newTransportConnPair( t, []Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")}, nil, ) msg := protocol.Message{ Type: protocol.MessageTypeText, ID: 7, From: "peer-a", To: "peer-b", Body: []byte("hello"), } sendErr := make(chan error, 1) go func() { sendErr <- sender.Send(msg) }() got, err := receiver.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("Send() error = %v", err) } if !reflect.DeepEqual(got, msg) { t.Fatalf("message mismatch: got %+v want %+v", got, msg) } events := logger.Events() if len(events) != 4 { t.Fatalf("event count = %d, want 4", len(events)) } if events[0].Event != latencylog.EventSendHandoffBegin { t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin) } if events[1].Event != latencylog.EventATXSched { t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched) } if events[2].Event != latencylog.EventATXSoftware { t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware) } if events[3].Event != latencylog.EventSendHandoffEnd { t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd) } for i, event := range events { if event.MessageID != msg.ID { t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID) } } if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" { t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a") } if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 { t.Fatalf("timestamps must be positive: %+v", events) } } func TestSendIgnoresLoggerFailure(t *testing.T) { sender, receiver := newTransportConnPair( t, []Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")}, nil, ) msg := protocol.Message{ Type: protocol.MessageTypeText, ID: 9, From: "peer-a", To: "peer-b", Body: []byte("hello"), } sendErr := make(chan error, 1) go func() { sendErr <- sender.Send(msg) }() got, err := receiver.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("Send() error = %v, want nil even if logger fails", err) } if !reflect.DeepEqual(got, msg) { t.Fatalf("message mismatch: got %+v want %+v", got, msg) } } // TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。 func TestReceiveLoopDeliversMessages(t *testing.T) { sender, receiver := newTransportConnPair(t, nil, nil) want := []protocol.Message{ { Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), }, { Type: protocol.MessageTypeFile, ID: 2, From: "peer-a", To: "peer-b", FileName: "payload.bin", Body: []byte{0x01, 0x02, 0x03}, }, } var ( mu sync.Mutex got []protocol.Message ) loopErr := make(chan error, 1) go func() { loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error { mu.Lock() defer mu.Unlock() got = append(got, msg) return nil }) }() for _, msg := range want { if err := sender.Send(msg); err != nil { t.Fatalf("Send() error = %v", err) } } if err := sender.Close(); err != nil { t.Fatalf("sender.Close() error = %v", err) } err := <-loopErr if err == nil { t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close") } if !strings.Contains(err.Error(), "receive loop read") { t.Fatalf("ReceiveLoop() error = %v, want read context", err) } mu.Lock() defer mu.Unlock() if !reflect.DeepEqual(got, want) { t.Fatalf("received messages mismatch: got %+v want %+v", got, want) } } // TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。 func TestConcurrentSendKeepsMessagesIntact(t *testing.T) { sender, receiver := newTransportConnPair(t, nil, nil) // 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。 want := []protocol.Message{ {Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")}, {Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")}, {Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")}, {Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")}, } received := make(chan protocol.Message, len(want)) readErr := make(chan error, 1) go func() { //异步地运行一个 goroutine for range want { msg, err := receiver.Receive() if err != nil { readErr <- err return } received <- msg } readErr <- nil }() var wg sync.WaitGroup for _, msg := range want { msg := msg wg.Add(1) go func() { //异步处理 defer wg.Done() if err := sender.Send(msg); err != nil { t.Errorf("Send() error = %v", err) } }() } wg.Wait() if err := <-readErr; err != nil { t.Fatalf("Receive() error = %v", err) } gotByID := make(map[uint64]protocol.Message, len(want)) for range want { msg := <-received gotByID[msg.ID] = msg } for _, msg := range want { got, ok := gotByID[msg.ID] if !ok { t.Fatalf("missing message with ID %d", msg.ID) } if !reflect.DeepEqual(got, msg) { t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg) } } } // TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。 func TestReceiveLoopStopsOnHandlerError(t *testing.T) { sender, receiver := newTransportConnPair(t, nil, nil) wantErr := errors.New("stop loop") loopErr := make(chan error, 1) go func() { loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error { return wantErr }) }() first := protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), } if err := sender.Send(first); err != nil { t.Fatalf("Send(first) error = %v", err) } err := <-loopErr if !errors.Is(err, wantErr) { t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr) } if !strings.Contains(err.Error(), "receive loop handler") { t.Fatalf("ReceiveLoop() error = %v, want handler context", err) } } // TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。 func TestReceiveLoopStopsOnReadError(t *testing.T) { sender, receiver := newTransportConnPair(t, nil, nil) loopErr := make(chan error, 1) go func() { loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error { return nil }) }() if err := sender.Close(); err != nil { t.Fatalf("sender.Close() error = %v", err) } err := <-loopErr if err == nil { t.Fatal("ReceiveLoop() error = nil, want non-nil") } if !strings.Contains(err.Error(), "receive loop read") { t.Fatalf("ReceiveLoop() error = %v, want read context", err) } } // TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。 func TestCloseIsIdempotent(t *testing.T) { conn, peer := newTransportConnPair(t, nil, nil) if err := conn.Close(); err != nil { t.Fatalf("Close(first) error = %v", err) } if err := conn.Close(); err != nil { t.Fatalf("Close(second) error = %v, want nil", err) } if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") { t.Fatalf("peer.Close() error = %v", err) } } // TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。 func TestReceiveReturnsWrappedReadError(t *testing.T) { conn, peer := newTransportConnPair(t, nil, nil) go func() { _ = peer.Close() }() _, err := conn.Receive() if err == nil { t.Fatal("Receive() error = nil, want non-nil") } if !strings.Contains(err.Error(), "transport: receive message") { t.Fatalf("Receive() error = %v, want wrapped receive error", err) } if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") { t.Fatalf("Receive() error = %v, want underlying read failure", err) } } func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) { t.Helper() left, right := newTCPPair(t) sender, err := NewTCPConn(left, senderOpts...) if err != nil { t.Fatalf("NewTCPConn(sender) error = %v", err) } receiver, err := NewTCPConn(right, receiverOpts...) if err != nil { t.Fatalf("NewTCPConn(receiver) error = %v", err) } t.Cleanup(func() { _ = sender.Close() _ = receiver.Close() }) return sender, receiver }