init
This commit is contained in:
416
cmd/internal/transport/tcp_test.go
Normal file
416
cmd/internal/transport/tcp_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
type recordingLogger struct {
|
||||
mu sync.Mutex
|
||||
events []latencylog.Event
|
||||
}
|
||||
|
||||
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.events = append(l.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *recordingLogger) Events() []latencylog.Event {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
return append([]latencylog.Event(nil), l.events...)
|
||||
}
|
||||
|
||||
type failingLogger struct{}
|
||||
|
||||
func (failingLogger) LogEvent(latencylog.Event) error {
|
||||
return errors.New("log failed")
|
||||
}
|
||||
|
||||
// TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。
|
||||
func TestSendReceiveMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "data.bin",
|
||||
Body: []byte{0x00, 0x10, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
//创建一个容量为1的缓冲通道sendErr,用于接收发送操作的错误结果。
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg) //发送消息,并将结果(错误或nil)发送到sendErr通道。
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLogsHandoffEvents(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
sender, receiver := newTransportConnPair(
|
||||
t,
|
||||
[]Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("event count = %d, want 4", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventSendHandoffBegin {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
|
||||
}
|
||||
if events[1].Event != latencylog.EventATXSched {
|
||||
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched)
|
||||
}
|
||||
if events[2].Event != latencylog.EventATXSoftware {
|
||||
t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware)
|
||||
}
|
||||
if events[3].Event != latencylog.EventSendHandoffEnd {
|
||||
t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd)
|
||||
}
|
||||
for i, event := range events {
|
||||
if event.MessageID != msg.ID {
|
||||
t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID)
|
||||
}
|
||||
}
|
||||
if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" {
|
||||
t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a")
|
||||
}
|
||||
if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 {
|
||||
t.Fatalf("timestamps must be positive: %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendIgnoresLoggerFailure(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(
|
||||
t,
|
||||
[]Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 9,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v, want nil even if logger fails", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
|
||||
func TestReceiveLoopDeliversMessages(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
want := []protocol.Message{
|
||||
{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
got []protocol.Message
|
||||
)
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
got = append(got, msg)
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
for _, msg := range want {
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
}
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。
|
||||
func TestConcurrentSendKeepsMessagesIntact(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
// 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。
|
||||
want := []protocol.Message{
|
||||
{Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")},
|
||||
{Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")},
|
||||
{Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")},
|
||||
{Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")},
|
||||
}
|
||||
|
||||
received := make(chan protocol.Message, len(want))
|
||||
readErr := make(chan error, 1)
|
||||
go func() { //异步地运行一个 goroutine
|
||||
for range want {
|
||||
msg, err := receiver.Receive()
|
||||
if err != nil {
|
||||
readErr <- err
|
||||
return
|
||||
}
|
||||
received <- msg
|
||||
}
|
||||
readErr <- nil
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, msg := range want {
|
||||
msg := msg
|
||||
wg.Add(1)
|
||||
go func() { //异步处理
|
||||
defer wg.Done()
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Errorf("Send() error = %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if err := <-readErr; err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
|
||||
gotByID := make(map[uint64]protocol.Message, len(want))
|
||||
for range want {
|
||||
msg := <-received
|
||||
gotByID[msg.ID] = msg
|
||||
}
|
||||
|
||||
for _, msg := range want {
|
||||
got, ok := gotByID[msg.ID]
|
||||
if !ok {
|
||||
t.Fatalf("missing message with ID %d", msg.ID)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。
|
||||
func TestReceiveLoopStopsOnHandlerError(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
wantErr := errors.New("stop loop")
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
return wantErr
|
||||
})
|
||||
}()
|
||||
|
||||
first := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
if err := sender.Send(first); err != nil {
|
||||
t.Fatalf("Send(first) error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop handler") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want handler context", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。
|
||||
func TestReceiveLoopStopsOnReadError(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。
|
||||
func TestCloseIsIdempotent(t *testing.T) {
|
||||
conn, peer := newTransportConnPair(t, nil, nil)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(first) error = %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(second) error = %v, want nil", err)
|
||||
}
|
||||
if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("peer.Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。
|
||||
func TestReceiveReturnsWrappedReadError(t *testing.T) {
|
||||
conn, peer := newTransportConnPair(t, nil, nil)
|
||||
go func() {
|
||||
_ = peer.Close()
|
||||
}()
|
||||
|
||||
_, err := conn.Receive()
|
||||
if err == nil {
|
||||
t.Fatal("Receive() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "transport: receive message") {
|
||||
t.Fatalf("Receive() error = %v, want wrapped receive error", err)
|
||||
}
|
||||
if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("Receive() error = %v, want underlying read failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) {
|
||||
t.Helper()
|
||||
|
||||
left, right := newTCPPair(t)
|
||||
|
||||
sender, err := NewTCPConn(left, senderOpts...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(sender) error = %v", err)
|
||||
}
|
||||
receiver, err := NewTCPConn(right, receiverOpts...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(receiver) error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
})
|
||||
|
||||
return sender, receiver
|
||||
}
|
||||
Reference in New Issue
Block a user