Files
OmniSocketGo/cmd/internal/transport/tcp_test.go

437 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package transport
import (
"errors"
"io"
"reflect"
"strings"
"sync"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
type recordingLogger struct {
mu sync.Mutex
events []latencylog.Event
}
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
l.mu.Lock()
defer l.mu.Unlock()
l.events = append(l.events, event)
return nil
}
func (l *recordingLogger) Events() []latencylog.Event {
l.mu.Lock()
defer l.mu.Unlock()
return append([]latencylog.Event(nil), l.events...)
}
type failingLogger struct{}
func (failingLogger) LogEvent(latencylog.Event) error {
return errors.New("log failed")
}
type recordingTXTimestampDebugLogger struct {
mu sync.Mutex
records []TXTimestampDebugRecord
}
func (l *recordingTXTimestampDebugLogger) LogTXTimestampDebugRecord(record TXTimestampDebugRecord) error {
l.mu.Lock()
defer l.mu.Unlock()
l.records = append(l.records, record)
return nil
}
func (l *recordingTXTimestampDebugLogger) Records() []TXTimestampDebugRecord {
l.mu.Lock()
defer l.mu.Unlock()
return append([]TXTimestampDebugRecord(nil), l.records...)
}
// TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。
func TestSendReceiveMessage(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0x10, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
//创建一个容量为1的缓冲通道sendErr用于接收发送操作的错误结果。
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg) //发送消息并将结果错误或nil发送到sendErr通道。
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
})
}
}
func TestSendLogsHandoffEvents(t *testing.T) {
logger := &recordingLogger{}
sender, receiver := newTransportConnPair(
t,
[]Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
)
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) != 4 {
t.Fatalf("event count = %d, want 4", len(events))
}
if events[0].Event != latencylog.EventSendHandoffBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
}
if events[1].Event != latencylog.EventATXSched {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched)
}
if events[2].Event != latencylog.EventATXSoftware {
t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware)
}
if events[3].Event != latencylog.EventSendHandoffEnd {
t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd)
}
for i, event := range events {
if event.MessageID != msg.ID {
t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID)
}
}
if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" {
t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a")
}
if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 {
t.Fatalf("timestamps must be positive: %+v", events)
}
}
func TestSendIgnoresLoggerFailure(t *testing.T) {
sender, receiver := newTransportConnPair(
t,
[]Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
nil,
)
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 9,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v, want nil even if logger fails", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
}
// TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
func TestReceiveLoopDeliversMessages(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
want := []protocol.Message{
{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
}
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
mu.Lock()
defer mu.Unlock()
got = append(got, msg)
return nil
})
}()
for _, msg := range want {
if err := sender.Send(msg); err != nil {
t.Fatalf("Send() error = %v", err)
}
}
if err := sender.Close(); err != nil {
t.Fatalf("sender.Close() error = %v", err)
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
}
if !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
}
mu.Lock()
defer mu.Unlock()
if !reflect.DeepEqual(got, want) {
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
}
}
// TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。
func TestConcurrentSendKeepsMessagesIntact(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
// 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。
want := []protocol.Message{
{Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")},
{Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")},
{Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")},
{Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")},
}
received := make(chan protocol.Message, len(want))
readErr := make(chan error, 1)
go func() { //异步地运行一个 goroutine
for range want {
msg, err := receiver.Receive()
if err != nil {
readErr <- err
return
}
received <- msg
}
readErr <- nil
}()
var wg sync.WaitGroup
for _, msg := range want {
msg := msg
wg.Add(1)
go func() { //异步处理
defer wg.Done()
if err := sender.Send(msg); err != nil {
t.Errorf("Send() error = %v", err)
}
}()
}
wg.Wait()
if err := <-readErr; err != nil {
t.Fatalf("Receive() error = %v", err)
}
gotByID := make(map[uint64]protocol.Message, len(want))
for range want {
msg := <-received
gotByID[msg.ID] = msg
}
for _, msg := range want {
got, ok := gotByID[msg.ID]
if !ok {
t.Fatalf("missing message with ID %d", msg.ID)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg)
}
}
}
// TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。
func TestReceiveLoopStopsOnHandlerError(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
wantErr := errors.New("stop loop")
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
return wantErr
})
}()
first := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
if err := sender.Send(first); err != nil {
t.Fatalf("Send(first) error = %v", err)
}
err := <-loopErr
if !errors.Is(err, wantErr) {
t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr)
}
if !strings.Contains(err.Error(), "receive loop handler") {
t.Fatalf("ReceiveLoop() error = %v, want handler context", err)
}
}
// TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。
func TestReceiveLoopStopsOnReadError(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
return nil
})
}()
if err := sender.Close(); err != nil {
t.Fatalf("sender.Close() error = %v", err)
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
}
}
// TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。
func TestCloseIsIdempotent(t *testing.T) {
conn, peer := newTransportConnPair(t, nil, nil)
if err := conn.Close(); err != nil {
t.Fatalf("Close(first) error = %v", err)
}
if err := conn.Close(); err != nil {
t.Fatalf("Close(second) error = %v, want nil", err)
}
if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") {
t.Fatalf("peer.Close() error = %v", err)
}
}
// TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。
func TestReceiveReturnsWrappedReadError(t *testing.T) {
conn, peer := newTransportConnPair(t, nil, nil)
go func() {
_ = peer.Close()
}()
_, err := conn.Receive()
if err == nil {
t.Fatal("Receive() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "transport: receive message") {
t.Fatalf("Receive() error = %v, want wrapped receive error", err)
}
if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") {
t.Fatalf("Receive() error = %v, want underlying read failure", err)
}
}
func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) {
t.Helper()
left, right := newTCPPair(t)
sender, err := NewTCPConn(left, senderOpts...)
if err != nil {
t.Fatalf("NewTCPConn(sender) error = %v", err)
}
receiver, err := NewTCPConn(right, receiverOpts...)
if err != nil {
t.Fatalf("NewTCPConn(receiver) error = %v", err)
}
t.Cleanup(func() {
_ = sender.Close()
_ = receiver.Close()
})
return sender, receiver
}