init
This commit is contained in:
166
cmd/internal/latencylog/logger.go
Normal file
166
cmd/internal/latencylog/logger.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package latencylog
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
NodeRolePeer = "peer" //客户端节点
|
||||
NodeRoleServer = "server" //云端转发节点
|
||||
)
|
||||
|
||||
// 记录的消息事件的类型常量。
|
||||
const (
|
||||
EventAAppPrepBegin = "A_APP_PREP_BEGIN" // A 端应用开始准备这条消息
|
||||
EventATXSched = "A_TX_SCHED" // A 端进入 Linux qdisc 之前
|
||||
EventATXSoftware = "A_TX_SOFTWARE" // A 端即将交给网卡驱动
|
||||
EventATXHardware = "A_TX_HARDWARE" // A 端网卡真正发出到物理介质
|
||||
EventBRXHardware = "B_RX_HARDWARE" // B 端网卡真正从物理介质收到
|
||||
EventBRXSoftware = "B_RX_SOFTWARE" // B 端驱动把数据交给 Linux 接收栈
|
||||
EventBAppRecv = "B_APP_RECV" // B 端应用真正读到完整消息
|
||||
EventBPersistBegin = "B_PERSIST_BEGIN" // B 端开始写盘
|
||||
EventBPersistEnd = "B_PERSIST_END" // B 端写盘完成
|
||||
|
||||
EventSendHandoffBegin = "send_handoff_begin" // 调试事件:应用把消息交给传输层开始
|
||||
EventSendHandoffEnd = "send_handoff_end" // 调试事件:应用把消息交给传输层结束
|
||||
)
|
||||
|
||||
// Event 是一条时延时间戳日志记录。
|
||||
type Event struct {
|
||||
TsUnixNano int64 `json:"ts_unix_nano"`
|
||||
NodeRole string `json:"node_role"`
|
||||
NodeID string `json:"node_id"`
|
||||
Event string `json:"event"`
|
||||
MessageType protocol.MessageType `json:"message_type"`
|
||||
MessageID uint64 `json:"message_id"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
BodySize int `json:"body_size"`
|
||||
}
|
||||
|
||||
// Logger 负责接收事件并将其写入外部介质。
|
||||
type Logger interface {
|
||||
LogEvent(Event) error
|
||||
}
|
||||
|
||||
// NoopLogger 是默认的空实现。
|
||||
type NoopLogger struct{}
|
||||
|
||||
// LogEvent 对空日志实现始终返回 nil。
|
||||
func (NoopLogger) LogEvent(Event) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// JSONLLogger 以 JSONL 形式追加写日志文件。
|
||||
type JSONLLogger struct {
|
||||
mu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
file *os.File
|
||||
}
|
||||
|
||||
// NewJSONLLogger 创建一个线程安全的 JSONL 文件日志器。
|
||||
func NewJSONLLogger(path string) (*JSONLLogger, error) {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &JSONLLogger{file: file}, nil
|
||||
}
|
||||
|
||||
// LogEvent 以单行 JSON 的形式追加一条事件。
|
||||
func (l *JSONLLogger) LogEvent(event Event) error {
|
||||
line, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if _, err := l.file.Write(append(line, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭底层文件;重复调用是安全的。
|
||||
func (l *JSONLLogger) Close() error {
|
||||
l.closeOnce.Do(func() {
|
||||
l.closeErr = l.file.Close()
|
||||
})
|
||||
|
||||
return l.closeErr
|
||||
}
|
||||
|
||||
// IsBusinessMessage 判断消息是否属于要参与 A-C-B 时延分析的业务消息。
|
||||
func IsBusinessMessage(msg protocol.Message) bool {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageEvent 用当前 UTC 时间为一条业务消息构造事件。
|
||||
func NewMessageEvent(nodeRole, nodeID, eventName string, msg protocol.Message) Event {
|
||||
return NewMessageEventAt(time.Now().UTC().UnixNano(), nodeRole, nodeID, eventName, msg)
|
||||
}
|
||||
|
||||
// NewMessageEventAt 用指定的 UnixNano 时间为一条业务消息构造事件。
|
||||
func NewMessageEventAt(tsUnixNano int64, nodeRole, nodeID, eventName string, msg protocol.Message) Event {
|
||||
return Event{
|
||||
TsUnixNano: tsUnixNano,
|
||||
NodeRole: nodeRole,
|
||||
NodeID: nodeID,
|
||||
Event: eventName,
|
||||
MessageType: msg.Type,
|
||||
MessageID: msg.ID,
|
||||
From: msg.From,
|
||||
To: msg.To,
|
||||
FileName: msg.FileName,
|
||||
BodySize: len(msg.Body),
|
||||
}
|
||||
}
|
||||
|
||||
// LogBestEffort 写一条事件,失败时静默忽略,避免打断主收发流程。
|
||||
func LogBestEffort(logger Logger, event Event) {
|
||||
if logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = logger.LogEvent(event)
|
||||
}
|
||||
|
||||
// LogMessageEvent 为业务消息构造并写入一条事件。
|
||||
func LogMessageEvent(logger Logger, nodeRole, nodeID, eventName string, msg protocol.Message) {
|
||||
if !IsBusinessMessage(msg) {
|
||||
return
|
||||
}
|
||||
|
||||
LogBestEffort(logger, NewMessageEvent(nodeRole, nodeID, eventName, msg))
|
||||
}
|
||||
|
||||
// LogMessageEventAt 为业务消息写入一条指定时间戳的事件。
|
||||
func LogMessageEventAt(logger Logger, nodeRole, nodeID, eventName string, tsUnixNano int64, msg protocol.Message) {
|
||||
if !IsBusinessMessage(msg) {
|
||||
return
|
||||
}
|
||||
|
||||
LogBestEffort(logger, NewMessageEventAt(tsUnixNano, nodeRole, nodeID, eventName, msg))
|
||||
}
|
||||
131
cmd/internal/latencylog/logger_test.go
Normal file
131
cmd/internal/latencylog/logger_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package latencylog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
func TestJSONLLoggerWritesOneEventPerLine(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "latency.jsonl")
|
||||
|
||||
logger, err := NewJSONLLogger(path)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLLogger() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = logger.Close()
|
||||
})
|
||||
|
||||
event := Event{
|
||||
TsUnixNano: 123,
|
||||
NodeRole: NodeRolePeer,
|
||||
NodeID: "peer-a",
|
||||
Event: EventAAppPrepBegin,
|
||||
MessageType: protocol.MessageTypeText,
|
||||
MessageID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
BodySize: 5,
|
||||
}
|
||||
if err := logger.LogEvent(event); err != nil {
|
||||
t.Fatalf("LogEvent() error = %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Open() error = %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
if !scanner.Scan() {
|
||||
t.Fatal("expected one JSONL line, got none")
|
||||
}
|
||||
|
||||
var got Event
|
||||
if err := json.Unmarshal(scanner.Bytes(), &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if got != event {
|
||||
t.Fatalf("event mismatch: got %+v want %+v", got, event)
|
||||
}
|
||||
if scanner.Scan() {
|
||||
t.Fatal("expected exactly one JSONL line")
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("scanner.Err() = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLLoggerHandlesConcurrentWrites(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "latency.jsonl")
|
||||
|
||||
logger, err := NewJSONLLogger(path)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLLogger() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = logger.Close()
|
||||
})
|
||||
|
||||
const total = 32
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < total; i++ {
|
||||
i := i
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := logger.LogEvent(Event{
|
||||
TsUnixNano: int64(i + 1),
|
||||
NodeRole: NodeRoleServer,
|
||||
NodeID: protocol.ServerPeerID,
|
||||
Event: EventBAppRecv,
|
||||
MessageType: protocol.MessageTypeFile,
|
||||
MessageID: uint64(i + 1),
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
BodySize: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("LogEvent() error = %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Open() error = %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
var count int
|
||||
seen := make(map[uint64]bool, total)
|
||||
for scanner.Scan() {
|
||||
var event Event
|
||||
if err := json.Unmarshal(scanner.Bytes(), &event); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
count++
|
||||
seen[event.MessageID] = true
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("scanner.Err() = %v", err)
|
||||
}
|
||||
if count != total {
|
||||
t.Fatalf("line count = %d, want %d", count, total)
|
||||
}
|
||||
if len(seen) != total {
|
||||
t.Fatalf("unique message count = %d, want %d", len(seen), total)
|
||||
}
|
||||
}
|
||||
254
cmd/internal/latencylog/summary.go
Normal file
254
cmd/internal/latencylog/summary.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package latencylog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// Summary 是针对单条消息的时延的规则列表。
|
||||
var requiredTimestampNames = []string{
|
||||
EventAAppPrepBegin, // A 端应用开始准备这条消息
|
||||
EventATXSched, // A 端进入 Linux qdisc 之前
|
||||
EventATXSoftware, // A 端即将交给网卡驱动
|
||||
EventBRXSoftware, // B 端网卡驱动把数据交给 Linux 接收栈
|
||||
EventBAppRecv, // B 端应用真正读到完整消息
|
||||
EventBPersistEnd, // B 端写盘完成
|
||||
}
|
||||
|
||||
// Summary 是针对单条消息的时延整理结果。
|
||||
type Summary struct {
|
||||
MessageType protocol.MessageType `json:"message_type"` //消息类型
|
||||
MessageID uint64 `json:"message_id"` //消息ID
|
||||
From string `json:"from"` //发送方
|
||||
To string `json:"to"` //接收方
|
||||
FileName string `json:"file_name,omitempty"` //文件名(仅文件消息)
|
||||
BodySize int `json:"body_size"` //消息体大小(字节数)
|
||||
Timestamps map[string]int64 `json:"timestamps"` //事件时间戳,key 是事件名称,value 是 UnixNano 时间戳
|
||||
|
||||
AProcessingLatencyNS *int64 `json:"a_processing_latency_ns,omitempty"` // A 处理时延:A_TX_SCHED - A_APP_PREP_BEGIN
|
||||
AQueueLatencyNS *int64 `json:"a_queue_latency_ns,omitempty"` // A 排队时延:A_TX_SOFTWARE - A_TX_SCHED
|
||||
ABTransportPropagationBQueueLatencyNS *int64 `json:"a_b_transport_propagation_b_queue_latency_ns,omitempty"` // A-B 传输时延 + B 排队时延:B_APP_RECV - A_TX_SOFTWARE
|
||||
BKernelReceivePathLatencyNS *int64 `json:"b_kernel_receive_path_latency_ns,omitempty"` // B 内核接收路径近似:B_APP_RECV - B_RX_SOFTWARE
|
||||
BProcessingLatencyNS *int64 `json:"b_processing_latency_ns,omitempty"` // B 处理时延:B_PERSIST_END - B_APP_RECV
|
||||
EndToEndLatencyNS *int64 `json:"end_to_end_latency_ns,omitempty"` // 端到端时延:B_PERSIST_END - A_APP_PREP_BEGIN
|
||||
MissingTimestamps []string `json:"missing_timestamps,omitempty"` // 缺失的时间戳列表,包含 requiredTimestampNames 中但在原始事件中没有的事件名称
|
||||
}
|
||||
|
||||
// LoadEventsFromFiles 从JSONL 原始日志文件中加载事件。
|
||||
type messageKey struct {
|
||||
MessageType protocol.MessageType //消息类型
|
||||
MessageID uint64 //消息ID
|
||||
From string //发送方
|
||||
To string //接收方
|
||||
}
|
||||
|
||||
// LoadEventsFromFiles 从多个 JSONL 原始日志文件中加载事件。
|
||||
func LoadEventsFromFiles(paths []string) ([]Event, error) {
|
||||
var events []Event
|
||||
for _, path := range paths {
|
||||
fileEvents, err := LoadEventsFromFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
events = append(events, fileEvents...)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// LoadEventsFromFile 从单个 JSONL 原始日志文件中加载事件。
|
||||
func LoadEventsFromFile(path string) ([]Event, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("latencylog: open raw log %s: %w", path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var events []Event
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
if len(scanner.Bytes()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var event Event
|
||||
if err := json.Unmarshal(scanner.Bytes(), &event); err != nil { //解析 JSONL 行失败,返回错误
|
||||
return nil, fmt.Errorf("latencylog: decode event from %s: %w", path, err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("latencylog: scan raw log %s: %w", path, err)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// SummarizeEvents 将原始事件整理成按消息分组的时延结果。
|
||||
func SummarizeEvents(events []Event) []Summary {
|
||||
grouped := make(map[messageKey]*Summary)
|
||||
|
||||
for _, event := range events {
|
||||
if !IsBusinessEvent(event) {
|
||||
continue
|
||||
}
|
||||
|
||||
key := messageKey{
|
||||
MessageType: event.MessageType,
|
||||
MessageID: event.MessageID,
|
||||
From: event.From,
|
||||
To: event.To,
|
||||
}
|
||||
|
||||
summary, ok := grouped[key]
|
||||
if !ok {
|
||||
summary = &Summary{
|
||||
MessageType: event.MessageType,
|
||||
MessageID: event.MessageID,
|
||||
From: event.From,
|
||||
To: event.To,
|
||||
FileName: event.FileName,
|
||||
BodySize: event.BodySize,
|
||||
Timestamps: make(map[string]int64),
|
||||
}
|
||||
grouped[key] = summary
|
||||
}
|
||||
|
||||
if summary.FileName == "" {
|
||||
summary.FileName = event.FileName
|
||||
}
|
||||
if event.BodySize > 0 {
|
||||
summary.BodySize = event.BodySize
|
||||
}
|
||||
|
||||
if existing, exists := summary.Timestamps[event.Event]; !exists || event.TsUnixNano < existing {
|
||||
summary.Timestamps[event.Event] = event.TsUnixNano
|
||||
}
|
||||
}
|
||||
|
||||
summaries := make([]Summary, 0, len(grouped))
|
||||
for _, summary := range grouped {
|
||||
completeSummary(summary) //补全时延指标和缺失时间戳信息
|
||||
summaries = append(summaries, *summary)
|
||||
}
|
||||
//对整理结果进行排序,先按发送方、再按接收方、再按消息 ID、最后按消息类型排序,保证输出的稳定性和可读性。
|
||||
sort.Slice(summaries, func(i, j int) bool {
|
||||
if summaries[i].From != summaries[j].From {
|
||||
return summaries[i].From < summaries[j].From
|
||||
}
|
||||
if summaries[i].To != summaries[j].To {
|
||||
return summaries[i].To < summaries[j].To
|
||||
}
|
||||
if summaries[i].MessageID != summaries[j].MessageID {
|
||||
return summaries[i].MessageID < summaries[j].MessageID
|
||||
}
|
||||
return summaries[i].MessageType < summaries[j].MessageType
|
||||
})
|
||||
|
||||
return summaries
|
||||
}
|
||||
|
||||
// WriteSummariesJSONL 将整理结果写成 JSONL 汇总文件。
|
||||
func WriteSummariesJSONL(path string, summaries []Summary) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return fmt.Errorf("latencylog: create summary dir for %s: %w", path, err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("latencylog: open summary file %s: %w", path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
writer := bufio.NewWriter(file)
|
||||
for _, summary := range summaries { //将每条整理结果编码成 JSONL 行并写入文件
|
||||
line, err := json.Marshal(summary)
|
||||
if err != nil {
|
||||
return fmt.Errorf("latencylog: encode summary for message %d: %w", summary.MessageID, err)
|
||||
}
|
||||
if _, err := writer.Write(append(line, '\n')); err != nil {
|
||||
return fmt.Errorf("latencylog: write summary file %s: %w", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := writer.Flush(); err != nil { //将缓冲区内容写入文件
|
||||
return fmt.Errorf("latencylog: flush summary file %s: %w", path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// completeSummary 根据事件时间戳计算时延指标,并找出缺失的时间戳。
|
||||
func completeSummary(summary *Summary) {
|
||||
summary.MissingTimestamps = missingTimestampNames(summary.Timestamps)
|
||||
|
||||
if value := subtractIfPresent(summary.Timestamps, EventATXSched, EventAAppPrepBegin); value != nil {
|
||||
summary.AProcessingLatencyNS = value
|
||||
}
|
||||
if value := subtractIfPresent(summary.Timestamps, EventATXSoftware, EventATXSched); value != nil {
|
||||
summary.AQueueLatencyNS = value
|
||||
}
|
||||
if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventATXSoftware); value != nil {
|
||||
summary.ABTransportPropagationBQueueLatencyNS = value
|
||||
}
|
||||
if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventBRXSoftware); value != nil {
|
||||
summary.BKernelReceivePathLatencyNS = value
|
||||
}
|
||||
if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventBAppRecv); value != nil {
|
||||
summary.BProcessingLatencyNS = value
|
||||
}
|
||||
if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventAAppPrepBegin); value != nil {
|
||||
summary.EndToEndLatencyNS = value
|
||||
}
|
||||
}
|
||||
|
||||
// 返回 requiredTimestampNames 中哪些在给定的 timestamps 中缺失。
|
||||
func missingTimestampNames(timestamps map[string]int64) []string {
|
||||
var missing []string
|
||||
for _, name := range requiredTimestampNames {
|
||||
if _, ok := timestamps[name]; !ok {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
|
||||
return missing
|
||||
}
|
||||
|
||||
// 如果 timestamps 中同时存在 endName 和 beginName,则返回它们的差值;否则返回 nil。
|
||||
func subtractIfPresent(timestamps map[string]int64, endName, beginName string) *int64 {
|
||||
end, ok := timestamps[endName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
begin, ok := timestamps[beginName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := end - begin
|
||||
return &value
|
||||
}
|
||||
|
||||
// 判断事件是否是业务相关的时延事件(其中一项)
|
||||
func IsBusinessEvent(event Event) bool {
|
||||
switch event.Event {
|
||||
case EventAAppPrepBegin,
|
||||
EventATXSched,
|
||||
EventATXSoftware,
|
||||
EventATXHardware,
|
||||
EventBRXHardware,
|
||||
EventBRXSoftware,
|
||||
EventBAppRecv,
|
||||
EventBPersistBegin,
|
||||
EventBPersistEnd:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
440
cmd/internal/latencylog/summary_chart.go
Normal file
440
cmd/internal/latencylog/summary_chart.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package latencylog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
const summaryChartHTMLTemplate = `<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Latency Summary Chart</title>
|
||||
<style>
|
||||
:root {
|
||||
color-scheme: light;
|
||||
--bg: #f6f7fb;
|
||||
--panel: #ffffff;
|
||||
--text: #172033;
|
||||
--muted: #60708a;
|
||||
--border: #d9dfeb;
|
||||
--track: #e8edf5;
|
||||
--a-proc: #3b82f6;
|
||||
--a-queue: #14b8a6;
|
||||
--transport: #f59e0b;
|
||||
--b-proc: #22c55e;
|
||||
--unknown: #94a3b8;
|
||||
}
|
||||
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif;
|
||||
background: linear-gradient(180deg, #eef3ff 0%, var(--bg) 220px);
|
||||
color: var(--text);
|
||||
}
|
||||
main {
|
||||
max-width: 1120px;
|
||||
margin: 0 auto;
|
||||
padding: 32px 20px 48px;
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 8px;
|
||||
font-size: 32px;
|
||||
line-height: 1.1;
|
||||
}
|
||||
.intro {
|
||||
color: var(--muted);
|
||||
margin: 0 0 24px;
|
||||
font-size: 15px;
|
||||
}
|
||||
.stats {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
||||
gap: 12px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.stat {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 14px;
|
||||
padding: 14px 16px;
|
||||
box-shadow: 0 10px 30px rgba(23, 32, 51, 0.06);
|
||||
}
|
||||
.stat-label {
|
||||
color: var(--muted);
|
||||
font-size: 12px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
.stat-value {
|
||||
font-size: 22px;
|
||||
font-weight: 700;
|
||||
}
|
||||
.legend {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 12px;
|
||||
margin-bottom: 20px;
|
||||
color: var(--muted);
|
||||
font-size: 13px;
|
||||
}
|
||||
.legend-item {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.swatch {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 999px;
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
.rows {
|
||||
display: grid;
|
||||
gap: 14px;
|
||||
}
|
||||
.row {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 16px;
|
||||
padding: 16px;
|
||||
box-shadow: 0 12px 30px rgba(23, 32, 51, 0.05);
|
||||
}
|
||||
.row-head {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: baseline;
|
||||
gap: 12px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.row-title {
|
||||
font-size: 18px;
|
||||
font-weight: 700;
|
||||
margin: 0;
|
||||
}
|
||||
.row-e2e {
|
||||
font-size: 16px;
|
||||
font-weight: 700;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.row-meta {
|
||||
color: var(--muted);
|
||||
font-size: 13px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.bar {
|
||||
height: 24px;
|
||||
display: flex;
|
||||
overflow: hidden;
|
||||
background: var(--track);
|
||||
border-radius: 999px;
|
||||
}
|
||||
.segment {
|
||||
height: 100%;
|
||||
}
|
||||
.segment:first-child {
|
||||
border-top-left-radius: 999px;
|
||||
border-bottom-left-radius: 999px;
|
||||
}
|
||||
.segment:last-child {
|
||||
border-top-right-radius: 999px;
|
||||
border-bottom-right-radius: 999px;
|
||||
}
|
||||
.row-legend {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px 14px;
|
||||
margin-top: 10px;
|
||||
color: var(--muted);
|
||||
font-size: 12px;
|
||||
}
|
||||
.missing {
|
||||
margin-top: 10px;
|
||||
color: #a16207;
|
||||
font-size: 12px;
|
||||
}
|
||||
.empty {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
padding: 24px 16px;
|
||||
background: var(--panel);
|
||||
border: 1px dashed var(--border);
|
||||
border-radius: 16px;
|
||||
text-align: center;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<h1>Latency Summary</h1>
|
||||
<p class="intro">A simple per-message end-to-end latency chart generated from summarized JSONL records.</p>
|
||||
|
||||
<section class="stats">
|
||||
<div class="stat">
|
||||
<div class="stat-label">Messages</div>
|
||||
<div class="stat-value">{{.TotalMessages}}</div>
|
||||
</div>
|
||||
<div class="stat">
|
||||
<div class="stat-label">With End-To-End</div>
|
||||
<div class="stat-value">{{.MessagesWithEndToEnd}}</div>
|
||||
</div>
|
||||
<div class="stat">
|
||||
<div class="stat-label">Average End-To-End</div>
|
||||
<div class="stat-value">{{.AverageEndToEnd}}</div>
|
||||
</div>
|
||||
<div class="stat">
|
||||
<div class="stat-label">Max End-To-End</div>
|
||||
<div class="stat-value">{{.MaxEndToEnd}}</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="legend">
|
||||
{{range .Legend}}
|
||||
<span class="legend-item">
|
||||
<span class="swatch" style="background: {{.Color}}"></span>
|
||||
<span>{{.Label}}</span>
|
||||
</span>
|
||||
{{end}}
|
||||
</section>
|
||||
|
||||
{{if .Rows}}
|
||||
<section class="rows">
|
||||
{{range .Rows}}
|
||||
<article class="row">
|
||||
<div class="row-head">
|
||||
<h2 class="row-title">{{.Title}}</h2>
|
||||
<div class="row-e2e">{{.EndToEnd}}</div>
|
||||
</div>
|
||||
<div class="row-meta">{{.Subtitle}}</div>
|
||||
<div class="bar">
|
||||
{{range .Segments}}
|
||||
<div class="segment" style="width: {{printf "%.4f" .WidthPercent}}%; background: {{.Color}}" title="{{.Label}}: {{.Value}}"></div>
|
||||
{{end}}
|
||||
</div>
|
||||
{{if .Segments}}
|
||||
<div class="row-legend">
|
||||
{{range .Segments}}
|
||||
<span class="legend-item">
|
||||
<span class="swatch" style="background: {{.Color}}"></span>
|
||||
<span>{{.Label}} {{.Value}}</span>
|
||||
</span>
|
||||
{{end}}
|
||||
</div>
|
||||
{{end}}
|
||||
{{if .MissingTimestamps}}
|
||||
<div class="missing">Missing timestamps: {{.MissingTimestamps}}</div>
|
||||
{{end}}
|
||||
</article>
|
||||
{{end}}
|
||||
</section>
|
||||
{{else}}
|
||||
<section class="empty">No summarized messages were available for chart rendering.</section>
|
||||
{{end}}
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
type summaryChartPage struct {
|
||||
TotalMessages int
|
||||
MessagesWithEndToEnd int
|
||||
AverageEndToEnd string
|
||||
MaxEndToEnd string
|
||||
Legend []summaryChartLegendItem
|
||||
Rows []summaryChartRow
|
||||
}
|
||||
|
||||
type summaryChartLegendItem struct {
|
||||
Label string
|
||||
Color string
|
||||
}
|
||||
|
||||
type summaryChartRow struct {
|
||||
Title string
|
||||
Subtitle string
|
||||
EndToEnd string
|
||||
MissingTimestamps string
|
||||
Segments []summaryChartSegment
|
||||
}
|
||||
|
||||
type summaryChartSegment struct {
|
||||
Label string
|
||||
Value string
|
||||
Color string
|
||||
WidthPercent float64
|
||||
}
|
||||
|
||||
type summaryChartSegmentMetric struct {
|
||||
label string
|
||||
value *int64
|
||||
color string
|
||||
}
|
||||
|
||||
// WriteSummariesHTMLChart 将整理结果写成一个可直接在浏览器中打开的简单 HTML 图表。
|
||||
func WriteSummariesHTMLChart(path string, summaries []Summary) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return fmt.Errorf("latencylog: create chart dir for %s: %w", path, err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("latencylog: open chart file %s: %w", path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
page := buildSummaryChartPage(summaries)
|
||||
tmpl, err := template.New("summary-chart").Parse(summaryChartHTMLTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("latencylog: parse chart template: %w", err)
|
||||
}
|
||||
|
||||
writer := bufio.NewWriter(file)
|
||||
if err := tmpl.Execute(writer, page); err != nil {
|
||||
return fmt.Errorf("latencylog: render chart %s: %w", path, err)
|
||||
}
|
||||
if err := writer.Flush(); err != nil {
|
||||
return fmt.Errorf("latencylog: flush chart %s: %w", path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildSummaryChartPage(summaries []Summary) summaryChartPage {
|
||||
page := summaryChartPage{
|
||||
TotalMessages: len(summaries),
|
||||
Legend: []summaryChartLegendItem{
|
||||
{Label: "A processing", Color: "var(--a-proc)"},
|
||||
{Label: "A queue", Color: "var(--a-queue)"},
|
||||
{Label: "Transport + B queue", Color: "var(--transport)"},
|
||||
{Label: "B processing", Color: "var(--b-proc)"},
|
||||
{Label: "Unknown / missing", Color: "var(--unknown)"},
|
||||
},
|
||||
Rows: make([]summaryChartRow, 0, len(summaries)),
|
||||
}
|
||||
|
||||
var (
|
||||
endToEndValues []int64
|
||||
totalEndToEnd int64
|
||||
maxEndToEnd int64
|
||||
)
|
||||
|
||||
for _, summary := range summaries {
|
||||
page.Rows = append(page.Rows, buildSummaryChartRow(summary))
|
||||
|
||||
if summary.EndToEndLatencyNS == nil {
|
||||
continue
|
||||
}
|
||||
endToEnd := *summary.EndToEndLatencyNS
|
||||
endToEndValues = append(endToEndValues, endToEnd)
|
||||
totalEndToEnd += endToEnd
|
||||
if endToEnd > maxEndToEnd {
|
||||
maxEndToEnd = endToEnd
|
||||
}
|
||||
}
|
||||
|
||||
page.MessagesWithEndToEnd = len(endToEndValues)
|
||||
page.AverageEndToEnd = "n/a"
|
||||
page.MaxEndToEnd = "n/a"
|
||||
if len(endToEndValues) > 0 {
|
||||
page.AverageEndToEnd = formatLatencyNS(totalEndToEnd / int64(len(endToEndValues)))
|
||||
page.MaxEndToEnd = formatLatencyNS(maxEndToEnd)
|
||||
}
|
||||
|
||||
return page
|
||||
}
|
||||
|
||||
func buildSummaryChartRow(summary Summary) summaryChartRow {
|
||||
row := summaryChartRow{
|
||||
Title: buildSummaryChartTitle(summary),
|
||||
Subtitle: buildSummaryChartSubtitle(summary),
|
||||
EndToEnd: "End-to-end: n/a",
|
||||
MissingTimestamps: strings.Join(summary.MissingTimestamps, ", "),
|
||||
}
|
||||
|
||||
if summary.EndToEndLatencyNS == nil || *summary.EndToEndLatencyNS <= 0 {
|
||||
return row
|
||||
}
|
||||
|
||||
total := *summary.EndToEndLatencyNS
|
||||
row.EndToEnd = fmt.Sprintf("End-to-end: %s", formatLatencyNS(total))
|
||||
|
||||
metrics := []summaryChartSegmentMetric{
|
||||
{label: "A processing", value: summary.AProcessingLatencyNS, color: "var(--a-proc)"},
|
||||
{label: "A queue", value: summary.AQueueLatencyNS, color: "var(--a-queue)"},
|
||||
{label: "Transport + B queue", value: summary.ABTransportPropagationBQueueLatencyNS, color: "var(--transport)"},
|
||||
{label: "B processing", value: summary.BProcessingLatencyNS, color: "var(--b-proc)"},
|
||||
}
|
||||
|
||||
var knownTotal int64
|
||||
for _, metric := range metrics {
|
||||
if metric.value == nil || *metric.value <= 0 {
|
||||
continue
|
||||
}
|
||||
knownTotal += *metric.value
|
||||
}
|
||||
|
||||
scaleTotal := total
|
||||
if knownTotal > scaleTotal {
|
||||
scaleTotal = knownTotal
|
||||
}
|
||||
if scaleTotal <= 0 {
|
||||
return row
|
||||
}
|
||||
|
||||
for _, metric := range metrics {
|
||||
if metric.value == nil || *metric.value <= 0 {
|
||||
continue
|
||||
}
|
||||
row.Segments = append(row.Segments, summaryChartSegment{
|
||||
Label: metric.label,
|
||||
Value: formatLatencyNS(*metric.value),
|
||||
Color: metric.color,
|
||||
WidthPercent: float64(*metric.value) * 100 / float64(scaleTotal),
|
||||
})
|
||||
}
|
||||
|
||||
if remaining := total - knownTotal; remaining > 0 {
|
||||
row.Segments = append(row.Segments, summaryChartSegment{
|
||||
Label: "Unknown / missing",
|
||||
Value: formatLatencyNS(remaining),
|
||||
Color: "var(--unknown)",
|
||||
WidthPercent: float64(remaining) * 100 / float64(scaleTotal),
|
||||
})
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
func buildSummaryChartTitle(summary Summary) string {
|
||||
if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" {
|
||||
return fmt.Sprintf("%s #%d (%s)", summary.MessageType, summary.MessageID, summary.FileName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s #%d", summary.MessageType, summary.MessageID)
|
||||
}
|
||||
|
||||
func buildSummaryChartSubtitle(summary Summary) string {
|
||||
parts := []string{
|
||||
fmt.Sprintf("%s -> %s", summary.From, summary.To),
|
||||
fmt.Sprintf("%d bytes", summary.BodySize),
|
||||
}
|
||||
|
||||
if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" {
|
||||
parts = append(parts, fmt.Sprintf("file: %s", summary.FileName))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func formatLatencyNS(ns int64) string {
|
||||
return fmt.Sprintf("%.3f ms", float64(ns)/1_000_000)
|
||||
}
|
||||
67
cmd/internal/latencylog/summary_chart_test.go
Normal file
67
cmd/internal/latencylog/summary_chart_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package latencylog
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
func TestWriteSummariesHTMLChart(t *testing.T) {
|
||||
aProcessing := int64(20_000_000)
|
||||
aQueue := int64(10_000_000)
|
||||
transport := int64(40_000_000)
|
||||
bProcessing := int64(30_000_000)
|
||||
endToEnd := int64(100_000_000)
|
||||
|
||||
summaries := []Summary{
|
||||
{
|
||||
MessageType: protocol.MessageTypeText,
|
||||
MessageID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
BodySize: 5,
|
||||
AProcessingLatencyNS: &aProcessing,
|
||||
AQueueLatencyNS: &aQueue,
|
||||
ABTransportPropagationBQueueLatencyNS: &transport,
|
||||
BProcessingLatencyNS: &bProcessing,
|
||||
EndToEndLatencyNS: &endToEnd,
|
||||
},
|
||||
{
|
||||
MessageType: protocol.MessageTypeFile,
|
||||
MessageID: 8,
|
||||
From: "peer-b",
|
||||
To: "peer-a",
|
||||
FileName: "payload.bin",
|
||||
BodySize: 128,
|
||||
MissingTimestamps: []string{EventBRXSoftware},
|
||||
},
|
||||
}
|
||||
|
||||
path := filepath.Join(t.TempDir(), "charts", "latency-summary.html")
|
||||
if err := WriteSummariesHTMLChart(path, summaries); err != nil {
|
||||
t.Fatalf("WriteSummariesHTMLChart() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("os.ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
for _, want := range []string{
|
||||
"Latency Summary",
|
||||
"text #7",
|
||||
"peer-a -> peer-b | 5 bytes",
|
||||
"End-to-end: 100.000 ms",
|
||||
"A processing 20.000 ms",
|
||||
"file #8 (payload.bin)",
|
||||
"Missing timestamps: B_RX_SOFTWARE",
|
||||
} {
|
||||
if !strings.Contains(content, want) {
|
||||
t.Fatalf("chart content missing %q\n%s", want, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
144
cmd/internal/latencylog/summary_test.go
Normal file
144
cmd/internal/latencylog/summary_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package latencylog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
func TestSummarizeEventsComputesLatencyMetrics(t *testing.T) {
|
||||
events := []Event{
|
||||
{TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 230, Event: EventBPersistBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
|
||||
}
|
||||
|
||||
summaries := SummarizeEvents(events)
|
||||
if len(summaries) != 1 {
|
||||
t.Fatalf("summary count = %d, want 1", len(summaries))
|
||||
}
|
||||
|
||||
summary := summaries[0]
|
||||
if got := ptrValue(summary.AProcessingLatencyNS); got != 20 {
|
||||
t.Fatalf("AProcessingLatencyNS = %d, want 20", got)
|
||||
}
|
||||
if got := ptrValue(summary.AQueueLatencyNS); got != 20 {
|
||||
t.Fatalf("AQueueLatencyNS = %d, want 20", got)
|
||||
}
|
||||
if got := ptrValue(summary.ABTransportPropagationBQueueLatencyNS); got != 80 {
|
||||
t.Fatalf("ABTransportPropagationBQueueLatencyNS = %d, want 80", got)
|
||||
}
|
||||
if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 {
|
||||
t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got)
|
||||
}
|
||||
if got := ptrValue(summary.BProcessingLatencyNS); got != 40 {
|
||||
t.Fatalf("BProcessingLatencyNS = %d, want 40", got)
|
||||
}
|
||||
if got := ptrValue(summary.EndToEndLatencyNS); got != 160 {
|
||||
t.Fatalf("EndToEndLatencyNS = %d, want 160", got)
|
||||
}
|
||||
if got := summary.Timestamps[EventBRXSoftware]; got != 180 {
|
||||
t.Fatalf("timestamps[%q] = %d, want 180", EventBRXSoftware, got)
|
||||
}
|
||||
if len(summary.MissingTimestamps) != 0 {
|
||||
t.Fatalf("MissingTimestamps = %v, want empty", summary.MissingTimestamps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeEventsReportsMissingTimestamps(t *testing.T) {
|
||||
events := []Event{
|
||||
{TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 240, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"},
|
||||
}
|
||||
|
||||
summaries := SummarizeEvents(events)
|
||||
if len(summaries) != 1 {
|
||||
t.Fatalf("summary count = %d, want 1", len(summaries))
|
||||
}
|
||||
|
||||
wantMissing := []string{EventATXSched, EventATXSoftware, EventBRXSoftware, EventBAppRecv}
|
||||
if !reflect.DeepEqual(summaries[0].MissingTimestamps, wantMissing) {
|
||||
t.Fatalf("MissingTimestamps = %v, want %v", summaries[0].MissingTimestamps, wantMissing)
|
||||
}
|
||||
if summaries[0].AProcessingLatencyNS != nil {
|
||||
t.Fatalf("AProcessingLatencyNS = %v, want nil", ptrValue(summaries[0].AProcessingLatencyNS))
|
||||
}
|
||||
if summaries[0].EndToEndLatencyNS == nil {
|
||||
t.Fatal("EndToEndLatencyNS = nil, want non-nil because endpoints are present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAndWriteSummaryFiles(t *testing.T) {
|
||||
rawPath := filepath.Join(t.TempDir(), "raw.jsonl")
|
||||
rawLogger, err := NewJSONLLogger(rawPath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLLogger() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = rawLogger.Close()
|
||||
})
|
||||
|
||||
for _, event := range []Event{
|
||||
{TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
|
||||
{TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
|
||||
} {
|
||||
if err := rawLogger.LogEvent(event); err != nil {
|
||||
t.Fatalf("LogEvent() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
events, err := LoadEventsFromFile(rawPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadEventsFromFile() error = %v", err)
|
||||
}
|
||||
|
||||
summaryPath := filepath.Join(t.TempDir(), "summary.jsonl")
|
||||
if err := WriteSummariesJSONL(summaryPath, SummarizeEvents(events)); err != nil {
|
||||
t.Fatalf("WriteSummariesJSONL() error = %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(summaryPath)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Open() error = %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
if !scanner.Scan() {
|
||||
t.Fatal("expected one summary line, got none")
|
||||
}
|
||||
|
||||
var summary Summary
|
||||
if err := json.Unmarshal(scanner.Bytes(), &summary); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if summary.MessageID != 3 {
|
||||
t.Fatalf("MessageID = %d, want 3", summary.MessageID)
|
||||
}
|
||||
if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 {
|
||||
t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got)
|
||||
}
|
||||
if got := ptrValue(summary.EndToEndLatencyNS); got != 160 {
|
||||
t.Fatalf("EndToEndLatencyNS = %d, want 160", got)
|
||||
}
|
||||
}
|
||||
|
||||
func ptrValue(value *int64) int64 {
|
||||
if value == nil {
|
||||
return 0
|
||||
}
|
||||
return *value
|
||||
}
|
||||
179
cmd/internal/peer/client.go
Normal file
179
cmd/internal/peer/client.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
var dialServer = net.Dial
|
||||
|
||||
type clientOptions struct {
|
||||
logger latencylog.Logger
|
||||
}
|
||||
|
||||
// Option 用于配置 Client 的可选行为,例如时延日志。
|
||||
type Option func(*clientOptions)
|
||||
|
||||
// WithLogger 为 client 注入时延日志记录器。
|
||||
func WithLogger(logger latencylog.Logger) Option {
|
||||
return func(options *clientOptions) {
|
||||
options.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// Client 表示一个已经连接到 server 的 peer。
|
||||
type Client struct {
|
||||
id string
|
||||
conn *transport.TCPConn
|
||||
logger latencylog.Logger
|
||||
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
// Dial 连接到 server,并立即发送 register 消息完成身份注册。
|
||||
func Dial(serverAddr, peerID string, opts ...Option) (*Client, error) {
|
||||
options := clientOptions{
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
if options.logger == nil {
|
||||
options.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
rawConn, err := dialServer("tcp", serverAddr) //使用 net.Dial 连接到 serverAddr 指定的 TCP 地址,返回一个 net.Conn。
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer: dial server: %w", err)
|
||||
}
|
||||
|
||||
conn, err := transport.NewTCPConn(
|
||||
rawConn,
|
||||
transport.WithLogger(options.logger, latencylog.NodeRolePeer, peerID),
|
||||
)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return nil, fmt.Errorf("peer: create transport conn: %w", err)
|
||||
}
|
||||
client := &Client{
|
||||
id: peerID,
|
||||
conn: conn,
|
||||
logger: options.logger,
|
||||
}
|
||||
|
||||
if err := conn.Send(protocol.Message{ //向 server 发送一条 register 消息,完成身份注册。
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: peerID,
|
||||
To: protocol.ServerPeerID,
|
||||
}); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("peer: register with server: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// ID 返回当前 client 的 peer 标识。
|
||||
func (c *Client) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
// SendText 向目标 peer 发送一条文本消息。
|
||||
func (c *Client) SendText(to, body string) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
}
|
||||
// 记录 A 端应用开始准备消息的时间点。
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
msg.Body = []byte(body)
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// SendFile 向目标 peer 发送一条文件消息。
|
||||
func (c *Client) SendFile(to, fileName string, body []byte) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
FileName: fileName,
|
||||
}
|
||||
// 记录 A 端应用开始准备消息的时间点。
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
bodyCopy := make([]byte, len(body))
|
||||
copy(bodyCopy, body)
|
||||
|
||||
msg.Body = bodyCopy
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// SendFilePath 从本地文件读取内容并发送给目标 peer。
|
||||
func (c *Client) SendFilePath(to, path string) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
FileName: filepath.Base(path),
|
||||
}
|
||||
// 记录 A 端应用开始准备消息的时间点。
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("peer: read file %s: %w", path, err)
|
||||
}
|
||||
|
||||
msg.Body = body
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// Receive 读取一条来自 server 的消息。
|
||||
func (c *Client) Receive() (protocol.Message, error) {
|
||||
msg, err := c.conn.Receive() //从底层 TCP 连接读取一条消息,返回一个 protocol.Message 结构体。
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("peer: receive from server: %w", err)
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReceiveLoop 持续接收 server 消息并交给 handler 处理。
|
||||
func (c *Client) ReceiveLoop(handler func(protocol.Message) error) error {
|
||||
return c.conn.ReceiveLoop(func(msg protocol.Message) error {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError:
|
||||
// 记录 B 端应用真正读到完整消息的时间点。
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
|
||||
return handler(msg)
|
||||
default:
|
||||
return fmt.Errorf("peer: unexpected message type from server: %s", msg.Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Close 关闭与 server 的连接。
|
||||
func (c *Client) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Client) nextMessageID() uint64 {
|
||||
return atomic.AddUint64(&c.nextID, 1)
|
||||
}
|
||||
188
cmd/internal/peer/client_linux_test.go
Normal file
188
cmd/internal/peer/client_linux_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
//go:build linux
|
||||
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/server"
|
||||
)
|
||||
|
||||
func TestClientsExchangeMessagesWithLinuxTimestamps(t *testing.T) {
|
||||
hub := server.NewHub()
|
||||
serverAddr, cleanup := startRealHubServer(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
peerALogger := &recordingLogger{}
|
||||
peerA, err := Dial(serverAddr, "peer-a", WithLogger(peerALogger))
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-a) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerA.Close() }()
|
||||
|
||||
peerBLogger := &recordingLogger{}
|
||||
peerB, err := Dial(serverAddr, "peer-b", WithLogger(peerBLogger))
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-b) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerB.Close() }()
|
||||
|
||||
inboxDir := t.TempDir()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
if err := peerA.SendText("peer-b", "hello"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
textMsg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peerB.Receive(text) error = %v", err)
|
||||
}
|
||||
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
|
||||
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
|
||||
}
|
||||
|
||||
if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil {
|
||||
t.Fatalf("SendFile() error = %v", err)
|
||||
}
|
||||
fileMsg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peerB.Receive(file) error = %v", err)
|
||||
}
|
||||
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
|
||||
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
|
||||
}
|
||||
|
||||
waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 1, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a text kernel timestamps")
|
||||
waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 2, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a file kernel timestamps")
|
||||
waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 1, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b text receive timestamps")
|
||||
waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 2, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b file receive timestamps")
|
||||
}
|
||||
|
||||
func startRealHubServer(t *testing.T, hub *server.Hub) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
stop = make(chan struct{})
|
||||
errOnce sync.Once
|
||||
)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if strings.Contains(acceptErr.Error(), "closed") {
|
||||
return
|
||||
}
|
||||
t.Errorf("listener.Accept() error = %v", acceptErr)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(rawConn net.Conn) {
|
||||
defer wg.Done()
|
||||
if serveErr := hub.ServeConn(rawConn); serveErr != nil && !isExpectedHubServeExit(serveErr) {
|
||||
errOnce.Do(func() {
|
||||
t.Logf("hub.ServeConn() ended with %v", serveErr)
|
||||
})
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() {
|
||||
close(stop)
|
||||
_ = listener.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
return listener.Addr().String(), cleanup
|
||||
}
|
||||
|
||||
func hasMessageEvents(events []latencylog.Event, messageID uint64, wantEvents ...string) bool {
|
||||
seen := make(map[string]bool, len(wantEvents))
|
||||
for _, event := range events {
|
||||
if event.MessageID != messageID {
|
||||
continue
|
||||
}
|
||||
if event.TsUnixNano <= 0 {
|
||||
return false
|
||||
}
|
||||
seen[event.Event] = true
|
||||
}
|
||||
|
||||
for _, wantEvent := range wantEvents {
|
||||
if !seen[wantEvent] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isExpectedHubServeExit(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
message := err.Error()
|
||||
return strings.Contains(message, "closed") || strings.Contains(message, "protocol: read frame: EOF")
|
||||
}
|
||||
|
||||
func TestLinuxTimestampedReceivePreservesBusinessMessageShape(t *testing.T) {
|
||||
hub := server.NewHub()
|
||||
serverAddr, cleanup := startRealHubServer(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
peerA, err := Dial(serverAddr, "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-a) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerA.Close() }()
|
||||
|
||||
peerB, err := Dial(serverAddr, "peer-b")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-b) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerB.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
want := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0xde, 0xad, 0xbe, 0xef},
|
||||
}
|
||||
if err := peerA.SendFile(want.To, want.FileName, want.Body); err != nil {
|
||||
t.Fatalf("SendFile() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := peerB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peerB.Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != want.Type || got.ID != want.ID || got.From != want.From || got.To != want.To || got.FileName != want.FileName || string(got.Body) != string(want.Body) {
|
||||
t.Fatalf("received message mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
770
cmd/internal/peer/client_test.go
Normal file
770
cmd/internal/peer/client_test.go
Normal file
@@ -0,0 +1,770 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/server"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
type recordingLogger struct {
|
||||
mu sync.Mutex
|
||||
events []latencylog.Event
|
||||
}
|
||||
|
||||
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.events = append(l.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *recordingLogger) Events() []latencylog.Event {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
return append([]latencylog.Event(nil), l.events...)
|
||||
}
|
||||
|
||||
type failingLogger struct{}
|
||||
|
||||
func (failingLogger) LogEvent(latencylog.Event) error {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
func TestDialRegistersPeer(t *testing.T) {
|
||||
hub := server.NewHub()
|
||||
cleanup := stubDialToHub(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
client, err := Dial("ignored", "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial() error = %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
}
|
||||
|
||||
func TestClientsExchangeTextAndFileMessages(t *testing.T) {
|
||||
hub := server.NewHub()
|
||||
cleanup := stubDialToHub(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
peerA, err := Dial("ignored", "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-a) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerA.Close() }()
|
||||
|
||||
peerB, err := Dial("ignored", "peer-b")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-b) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerB.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
received := make(chan protocol.Message, 2)
|
||||
receiveErr := make(chan error, 1)
|
||||
go func() {
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
receiveErr <- err
|
||||
return
|
||||
}
|
||||
received <- msg
|
||||
}
|
||||
receiveErr <- nil
|
||||
}()
|
||||
|
||||
if err := peerA.SendText("peer-b", "hello"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
fileBody := []byte{0x01, 0x02, 0x03}
|
||||
if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil {
|
||||
t.Fatalf("SendFile() error = %v", err)
|
||||
}
|
||||
|
||||
if err := <-receiveErr; err != nil {
|
||||
t.Fatalf("peerB.Receive() error = %v", err)
|
||||
}
|
||||
|
||||
gotFirst := <-received
|
||||
wantFirst := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
if !reflect.DeepEqual(gotFirst, wantFirst) {
|
||||
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst)
|
||||
}
|
||||
|
||||
gotSecond := <-received
|
||||
wantSecond := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: fileBody,
|
||||
}
|
||||
if !reflect.DeepEqual(gotSecond, wantSecond) {
|
||||
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReceivesServerErrorForUnknownTarget(t *testing.T) {
|
||||
hub := server.NewHub()
|
||||
cleanup := stubDialToHub(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
client, err := Dial("ignored", "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial() error = %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
if err := client.SendText("missing-peer", "hello"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := client.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
|
||||
}
|
||||
if string(got.Body) != "unknown target: missing-peer" {
|
||||
t.Fatalf("error body = %q, want unknown target message", got.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReceiveLoopHandlesForwardedMessages(t *testing.T) {
|
||||
hub := server.NewHub()
|
||||
cleanup := stubDialToHub(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
peerA, err := Dial("ignored", "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-a) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerA.Close() }()
|
||||
|
||||
peerB, err := Dial("ignored", "peer-b")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-b) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerB.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
got []protocol.Message
|
||||
)
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- peerB.ReceiveLoop(func(msg protocol.Message) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
got = append(got, msg)
|
||||
if len(got) == 1 {
|
||||
return peerB.Close()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
if err := peerA.SendText("peer-b", "hello"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
|
||||
err = <-loopErr
|
||||
if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection")) {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
want := []protocol.Message{
|
||||
{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSendLogsLatencyEvents(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*testing.T) string
|
||||
send func(*Client, string) error
|
||||
wantMsg protocol.Message
|
||||
wantEvents []string
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
send: func(client *Client, _ string) error {
|
||||
return client.SendText("peer-b", "hello")
|
||||
},
|
||||
wantMsg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
wantEvents: []string{
|
||||
latencylog.EventAAppPrepBegin,
|
||||
latencylog.EventSendHandoffBegin,
|
||||
latencylog.EventATXSched,
|
||||
latencylog.EventATXSoftware,
|
||||
latencylog.EventSendHandoffEnd,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file-bytes",
|
||||
send: func(client *Client, _ string) error {
|
||||
return client.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03})
|
||||
},
|
||||
wantMsg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
wantEvents: []string{
|
||||
latencylog.EventAAppPrepBegin,
|
||||
latencylog.EventSendHandoffBegin,
|
||||
latencylog.EventATXSched,
|
||||
latencylog.EventATXSoftware,
|
||||
latencylog.EventSendHandoffEnd,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file-path",
|
||||
setup: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
path := filepath.Join(t.TempDir(), "payload.bin")
|
||||
if err := os.WriteFile(path, []byte{0x01, 0x02, 0x03}, 0o644); err != nil {
|
||||
t.Fatalf("os.WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
return path
|
||||
},
|
||||
send: func(client *Client, path string) error {
|
||||
return client.SendFilePath("peer-b", path)
|
||||
},
|
||||
wantMsg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
wantEvents: []string{
|
||||
latencylog.EventAAppPrepBegin,
|
||||
latencylog.EventSendHandoffBegin,
|
||||
latencylog.EventATXSched,
|
||||
latencylog.EventATXSoftware,
|
||||
latencylog.EventSendHandoffEnd,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
inputPath := ""
|
||||
if tt.setup != nil {
|
||||
inputPath = tt.setup(t)
|
||||
}
|
||||
|
||||
logger := &recordingLogger{}
|
||||
clientConn, receiver := newClientTransportPair(
|
||||
t,
|
||||
[]transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
client := &Client{
|
||||
id: "peer-a",
|
||||
conn: clientConn,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- tt.send(client, inputPath)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receiver.Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.wantMsg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.wantMsg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != len(tt.wantEvents) {
|
||||
t.Fatalf("event count = %d, want %d", len(events), len(tt.wantEvents))
|
||||
}
|
||||
for i, wantEvent := range tt.wantEvents {
|
||||
if events[i].Event != wantEvent {
|
||||
t.Fatalf("event[%d] = %q, want %q", i, events[i].Event, wantEvent)
|
||||
}
|
||||
if events[i].MessageID != tt.wantMsg.ID || events[i].From != tt.wantMsg.From || events[i].To != tt.wantMsg.To {
|
||||
t.Fatalf("event[%d] metadata mismatch: %+v", i, events[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReceiveLogsOnlyBusinessMessages(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
clientConn, sender := newClientTransportPair(
|
||||
t,
|
||||
[]transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-b")},
|
||||
nil,
|
||||
)
|
||||
client := &Client{
|
||||
id: "peer-b",
|
||||
conn: clientConn,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
textMsg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 21,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(textMsg)
|
||||
}()
|
||||
if _, err := client.Receive(); err != nil {
|
||||
t.Fatalf("client.Receive(text) error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("sender.Send(text) error = %v", err)
|
||||
}
|
||||
|
||||
errorMsg := protocol.Message{
|
||||
Type: protocol.MessageTypeError,
|
||||
ID: 22,
|
||||
From: protocol.ServerPeerID,
|
||||
To: "peer-b",
|
||||
Body: []byte("failure"),
|
||||
}
|
||||
sendErr = make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(errorMsg)
|
||||
}()
|
||||
if _, err := client.Receive(); err != nil {
|
||||
t.Fatalf("client.Receive(error) error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("sender.Send(error) error = %v", err)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("event count = %d, want 2", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventBRXSoftware {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBRXSoftware)
|
||||
}
|
||||
if events[1].Event != latencylog.EventBAppRecv {
|
||||
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBAppRecv)
|
||||
}
|
||||
if events[0].MessageID != textMsg.ID || events[1].MessageID != textMsg.ID {
|
||||
t.Fatalf("message IDs = %d,%d, want %d", events[0].MessageID, events[1].MessageID, textMsg.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPersistTextMessageWritesInboxFileAndLogs(t *testing.T) {
|
||||
inboxDir := t.TempDir()
|
||||
logger := &recordingLogger{}
|
||||
client := &Client{
|
||||
id: "peer-b",
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 31,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
path, err := client.PersistMessage(msg, inboxDir)
|
||||
if err != nil {
|
||||
t.Fatalf("PersistMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if path != filepath.Join(inboxDir, textInboxFileName) {
|
||||
t.Fatalf("path = %q, want %q", path, filepath.Join(inboxDir, textInboxFileName))
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("os.ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
var record textInboxRecord
|
||||
if err := json.Unmarshal(bytes.TrimSpace(data), &record); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if record.MessageID != msg.ID || record.From != msg.From || record.To != msg.To || record.Body != "hello" {
|
||||
t.Fatalf("record mismatch: got %+v want message %+v", record, msg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("event count = %d, want 2", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventBPersistBegin {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
|
||||
}
|
||||
if events[1].Event != latencylog.EventBPersistEnd {
|
||||
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPersistFileMessageWritesInboxFileAndLogs(t *testing.T) {
|
||||
inboxDir := t.TempDir()
|
||||
logger := &recordingLogger{}
|
||||
client := &Client{
|
||||
id: "peer-b",
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 32,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
|
||||
path, err := client.PersistMessage(msg, inboxDir)
|
||||
if err != nil {
|
||||
t.Fatalf("PersistMessage() error = %v", err)
|
||||
}
|
||||
|
||||
wantPath := filepath.Join(inboxDir, "peer-a-32-payload.bin")
|
||||
if path != wantPath {
|
||||
t.Fatalf("path = %q, want %q", path, wantPath)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("os.ReadFile() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(data, msg.Body) {
|
||||
t.Fatalf("file body mismatch: got %v want %v", data, msg.Body)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("event count = %d, want 2", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventBPersistBegin {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
|
||||
}
|
||||
if events[1].Event != latencylog.EventBPersistEnd {
|
||||
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPersistMessageReturnsErrorOnWriteFailure(t *testing.T) {
|
||||
blocker := filepath.Join(t.TempDir(), "blocker")
|
||||
if err := os.WriteFile(blocker, []byte("not a directory"), 0o644); err != nil {
|
||||
t.Fatalf("os.WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
logger := &recordingLogger{}
|
||||
client := &Client{
|
||||
id: "peer-b",
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 33,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
if _, err := client.PersistMessage(msg, blocker); err == nil {
|
||||
t.Fatal("PersistMessage() error = nil, want non-nil")
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("event count = %d, want 1", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventBPersistBegin {
|
||||
t.Fatalf("event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIgnoresLoggerFailure(t *testing.T) {
|
||||
clientConn, receiver := newClientTransportPair(
|
||||
t,
|
||||
[]transport.Option{transport.WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
client := &Client{
|
||||
id: "peer-a",
|
||||
conn: clientConn,
|
||||
logger: failingLogger{},
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- client.SendText("peer-b", "hello")
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receiver.Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("SendText() error = %v, want nil even if logger fails", err)
|
||||
}
|
||||
if string(got.Body) != "hello" {
|
||||
t.Fatalf("body = %q, want hello", got.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPersistIgnoresLoggerFailure(t *testing.T) {
|
||||
client := &Client{
|
||||
id: "peer-b",
|
||||
logger: failingLogger{},
|
||||
}
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 34,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
path, err := client.PersistMessage(msg, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("PersistMessage() error = %v, want nil even if logger fails", err)
|
||||
}
|
||||
if path == "" {
|
||||
t.Fatal("PersistMessage() path = empty, want non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsExchangeMessagesWithLatencyLogs(t *testing.T) {
|
||||
hub := server.NewHub()
|
||||
cleanup := stubDialToHub(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
peerALogger := &recordingLogger{}
|
||||
peerA, err := Dial("ignored", "peer-a", WithLogger(peerALogger))
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-a) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerA.Close() }()
|
||||
|
||||
peerBLogger := &recordingLogger{}
|
||||
peerB, err := Dial("ignored", "peer-b", WithLogger(peerBLogger))
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(peer-b) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerB.Close() }()
|
||||
|
||||
inboxDir := t.TempDir()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
if err := peerA.SendText("peer-b", "hello"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
textMsg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peerB.Receive(text) error = %v", err)
|
||||
}
|
||||
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
|
||||
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
|
||||
}
|
||||
|
||||
if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil {
|
||||
t.Fatalf("SendFile() error = %v", err)
|
||||
}
|
||||
fileMsg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peerB.Receive(file) error = %v", err)
|
||||
}
|
||||
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
|
||||
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
|
||||
}
|
||||
|
||||
waitFor(t, func() bool { return len(peerALogger.Events()) == 10 }, "peer-a latency events")
|
||||
waitFor(t, func() bool { return len(peerBLogger.Events()) == 8 }, "peer-b latency events")
|
||||
|
||||
assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{
|
||||
1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
|
||||
2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
|
||||
})
|
||||
assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{
|
||||
1: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
|
||||
2: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
|
||||
})
|
||||
}
|
||||
|
||||
func stubDialToHub(t *testing.T, hub *server.Hub) func() {
|
||||
t.Helper()
|
||||
|
||||
originalDial := dialServer
|
||||
serverAddr, cleanup := startRealHubServer(t, hub)
|
||||
|
||||
dialServer = func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, serverAddr)
|
||||
}
|
||||
|
||||
return func() {
|
||||
dialServer = originalDial
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func waitFor(t *testing.T, condition func() bool, description string) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("timed out waiting for %s", description)
|
||||
}
|
||||
|
||||
func assertEventSequencesByMessage(t *testing.T, events []latencylog.Event, want map[uint64][]string) {
|
||||
t.Helper()
|
||||
|
||||
grouped := make(map[uint64][]latencylog.Event)
|
||||
for _, event := range events {
|
||||
grouped[event.MessageID] = append(grouped[event.MessageID], event)
|
||||
if event.TsUnixNano <= 0 {
|
||||
t.Fatalf("event timestamp must be positive: %+v", event)
|
||||
}
|
||||
}
|
||||
|
||||
if len(grouped) != len(want) {
|
||||
t.Fatalf("message group count = %d, want %d", len(grouped), len(want))
|
||||
}
|
||||
|
||||
for messageID, wantEvents := range want {
|
||||
gotEvents := grouped[messageID]
|
||||
if len(gotEvents) != len(wantEvents) {
|
||||
t.Fatalf("message %d event count = %d, want %d", messageID, len(gotEvents), len(wantEvents))
|
||||
}
|
||||
for i, wantEvent := range wantEvents {
|
||||
if gotEvents[i].Event != wantEvent {
|
||||
t.Fatalf("message %d event[%d] = %q, want %q", messageID, i, gotEvents[i].Event, wantEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newClientTransportPair(t *testing.T, clientOpts []transport.Option, peerOpts []transport.Option) (*transport.TCPConn, *transport.TCPConn) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
|
||||
type acceptResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
accepted := make(chan acceptResult, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
accepted <- acceptResult{conn: conn, err: acceptErr}
|
||||
}()
|
||||
|
||||
clientSide, err := net.Dial("tcp", listener.Addr().String())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
t.Fatalf("net.Dial() error = %v", err)
|
||||
}
|
||||
|
||||
result := <-accepted
|
||||
if err := listener.Close(); err != nil {
|
||||
t.Fatalf("listener.Close() error = %v", err)
|
||||
}
|
||||
if result.err != nil {
|
||||
_ = clientSide.Close()
|
||||
t.Fatalf("listener.Accept() error = %v", result.err)
|
||||
}
|
||||
|
||||
clientConn, err := transport.NewTCPConn(clientSide, clientOpts...)
|
||||
if err != nil {
|
||||
_ = clientSide.Close()
|
||||
_ = result.conn.Close()
|
||||
t.Fatalf("transport.NewTCPConn(client) error = %v", err)
|
||||
}
|
||||
peerConn, err := transport.NewTCPConn(result.conn, peerOpts...)
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = result.conn.Close()
|
||||
t.Fatalf("transport.NewTCPConn(peer) error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = peerConn.Close()
|
||||
})
|
||||
|
||||
return clientConn, peerConn
|
||||
}
|
||||
99
cmd/internal/peer/persist.go
Normal file
99
cmd/internal/peer/persist.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
const textInboxFileName = "messages.log"
|
||||
|
||||
type textInboxRecord struct {
|
||||
MessageType protocol.MessageType `json:"message_type"`
|
||||
MessageID uint64 `json:"message_id"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
// PersistMessage 将收到的业务消息写入本地磁盘,并记录处理完成节点。
|
||||
func (c *Client) PersistMessage(msg protocol.Message, inboxDir string) (string, error) {
|
||||
if !latencylog.IsBusinessMessage(msg) {
|
||||
return "", fmt.Errorf("peer: cannot persist message type %s", msg.Type)
|
||||
}
|
||||
if inboxDir == "" {
|
||||
return "", fmt.Errorf("peer: inbox directory is required")
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistBegin, msg)
|
||||
|
||||
if err := os.MkdirAll(inboxDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("peer: create inbox dir %s: %w", inboxDir, err)
|
||||
}
|
||||
|
||||
path, err := persistMessageToDisk(msg, inboxDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistEnd, msg)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// persistMessageToDisk 根据消息类型将消息内容写入磁盘,文本消息追加到文本日志文件,文件消息写成独立文件。
|
||||
func persistMessageToDisk(msg protocol.Message, inboxDir string) (string, error) {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText:
|
||||
return persistTextMessage(msg, inboxDir)
|
||||
case protocol.MessageTypeFile:
|
||||
return persistFileMessage(msg, inboxDir)
|
||||
default:
|
||||
return "", fmt.Errorf("peer: cannot persist unsupported message type %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// registerPeer 验证 peer ID 的合法性和唯一性,并将其与连接关联起来。
|
||||
func persistTextMessage(msg protocol.Message, inboxDir string) (string, error) {
|
||||
record := textInboxRecord{
|
||||
MessageType: msg.Type,
|
||||
MessageID: msg.ID,
|
||||
From: msg.From,
|
||||
To: msg.To,
|
||||
Body: string(msg.Body),
|
||||
}
|
||||
|
||||
line, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("peer: encode text inbox record: %w", err)
|
||||
}
|
||||
|
||||
path := filepath.Join(inboxDir, textInboxFileName)
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("peer: open text inbox %s: %w", path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := file.Write(append(line, '\n')); err != nil {
|
||||
return "", fmt.Errorf("peer: append text inbox %s: %w", path, err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// persistFileMessage 将文件消息的内容写成独立文件,文件名包含发送方、消息 ID 和原始文件名,保证唯一性和可读性。
|
||||
func persistFileMessage(msg protocol.Message, inboxDir string) (string, error) {
|
||||
fileName := filepath.Base(msg.FileName)
|
||||
path := filepath.Join(inboxDir, fmt.Sprintf("%s-%d-%s", msg.From, msg.ID, fileName))
|
||||
|
||||
if err := os.WriteFile(path, msg.Body, 0o644); err != nil {
|
||||
return "", fmt.Errorf("peer: write received file %s: %w", path, err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
279
cmd/internal/protocol/codec.go
Normal file
279
cmd/internal/protocol/codec.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// MaxFrameSize 用于限制单个帧的最大长度,
|
||||
// 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。
|
||||
const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB
|
||||
|
||||
var (
|
||||
ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。
|
||||
ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。
|
||||
ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。
|
||||
ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。
|
||||
ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。
|
||||
ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。
|
||||
ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。
|
||||
ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。
|
||||
ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。
|
||||
ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。
|
||||
ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。
|
||||
ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。
|
||||
ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。
|
||||
ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。
|
||||
)
|
||||
|
||||
// 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON(下面自定义的Message头)][body bytes]
|
||||
// 写了 tag:JSON 字段名是你指定的 type;不写 tag:JSON 字段名默认是 Go 字段名 Type
|
||||
type messageHeader struct {
|
||||
Type MessageType `json:"type"`
|
||||
ID uint64 `json:"id"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
ContentLength int `json:"content_length"`
|
||||
}
|
||||
|
||||
// EncodeMessage 将逻辑消息编码为帧内字节格式:
|
||||
// 1. 4 字节大端序 header 长度
|
||||
// 2. header JSON
|
||||
// 3. 原始 body 字节
|
||||
func EncodeMessage(msg Message) ([]byte, error) {
|
||||
if err := validateMessage(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header := messageHeader{
|
||||
Type: msg.Type,
|
||||
ID: msg.ID,
|
||||
From: msg.From,
|
||||
To: msg.To,
|
||||
FileName: msg.FileName,
|
||||
ContentLength: len(msg.Body),
|
||||
}
|
||||
|
||||
headerPayload, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("protocol: encode header: %w", err)
|
||||
}
|
||||
// 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。
|
||||
payload := make([]byte, 4+len(headerPayload)+len(msg.Body))
|
||||
// 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON(第五个字节开始) 和 body。
|
||||
binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload)))
|
||||
copy(payload[4:], headerPayload)
|
||||
copy(payload[4+len(headerPayload):], msg.Body)
|
||||
|
||||
//检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。
|
||||
if len(payload) > MaxFrameSize {
|
||||
return nil, ErrFrameTooLarge
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// DecodeMessage 将帧内字节格式还原为 Message。
|
||||
func DecodeMessage(data []byte) (Message, error) {
|
||||
if len(data) > MaxFrameSize {
|
||||
return Message{}, ErrFrameTooLarge
|
||||
}
|
||||
if len(data) < 4 {
|
||||
return Message{}, ErrInvalidHeaderLength
|
||||
}
|
||||
|
||||
headerLen := int(binary.BigEndian.Uint32(data[:4]))
|
||||
if headerLen == 0 || headerLen > len(data)-4 {
|
||||
return Message{}, ErrInvalidHeaderLength
|
||||
}
|
||||
|
||||
headerPayload := data[4 : 4+headerLen]
|
||||
body := data[4+headerLen:]
|
||||
|
||||
var header messageHeader
|
||||
if err := json.Unmarshal(headerPayload, &header); err != nil {
|
||||
return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err))
|
||||
}
|
||||
|
||||
if header.ContentLength < 0 || header.ContentLength != len(body) {
|
||||
return Message{}, ErrInvalidContentLength
|
||||
}
|
||||
|
||||
bodyCopy := make([]byte, len(body))
|
||||
copy(bodyCopy, body)
|
||||
|
||||
msg := Message{
|
||||
Type: header.Type,
|
||||
ID: header.ID,
|
||||
From: header.From,
|
||||
To: header.To,
|
||||
FileName: header.FileName,
|
||||
Body: bodyCopy,
|
||||
}
|
||||
|
||||
if err := validateMessage(msg); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// WriteFrame 向流中写入一个带长度前缀的帧。
|
||||
// TCP帧格式如下:
|
||||
// 1. 4 字节大端序长度
|
||||
// 2. 后续 payload 内容
|
||||
//
|
||||
// TCP 是字节流协议,没有天然的消息边界。
|
||||
// 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节,
|
||||
// 从而解决粘包和拆包问题。
|
||||
func WriteFrame(w io.Writer, payload []byte) error {
|
||||
size := len(payload)
|
||||
//空帧
|
||||
if size == 0 {
|
||||
return ErrInvalidFrameLength
|
||||
}
|
||||
//帧过大
|
||||
if size > MaxFrameSize {
|
||||
return ErrFrameTooLarge
|
||||
}
|
||||
|
||||
var header [4]byte
|
||||
binary.BigEndian.PutUint32(header[:], uint32(size))
|
||||
|
||||
// 先写长度头,接收方才能根据长度一次性读取完整消息体。
|
||||
if err := writeFull(w, header[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeFull(w, payload)
|
||||
}
|
||||
|
||||
// ReadFrame 从流中读取一个完整的长度前缀帧。
|
||||
// 它会先读取固定 4 字节长度头,校验长度是否合法,
|
||||
// 再使用 io.ReadFull 按长度读取完整消息体,
|
||||
// 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。
|
||||
func ReadFrame(r io.Reader) ([]byte, error) {
|
||||
var header [4]byte
|
||||
if _, err := io.ReadFull(r, header[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
size := binary.BigEndian.Uint32(header[:])
|
||||
// 长度为 0 的帧被认为是非法输入,而不是合法的空消息。
|
||||
if size == 0 {
|
||||
return nil, ErrInvalidFrameLength
|
||||
}
|
||||
// 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。
|
||||
if size > MaxFrameSize {
|
||||
return nil, ErrFrameTooLarge
|
||||
}
|
||||
|
||||
payload := make([]byte, int(size))
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// WriteMessage 是给上层直接使用的完整发送路径:
|
||||
// 把一条结构化消息完整编码并发送出去”的总入口。
|
||||
// Message -> header+body -> 长度前缀帧 -> io.Writer。
|
||||
func WriteMessage(w io.Writer, msg Message) error {
|
||||
payload, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("protocol: encode message: %w", err)
|
||||
}
|
||||
|
||||
if err := WriteFrame(w, payload); err != nil {
|
||||
return fmt.Errorf("protocol: write frame: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMessage 是给上层直接使用的完整接收路径:
|
||||
// io.Reader -> 长度前缀帧 -> header+body -> Message。
|
||||
func ReadMessage(r io.Reader) (Message, error) {
|
||||
payload, err := ReadFrame(r)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("protocol: read frame: %w", err)
|
||||
}
|
||||
|
||||
msg, err := DecodeMessage(payload)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("protocol: decode message: %w", err)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// validateMessage 检查 Message 传输的类型(只接受 text 和 file )。
|
||||
func validateMessage(msg Message) error {
|
||||
if msg.From == "" {
|
||||
return ErrMissingFrom
|
||||
}
|
||||
if msg.To == "" {
|
||||
return ErrMissingTo
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case MessageTypeText:
|
||||
if msg.FileName != "" {
|
||||
return ErrUnexpectedFileName
|
||||
}
|
||||
if !utf8.Valid(msg.Body) {
|
||||
return ErrInvalidTextBody
|
||||
}
|
||||
case MessageTypeFile:
|
||||
if msg.FileName == "" {
|
||||
return ErrMissingFileName
|
||||
}
|
||||
case MessageTypeRegister:
|
||||
if msg.To != ServerPeerID {
|
||||
return ErrInvalidRegisterTarget
|
||||
}
|
||||
if msg.FileName != "" {
|
||||
return ErrUnexpectedFileName
|
||||
}
|
||||
if len(msg.Body) != 0 {
|
||||
return ErrUnexpectedBody
|
||||
}
|
||||
case MessageTypeError:
|
||||
if msg.From != ServerPeerID {
|
||||
return ErrInvalidErrorSource
|
||||
}
|
||||
if msg.FileName != "" {
|
||||
return ErrUnexpectedFileName
|
||||
}
|
||||
if !utf8.Valid(msg.Body) {
|
||||
return ErrInvalidTextBody
|
||||
}
|
||||
default:
|
||||
return ErrInvalidMessageType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeFull 会持续写入,直到所有字节都写完或者底层返回错误。
|
||||
// 这样可以避免某些 Writer 发生部分写入时破坏帧格式。
|
||||
func writeFull(w io.Writer, data []byte) error {
|
||||
for len(data) > 0 {
|
||||
n, err := w.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
507
cmd/internal/protocol/codec_test.go
Normal file
507
cmd/internal/protocol/codec_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
|
||||
func TestEncodeDecodeMessageTextASCII(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 42,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8,
|
||||
// 从而天然兼容 ASCII 之外的普通文本。
|
||||
func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 43,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("你好, world"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
|
||||
func TestEncodeDecodeMessageFile(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 44,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "data.bin",
|
||||
Body: []byte{0x00, 0xff, 0x10, 0x7f},
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
|
||||
func TestEncodeDecodeMessageRegister(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 45,
|
||||
From: "peer-a",
|
||||
To: ServerPeerID,
|
||||
Body: []byte{},
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
|
||||
func TestEncodeDecodeMessageError(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 46,
|
||||
From: ServerPeerID,
|
||||
To: "peer-a",
|
||||
Body: []byte("unknown target"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
|
||||
// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
|
||||
func TestWriteReadFrame(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
payload := []byte("header+body")
|
||||
|
||||
if err := WriteFrame(&buf, payload); err != nil {
|
||||
t.Fatalf("WriteFrame() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadFrame(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame() error = %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(got, payload) {
|
||||
t.Fatalf("payload mismatch: got %q want %q", got, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
|
||||
// 因为外层帧非空的前提下,空正文是合法业务内容。
|
||||
func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message Message
|
||||
}{
|
||||
{
|
||||
name: "empty text",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte(""),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
message: Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "empty.txt",
|
||||
Body: []byte{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := WriteMessage(&buf, tt.message); err != nil {
|
||||
t.Fatalf("WriteMessage() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.message) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
|
||||
func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message Message
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid type",
|
||||
message: Message{
|
||||
Type: MessageType("unknown"),
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrInvalidMessageType,
|
||||
},
|
||||
{
|
||||
name: "missing from",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 2,
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrMissingFrom,
|
||||
},
|
||||
{
|
||||
name: "missing to",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 3,
|
||||
From: "peer-a",
|
||||
},
|
||||
wantErr: ErrMissingTo,
|
||||
},
|
||||
{
|
||||
name: "text with file name",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 4,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "bad.txt",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
wantErr: ErrUnexpectedFileName,
|
||||
},
|
||||
{
|
||||
name: "text with invalid utf8",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 5,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte{0xff, 0xfe},
|
||||
},
|
||||
wantErr: ErrInvalidTextBody,
|
||||
},
|
||||
{
|
||||
name: "file without file name",
|
||||
message: Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 6,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte{0x01},
|
||||
},
|
||||
wantErr: ErrMissingFileName,
|
||||
},
|
||||
{
|
||||
name: "register with wrong target",
|
||||
message: Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrInvalidRegisterTarget,
|
||||
},
|
||||
{
|
||||
name: "register with body",
|
||||
message: Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 8,
|
||||
From: "peer-a",
|
||||
To: ServerPeerID,
|
||||
Body: []byte("unexpected"),
|
||||
},
|
||||
wantErr: ErrUnexpectedBody,
|
||||
},
|
||||
{
|
||||
name: "error with wrong source",
|
||||
message: Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 9,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("bad"),
|
||||
},
|
||||
wantErr: ErrInvalidErrorSource,
|
||||
},
|
||||
{
|
||||
name: "error with file name",
|
||||
message: Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 10,
|
||||
From: ServerPeerID,
|
||||
To: "peer-a",
|
||||
FileName: "bad.txt",
|
||||
Body: []byte("bad"),
|
||||
},
|
||||
wantErr: ErrUnexpectedFileName,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := EncodeMessage(tt.message)
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
|
||||
// 而不是被当成一条合法的空消息。
|
||||
func TestReadFrameRejectsInvalidLength(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadFrame(&buf)
|
||||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
|
||||
// 从而保证最大长度限制真正生效。
|
||||
func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadFrame(&buf)
|
||||
if !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
|
||||
// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
|
||||
func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
err := WriteFrame(&buf, nil)
|
||||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||||
t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
|
||||
func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "too short for header len",
|
||||
data: []byte{0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
name: "zero header len",
|
||||
data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
name: "header len exceeds payload",
|
||||
data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := DecodeMessage(tt.data)
|
||||
if !errors.Is(err, ErrInvalidHeaderLength) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
|
||||
func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
|
||||
data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
|
||||
|
||||
_, err := DecodeMessage(data)
|
||||
if !errors.Is(err, ErrInvalidHeaderJSON) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
|
||||
func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
|
||||
headerPayload, err := json.Marshal(messageHeader{
|
||||
Type: MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
ContentLength: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
var data bytes.Buffer
|
||||
if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
if _, err := data.Write(headerPayload); err != nil {
|
||||
t.Fatalf("data.Write(headerPayload) error = %v", err)
|
||||
}
|
||||
if _, err := data.Write([]byte("hello")); err != nil {
|
||||
t.Fatalf("data.Write(body) error = %v", err)
|
||||
}
|
||||
|
||||
_, err = DecodeMessage(data.Bytes())
|
||||
if !errors.Is(err, ErrInvalidContentLength) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file,
|
||||
// 验证读取端每次都能严格停在当前帧边界,不会串包。
|
||||
func TestReadMultipleMessages(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
first := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
second := Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-b",
|
||||
To: "peer-a",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
|
||||
if err := WriteMessage(&buf, first); err != nil {
|
||||
t.Fatalf("WriteMessage(first) error = %v", err)
|
||||
}
|
||||
if err := WriteMessage(&buf, second); err != nil {
|
||||
t.Fatalf("WriteMessage(second) error = %v", err)
|
||||
}
|
||||
|
||||
gotFirst, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage(first) error = %v", err)
|
||||
}
|
||||
gotSecond, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage(second) error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotFirst, first) {
|
||||
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
|
||||
}
|
||||
if !reflect.DeepEqual(gotSecond, second) {
|
||||
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
|
||||
func TestReadMessageWrapsDecodeError(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
|
||||
t.Fatalf("WriteFrame() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadMessage(&buf)
|
||||
if err == nil {
|
||||
t.Fatal("ReadMessage() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "decode message") {
|
||||
t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
|
||||
}
|
||||
}
|
||||
33
cmd/internal/protocol/message.go
Normal file
33
cmd/internal/protocol/message.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package protocol
|
||||
|
||||
// MessageType 表示一条消息的传输类型。
|
||||
// v1 只区分普通文本和文件两类负载。
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// MessageTypeText 表示正文按 UTF-8 文本解释,天然兼容 ASCII。
|
||||
MessageTypeText MessageType = "text"
|
||||
// MessageTypeFile 表示正文是原始文件字节。
|
||||
MessageTypeFile MessageType = "file"
|
||||
// MessageTypeRegister 表示 peer 向 server 显式注册自己的身份。
|
||||
MessageTypeRegister MessageType = "register"
|
||||
// MessageTypeError 表示 server 向 peer 返回错误信息。
|
||||
MessageTypeError MessageType = "error"
|
||||
)
|
||||
|
||||
// ServerPeerID 是协议中约定的 server 端固定标识。
|
||||
const ServerPeerID = "server"
|
||||
|
||||
// Message 是 peer 和 server 共用的传输消息结构。
|
||||
// 头部元信息会被编码为 JSON,Body 则作为原始字节拼接在头部之后。
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"` // 消息类型,只允许 text 或 file。
|
||||
ID uint64 `json:"id"` // 由发送方生成,用于追踪消息。
|
||||
From string `json:"from"` // 发送方标识。
|
||||
To string `json:"to"` // 接收方标识。
|
||||
|
||||
// FileName 仅在 Type 为 file 时使用。
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
// Body 是真正传输的正文内容,不进入头部 JSON。
|
||||
Body []byte `json:"-"`
|
||||
}
|
||||
192
cmd/internal/server/hub.go
Normal file
192
cmd/internal/server/hub.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
const gracefulRejectCloseTimeout = 100 * time.Millisecond
|
||||
|
||||
// Hub 管理已注册 peer 的连接,并负责在它们之间转发消息。
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
peers map[string]*transport.TCPConn
|
||||
logger latencylog.Logger
|
||||
}
|
||||
|
||||
// Option 用于配置 Hub 的可选行为,例如时延日志。
|
||||
type Option func(*Hub)
|
||||
|
||||
// WithLogger 为 hub 注入时延日志记录器。
|
||||
func WithLogger(logger latencylog.Logger) Option {
|
||||
return func(hub *Hub) {
|
||||
hub.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// NewHub 创建一个空的连接中心。
|
||||
func NewHub(opts ...Option) *Hub {
|
||||
hub := &Hub{
|
||||
peers: make(map[string]*transport.TCPConn),
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(hub)
|
||||
}
|
||||
|
||||
if hub.logger == nil {
|
||||
hub.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
return hub
|
||||
}
|
||||
|
||||
// HasPeer 返回给定 ID 是否已经注册到 hub。
|
||||
func (h *Hub) HasPeer(peerID string) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
_, ok := h.peers[peerID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ServeConn 处理一条新接入的底层 TCP 连接。
|
||||
// 连接上的第一条消息必须是 register,之后才允许转发 text/file。
|
||||
func (h *Hub) ServeConn(rawConn net.Conn) error {
|
||||
conn, err := transport.NewTCPConn(rawConn)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return fmt.Errorf("server: create transport conn: %w", err)
|
||||
}
|
||||
|
||||
peerID, gracefulClose, err := h.registerConn(conn)
|
||||
if err != nil {
|
||||
h.closeConn(conn, gracefulClose)
|
||||
return err
|
||||
}
|
||||
defer h.unregister(peerID, conn)
|
||||
|
||||
if err := h.receivePeerLoop(peerID, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerConn 从新连接上读取第一条消息,验证它是 register 消息,并把连接注册到 hub。
|
||||
func (h *Hub) registerConn(conn *transport.TCPConn) (string, bool, error) {
|
||||
msg, err := conn.Receive()
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("server: receive register: %w", err)
|
||||
}
|
||||
|
||||
if msg.Type != protocol.MessageTypeRegister {
|
||||
if sendErr := sendServerError(conn, msg.From, "first message must be register"); sendErr != nil {
|
||||
return "", false, fmt.Errorf("server: reject unregistered peer: %w", sendErr)
|
||||
}
|
||||
return "", true, fmt.Errorf("server: first message must be register, got %s", msg.Type)
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if _, exists := h.peers[msg.From]; exists {
|
||||
if sendErr := sendServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
|
||||
return "", false, fmt.Errorf("server: duplicate peer id %s: %w", msg.From, sendErr)
|
||||
}
|
||||
return "", true, fmt.Errorf("server: duplicate peer id: %s", msg.From)
|
||||
}
|
||||
|
||||
h.peers[msg.From] = conn
|
||||
return msg.From, false, nil
|
||||
}
|
||||
|
||||
// handlePeerMessage 验证消息类型并执行相应的转发或错误响应。
|
||||
func (h *Hub) handlePeerMessage(peerID string, conn *transport.TCPConn, msg protocol.Message) (bool, error) {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile: //只允许已注册的 peer 发送文本或文件消息,其他类型都视为协议错误。
|
||||
msg.From = peerID
|
||||
targetConn, ok := h.lookup(msg.To)
|
||||
if !ok {
|
||||
return false, sendServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
|
||||
}
|
||||
if err := targetConn.Send(msg); err != nil { //转发消息,如果发送失败,说明目标连接可能已经不可用,此时从 hub 中注销该连接并关闭它,并向发送方返回错误响应。
|
||||
h.unregister(msg.To, targetConn)
|
||||
_ = targetConn.Close()
|
||||
return false, sendServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
|
||||
}
|
||||
return false, nil
|
||||
case protocol.MessageTypeRegister, protocol.MessageTypeError: //已注册的 peer 不允许再发送 register 或 error 消息,这些都视为协议错误。
|
||||
if err := sendServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
|
||||
return false, fmt.Errorf("server: send protocol error: %w", err)
|
||||
}
|
||||
return true, fmt.Errorf("server: unexpected message type from peer %s: %s", peerID, msg.Type)
|
||||
default: // 其他任何消息类型都视为协议错误。
|
||||
if err := sendServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
|
||||
return false, fmt.Errorf("server: send unsupported type error: %w", err)
|
||||
}
|
||||
return true, fmt.Errorf("server: unsupported message type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) receivePeerLoop(peerID string, conn *transport.TCPConn) error {
|
||||
for {
|
||||
msg, err := conn.Receive()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("transport: receive loop read: %w", err)
|
||||
}
|
||||
|
||||
gracefulClose, err := h.handlePeerMessage(peerID, conn, msg)
|
||||
if err != nil {
|
||||
h.closeConn(conn, gracefulClose)
|
||||
return fmt.Errorf("transport: receive loop handler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// lookup 在 hub 中查找目标 peer 的连接。
|
||||
func (h *Hub) lookup(peerID string) (*transport.TCPConn, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
conn, ok := h.peers[peerID]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// unregister 从 hub 中移除指定 peer 的连接,通常在连接关闭或发生错误时调用。
|
||||
func (h *Hub) unregister(peerID string, conn *transport.TCPConn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
current, ok := h.peers[peerID]
|
||||
if ok && current == conn {
|
||||
delete(h.peers, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) closeConn(conn *transport.TCPConn, graceful bool) {
|
||||
if graceful {
|
||||
_ = conn.CloseGracefully(gracefulRejectCloseTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
// sendServerError 是一个辅助函数,用于向指定 peer 发送错误消息。
|
||||
func sendServerError(conn *transport.TCPConn, to, message string) error {
|
||||
return conn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeError,
|
||||
From: protocol.ServerPeerID,
|
||||
To: to,
|
||||
Body: []byte(message),
|
||||
})
|
||||
}
|
||||
398
cmd/internal/server/hub_test.go
Normal file
398
cmd/internal/server/hub_test.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
type recordingLogger struct {
|
||||
mu sync.Mutex
|
||||
events []latencylog.Event
|
||||
}
|
||||
|
||||
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.events = append(l.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *recordingLogger) Events() []latencylog.Event {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
return append([]latencylog.Event(nil), l.events...)
|
||||
}
|
||||
|
||||
func TestServeConnRegistersPeer(t *testing.T) {
|
||||
hub := NewHub()
|
||||
client, done := startHubConn(t, hub)
|
||||
|
||||
if err := client.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-a",
|
||||
To: protocol.ServerPeerID,
|
||||
}); err != nil {
|
||||
t.Fatalf("Send(register) error = %v", err)
|
||||
}
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatalf("client.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-done
|
||||
if err == nil || !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("ServeConn() error = %v, want read-loop shutdown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeConnRejectsDuplicatePeerID(t *testing.T) {
|
||||
hub := NewHub()
|
||||
|
||||
first, firstDone := startHubConn(t, hub)
|
||||
registerPeer(t, first, "peer-a")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
second, secondDone := startHubConn(t, hub)
|
||||
registerPeer(t, second, "peer-a")
|
||||
|
||||
got, err := second.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("second.Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
|
||||
}
|
||||
if string(got.Body) != "duplicate peer id: peer-a" {
|
||||
t.Fatalf("error body = %q, want duplicate peer message", got.Body)
|
||||
}
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "original peer-a to remain registered")
|
||||
|
||||
if err := first.Close(); err != nil {
|
||||
t.Fatalf("first.Close() error = %v", err)
|
||||
}
|
||||
|
||||
if err := <-secondDone; err == nil || !strings.Contains(err.Error(), "duplicate peer id") {
|
||||
t.Fatalf("second ServeConn() error = %v, want duplicate peer id error", err)
|
||||
}
|
||||
if err := <-firstDone; err == nil || !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("first ServeConn() error = %v, want read-loop shutdown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeConnForwardsMessages(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "spoofed",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "spoofed",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hub := NewHub()
|
||||
|
||||
sender, senderDone := startHubConn(t, hub)
|
||||
receiver, receiverDone := startHubConn(t, hub)
|
||||
registerPeer(t, sender, "peer-a")
|
||||
registerPeer(t, receiver, "peer-b")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
if err := sender.Send(tt.msg); err != nil {
|
||||
t.Fatalf("sender.Send() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receiver.Receive() error = %v", err)
|
||||
}
|
||||
|
||||
want := tt.msg
|
||||
want.From = "peer-a"
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("forwarded message mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
<-senderDone
|
||||
<-receiverDone
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeConnReturnsErrorForUnknownTarget(t *testing.T) {
|
||||
hub := NewHub()
|
||||
|
||||
client, done := startHubConn(t, hub)
|
||||
registerPeer(t, client, "peer-a")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
if err := client.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "missing-peer",
|
||||
Body: []byte("hello"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Send(text) error = %v", err)
|
||||
}
|
||||
|
||||
got, err := client.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
|
||||
}
|
||||
if string(got.Body) != "unknown target: missing-peer" {
|
||||
t.Fatalf("error body = %q, want unknown target message", got.Body)
|
||||
}
|
||||
if !hub.HasPeer("peer-a") {
|
||||
t.Fatal("peer-a should remain registered after unknown target error")
|
||||
}
|
||||
|
||||
_ = client.Close()
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestServeConnRejectsRegisterAfterRegistration(t *testing.T) {
|
||||
hub := NewHub()
|
||||
|
||||
client, done := startHubConn(t, hub)
|
||||
registerPeer(t, client, "peer-a")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
if err := client.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: "peer-a",
|
||||
To: protocol.ServerPeerID,
|
||||
}); err != nil {
|
||||
t.Fatalf("Send(register again) error = %v", err)
|
||||
}
|
||||
|
||||
got, err := client.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
|
||||
}
|
||||
if string(got.Body) != "registered peers can only send text or file messages" {
|
||||
t.Fatalf("error body = %q, want registered-peer protocol error", got.Body)
|
||||
}
|
||||
|
||||
if err := <-done; err == nil || !strings.Contains(err.Error(), "unexpected message type from peer peer-a: register") {
|
||||
t.Fatalf("ServeConn() error = %v, want unexpected register message error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeConnUnregistersPeerOnClose(t *testing.T) {
|
||||
hub := NewHub()
|
||||
|
||||
client, done := startHubConn(t, hub)
|
||||
registerPeer(t, client, "peer-a")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatalf("client.Close() error = %v", err)
|
||||
}
|
||||
<-done
|
||||
|
||||
waitFor(t, func() bool { return !hub.HasPeer("peer-a") }, "peer-a to be unregistered")
|
||||
}
|
||||
|
||||
func TestServeConnDoesNotEmitEndpointLatencyEventsOnForward(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
hub := NewHub(WithLogger(logger))
|
||||
|
||||
sender, senderDone := startHubConn(t, hub)
|
||||
receiver, receiverDone := startHubConn(t, hub)
|
||||
registerPeer(t, sender, "peer-a")
|
||||
registerPeer(t, receiver, "peer-b")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 11,
|
||||
From: "spoofed",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Fatalf("sender.Send() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receiver.Receive() error = %v", err)
|
||||
}
|
||||
msg.From = "peer-a"
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("forwarded message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 0 {
|
||||
t.Fatalf("event count = %d, want 0 because server is a black-box relay", len(events))
|
||||
}
|
||||
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
<-senderDone
|
||||
<-receiverDone
|
||||
}
|
||||
|
||||
func TestServeConnDoesNotLogLatencyEventsForUnknownTarget(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
hub := NewHub(WithLogger(logger))
|
||||
|
||||
client, done := startHubConn(t, hub)
|
||||
registerPeer(t, client, "peer-a")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
if err := client.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 15,
|
||||
From: "peer-a",
|
||||
To: "missing-peer",
|
||||
Body: []byte("hello"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Send(text) error = %v", err)
|
||||
}
|
||||
|
||||
got, err := client.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
|
||||
}
|
||||
if events := logger.Events(); len(events) != 0 {
|
||||
t.Fatalf("event count = %d, want 0 for unknown target path", len(events))
|
||||
}
|
||||
|
||||
_ = client.Close()
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestServeConnDoesNotLogLatencyEventsForDuplicateRegister(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
hub := NewHub(WithLogger(logger))
|
||||
|
||||
first, firstDone := startHubConn(t, hub)
|
||||
registerPeer(t, first, "peer-a")
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
second, secondDone := startHubConn(t, hub)
|
||||
registerPeer(t, second, "peer-a")
|
||||
|
||||
got, err := second.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("second.Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
|
||||
}
|
||||
if events := logger.Events(); len(events) != 0 {
|
||||
t.Fatalf("event count = %d, want 0 for duplicate register path", len(events))
|
||||
}
|
||||
|
||||
_ = first.Close()
|
||||
<-secondDone
|
||||
<-firstDone
|
||||
}
|
||||
|
||||
func startHubConn(t *testing.T, hub *Hub) (*transport.TCPConn, <-chan error) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
serverSide, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
done <- acceptErr
|
||||
return
|
||||
}
|
||||
done <- hub.ServeConn(serverSide)
|
||||
}()
|
||||
|
||||
clientSide, err := net.Dial("tcp", listener.Addr().String())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
t.Fatalf("net.Dial() error = %v", err)
|
||||
}
|
||||
if err := listener.Close(); err != nil {
|
||||
t.Fatalf("listener.Close() error = %v", err)
|
||||
}
|
||||
|
||||
conn, err := transport.NewTCPConn(clientSide)
|
||||
if err != nil {
|
||||
_ = clientSide.Close()
|
||||
t.Fatalf("transport.NewTCPConn() error = %v", err)
|
||||
}
|
||||
|
||||
return conn, done
|
||||
}
|
||||
|
||||
func registerPeer(t *testing.T, conn *transport.TCPConn, peerID string) {
|
||||
t.Helper()
|
||||
|
||||
if err := conn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: peerID,
|
||||
To: protocol.ServerPeerID,
|
||||
}); err != nil {
|
||||
t.Fatalf("Send(register %s) error = %v", peerID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitFor(t *testing.T, condition func() bool, description string) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("timed out waiting for %s", description)
|
||||
}
|
||||
151
cmd/internal/transport/tcp.go
Normal file
151
cmd/internal/transport/tcp.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// TCPConn 是对单条活跃 TCP 连接的轻量封装。
|
||||
// 它负责把协议层的单条消息读写,提升为可复用的收发接口。
|
||||
type TCPConn struct {
|
||||
conn net.Conn
|
||||
raw syscall.RawConn // 连接对应的底层 syscall 句柄,用于 Linux socket timestamping 收发。
|
||||
|
||||
logger latencylog.Logger
|
||||
nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer"
|
||||
nodeID string // 日志中记录的节点 ID,例如 peer 的 ID 或 server 的 "hub"
|
||||
writeMu sync.Mutex // 保护 Send 方法的互斥锁,确保同一时刻只有一条完整协议消息被写入连接,防止多条消息字节交叉
|
||||
closeOnce sync.Once // 保护 Close 方法的 sync.Once,确保连接只被关闭一次
|
||||
closeErr error // 连接关闭时的错误,如果连接成功关闭则为 nil,重复调用 Close 时会返回同样的错误
|
||||
}
|
||||
|
||||
// Option 用于为 TCPConn 注入可选行为,例如时延日志。
|
||||
type Option func(*TCPConn)
|
||||
|
||||
// WithLogger 为连接发送路径注入业务消息日志上下文。
|
||||
func WithLogger(logger latencylog.Logger, nodeRole, nodeID string) Option {
|
||||
return func(conn *TCPConn) {
|
||||
conn.logger = logger
|
||||
conn.nodeRole = nodeRole
|
||||
conn.nodeID = nodeID
|
||||
}
|
||||
}
|
||||
|
||||
// NewTCPConn 用已有的 net.Conn 创建 transport 连接封装。
|
||||
func NewTCPConn(conn net.Conn, opts ...Option) (*TCPConn, error) {
|
||||
tcpConn := &TCPConn{
|
||||
conn: conn,
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(tcpConn)
|
||||
}
|
||||
|
||||
if tcpConn.logger == nil {
|
||||
tcpConn.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
if err := tcpConn.initLinuxTimestamping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
||||
|
||||
// Send 将一条协议消息完整写入底层连接。
|
||||
// 多个 goroutine 可以并发调用,内部会串行化写入。
|
||||
func (c *TCPConn) Send(msg protocol.Message) error {
|
||||
c.writeMu.Lock() //“同一时刻只能有一条完整协议消息往连接里写,防止多条消息字节交叉
|
||||
defer c.writeMu.Unlock()
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
|
||||
|
||||
if err := c.sendMessageLinux(msg); err != nil {
|
||||
return fmt.Errorf("transport: send message: %w", err)
|
||||
}
|
||||
//记录发送完成的时延日志事件,事件类型为 EventSendHandoffEnd,包含消息的基本信息(类型、ID、来源、目标)。
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive 从底层连接读取一条完整协议消息。
|
||||
// 同一条连接应只由单个 reader 持续调用该方法。
|
||||
func (c *TCPConn) Receive() (protocol.Message, error) {
|
||||
msg, err := c.receiveMessageLinux()
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("transport: receive message: %w", err)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReceiveLoop 持续读取消息并交给 handler 处理。
|
||||
// 读取错误、handler 错误或连接关闭都会结束循环,并关闭连接。
|
||||
func (c *TCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
|
||||
for {
|
||||
msg, err := c.Receive()
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("transport: receive loop read: %w", err)
|
||||
}
|
||||
|
||||
if err := handler(msg); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("transport: receive loop handler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloseGracefully 在支持 half-close 的连接上先关闭写方向,给对端留出读取最终响应的机会,
|
||||
// 然后在短暂等待后再彻底关闭连接。
|
||||
func (c *TCPConn) CloseGracefully(drainTimeout time.Duration) error {
|
||||
if closeWriter, ok := c.conn.(interface{ CloseWrite() error }); ok {
|
||||
if err := closeWriter.CloseWrite(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
if drainTimeout > 0 {
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(drainTimeout))
|
||||
defer func() {
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
var buf [256]byte
|
||||
for {
|
||||
_, err := c.conn.Read(buf[:])
|
||||
switch {
|
||||
case err == nil:
|
||||
continue
|
||||
case errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
|
||||
return c.Close()
|
||||
default:
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return c.Close()
|
||||
}
|
||||
return c.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// Close 关闭底层连接,并保证重复调用是安全的。
|
||||
func (c *TCPConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
|
||||
return c.closeErr
|
||||
}
|
||||
462
cmd/internal/transport/tcp_linux.go
Normal file
462
cmd/internal/transport/tcp_linux.go
Normal file
@@ -0,0 +1,462 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
linuxTimestampControlBufferSize = 256 // 控制消息缓冲区。
|
||||
linuxTXTimestampWaitTimeout = 250 * time.Millisecond // 等待 TX 时间戳的上限。
|
||||
linuxTXTimestampPollInterval = time.Millisecond // 轮询 errqueue 的间隔。
|
||||
|
||||
linuxSOTimestampingNew = 0x41
|
||||
linuxSCMTimestampingNew = linuxSOTimestampingNew
|
||||
linuxSOEEOriginTimestamping = 4 // timestamping errqueue 事件。
|
||||
linuxSCMTstampSnd = 0 // 对应 A_TX_SOFTWARE。
|
||||
linuxSCMTstampSched = 1 // 对应 A_TX_SCHED。
|
||||
|
||||
linuxSOFTimestampingTXSoftware = 1 << 1 // 打开 TX software timestamp。
|
||||
linuxSOFTimestampingRXSoftware = 1 << 3 // 打开 RX software timestamp。
|
||||
linuxSOFTimestampingSoftware = 1 << 4 // software timestamp 总开关。
|
||||
linuxSOFTimestampingOptID = 1 << 7 // 给时间戳关联 ID。
|
||||
linuxSOFTimestampingTXSched = 1 << 8 // 打开 TX sched timestamp。
|
||||
linuxSOFTimestampingOptTSONLY = 1 << 11 // 只回时间戳。
|
||||
linuxSOFTimestampingOptIDTCP = 1 << 16 // 让 TCP 也带 timestamp ID。
|
||||
)
|
||||
|
||||
// 拿到底层 fd,并打开 Linux timestamping。
|
||||
func (c *TCPConn) initLinuxTimestamping() error {
|
||||
sysConn, ok := c.conn.(interface {
|
||||
SyscallConn() (syscall.RawConn, error)
|
||||
})
|
||||
if !ok {
|
||||
return fmt.Errorf("transport: connection does not support SyscallConn")
|
||||
}
|
||||
|
||||
rawConn, err := sysConn.SyscallConn()
|
||||
if err != nil || rawConn == nil {
|
||||
if err != nil {
|
||||
return fmt.Errorf("transport: get syscall conn: %w", err)
|
||||
}
|
||||
return fmt.Errorf("transport: missing syscall conn")
|
||||
}
|
||||
|
||||
//socket是否可以成功打开 timestamping 取决于内核版本和配置,尝试多个 flag 组合直到成功或遇到非 EINVAL 错误。
|
||||
if err := enableLinuxTimestamping(rawConn); err != nil {
|
||||
return fmt.Errorf("transport: enable linux timestamping: %w", err)
|
||||
}
|
||||
//成功打开 timestamping 后,rawConn 就可以用来收 TX/RX 时间戳了。
|
||||
c.raw = rawConn
|
||||
return nil
|
||||
}
|
||||
|
||||
// 给 socket开权限打开TX software timestamping。
|
||||
func enableLinuxTimestamping(rawConn syscall.RawConn) error {
|
||||
flagCandidates := []int{ //不同linux版本可能支持不同的 flag 组合,尝试多个组合直到成功。
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptID | //TCP 协议栈给每个时间戳生成一个序列号
|
||||
linuxSOFTimestampingOptIDTCP |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptID |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, flags := range flagCandidates { //尝试不同的 flag 组合,直到成功或遇到非 EINVAL 错误。
|
||||
// 内核根据 fd 找到对应的内存结构体(Socket 缓冲区)
|
||||
err := rawConn.Control(func(fd uintptr) { //Control 方法保证在回调里 fd 是有效的,可以安全地调用 syscall.SetsockoptInt。
|
||||
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if lastErr == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(lastErr, syscall.EINVAL) {
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// sendMessageLinux 编码消息、写完整帧,再记录 TX 时间戳。
|
||||
func (c *TCPConn) sendMessageLinux(msg protocol.Message) error {
|
||||
payload, err := protocol.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("protocol: encode message: %w", err)
|
||||
}
|
||||
|
||||
//编码后的消息 payload 前面加 4 字节长度,构成完整帧。
|
||||
frame := make([]byte, 4+len(payload))
|
||||
binary.BigEndian.PutUint32(frame[:4], uint32(len(payload)))
|
||||
copy(frame[4:], payload)
|
||||
|
||||
if err := c.writeFrameLinux(frame); err != nil {
|
||||
return fmt.Errorf("protocol: write frame: %w", err)
|
||||
}
|
||||
//记录发送延时日志
|
||||
c.logTXTimestampEvents(msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeFrameLinux 用 sendmsg 写完整帧。
|
||||
func (c *TCPConn) writeFrameLinux(frame []byte) error {
|
||||
written := 0
|
||||
var opErr error
|
||||
|
||||
err := c.raw.Write(func(fd uintptr) bool {
|
||||
if written >= len(frame) {
|
||||
return true
|
||||
}
|
||||
|
||||
n, sendErr := syscall.SendmsgN(int(fd), frame[written:], nil, nil, 0)
|
||||
switch {
|
||||
case sendErr == nil:
|
||||
if n <= 0 {
|
||||
opErr = io.ErrShortWrite
|
||||
return true
|
||||
}
|
||||
written += n
|
||||
return written >= len(frame)
|
||||
case errors.Is(sendErr, syscall.EAGAIN), errors.Is(sendErr, syscall.EWOULDBLOCK):
|
||||
return false
|
||||
default:
|
||||
opErr = sendErr
|
||||
return true
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if opErr != nil {
|
||||
return opErr
|
||||
}
|
||||
if written != len(frame) {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 把 A_TX_SCHED / A_TX_SOFTWARE 写入日志。(发送过程中)
|
||||
func (c *TCPConn) logTXTimestampEvents(msg protocol.Message) {
|
||||
timestamps := c.collectTXTimestampEvents()
|
||||
|
||||
if ts, ok := timestamps[latencylog.EventATXSched]; ok {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg)
|
||||
}
|
||||
if ts, ok := timestamps[latencylog.EventATXSoftware]; ok {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// 在 errqueue 里等两类 TX 时间戳。
|
||||
func (c *TCPConn) collectTXTimestampEvents() map[string]int64 {
|
||||
timestamps := make(map[string]int64, 2)
|
||||
//设置合理等待上限
|
||||
deadline := time.Now().Add(linuxTXTimestampWaitTimeout)
|
||||
|
||||
//轮询 errqueue 直到拿到两类时间戳,或超时,或遇到非 EAGAIN 错误。
|
||||
for len(timestamps) < 2 && time.Now().Before(deadline) {
|
||||
eventName, ts, err := c.recvTXTimestampOnce()
|
||||
if err != nil {
|
||||
if isWouldBlock(err) {
|
||||
time.Sleep(linuxTXTimestampPollInterval)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if eventName == "" || ts <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := timestamps[eventName]; !exists {
|
||||
timestamps[eventName] = ts
|
||||
}
|
||||
}
|
||||
|
||||
return timestamps
|
||||
}
|
||||
|
||||
// recvTXTimestampOnce 从 errqueue 读一次时间戳事件。
|
||||
func (c *TCPConn) recvTXTimestampOnce() (string, int64, error) {
|
||||
var (
|
||||
eventName string // 事件名,例如 A_TX_SCHED 或 A_TX_SOFTWARE。
|
||||
tsUnixNS int64 // 时间戳的 UnixNano 表示。
|
||||
opErr error
|
||||
)
|
||||
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
//设置足够大的 oob buffer 来接收控制消息,调用 recvmsg 从 errqueue 读一条消息。
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
//recvmsg 的 flags 里必须带 MSG_ERRQUEUE,才能从 errqueue 里读消息,非阻塞模式下如果没有消息可读会返回 EAGAIN。
|
||||
_, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
|
||||
if recvErr != nil {
|
||||
opErr = recvErr
|
||||
return
|
||||
}
|
||||
//解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
|
||||
eventName, tsUnixNS = parseTXTimestampControlMessages(oob[:oobn])
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if opErr != nil {
|
||||
return "", 0, opErr
|
||||
}
|
||||
|
||||
return eventName, tsUnixNS, nil //如果成功拿到时间戳事件,eventName 会是 A_TX_SCHED 或 A_TX_SOFTWARE 之一,tsUnixNS 是对应的时间戳;如果没有拿到事件或时间戳无效,eventName 会是空字符串,tsUnixNS 会是 0。
|
||||
}
|
||||
|
||||
// 把底层时间戳映射成日志事件名。
|
||||
func parseTXTimestampControlMessages(oob []byte) (string, int64) {
|
||||
if len(oob) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
//解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
|
||||
controlMessages, err := syscall.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var (
|
||||
tsUnixNS int64 //时间戳的 UnixNano 表示。
|
||||
tsKind uint32 //extended err里,告诉我们这个时间戳是 sched 还是 software。
|
||||
hasTS bool // 是否拿到时间戳了。
|
||||
hasKind bool // 是否拿到时间戳类型了。
|
||||
)
|
||||
//一个 recvmsg 可能会收到多个控制消息,循环找我们关心的时间戳事件,拿到时间戳和事件类型。
|
||||
for _, controlMessage := range controlMessages {
|
||||
switch {
|
||||
case controlMessage.Header.Level == syscall.SOL_SOCKET && controlMessage.Header.Type == linuxSCMTimestampingNew:
|
||||
if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
|
||||
tsUnixNS = ts
|
||||
hasTS = true
|
||||
}
|
||||
case isSocketExtendedErr(controlMessage): //判断时间戳是否进入了errqueue,
|
||||
if info, ok := parseSocketExtendedErrInfo(controlMessage.Data); ok {
|
||||
tsKind = info //时间戳类型被内核放在 extended err 的附加信息里,解析出来。
|
||||
hasKind = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTS || !hasKind {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
switch tsKind { //把内核的时间戳类型映射成日志事件名。(记录时只关心 sched 和 software 两类时间戳)
|
||||
case linuxSCMTstampSched:
|
||||
return latencylog.EventATXSched, tsUnixNS
|
||||
case linuxSCMTstampSnd:
|
||||
return latencylog.EventATXSoftware, tsUnixNS
|
||||
default:
|
||||
return "", 0
|
||||
}
|
||||
}
|
||||
|
||||
// 判断控制消息是否来自 socket extended err。
|
||||
// 内核产生的时间戳并不会混合在普通的数据流里,而是被包装成一种特殊的“错误消息”丢进 Error Queue。
|
||||
func isSocketExtendedErr(controlMessage syscall.SocketControlMessage) bool {
|
||||
switch {
|
||||
case controlMessage.Header.Level == syscall.SOL_IP && controlMessage.Header.Type == syscall.IP_RECVERR:
|
||||
return true
|
||||
case controlMessage.Header.Level == syscall.SOL_IPV6 && controlMessage.Header.Type == syscall.IPV6_RECVERR:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 从 socket extended err 的数据里取 origin timestamping 信息。
|
||||
func parseSocketExtendedErrInfo(data []byte) (uint32, bool) {
|
||||
if len(data) < 16 {
|
||||
return 0, false
|
||||
}
|
||||
if data[4] != linuxSOEEOriginTimestamping {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return binary.NativeEndian.Uint32(data[8:12]), true
|
||||
}
|
||||
|
||||
// 读一条完整消息,并记录 B_RX_SOFTWARE。
|
||||
func (c *TCPConn) receiveMessageLinux() (protocol.Message, error) {
|
||||
payload, rxTimestamp, err := c.readFrameLinux()
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("protocol: read frame: %w", err)
|
||||
}
|
||||
|
||||
msg, err := protocol.DecodeMessage(payload)
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("protocol: decode message: %w", err)
|
||||
}
|
||||
|
||||
if rxTimestamp > 0 {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// readFrameLinux 先读 4 字节长度,再读整条 payload。
|
||||
func (c *TCPConn) readFrameLinux() ([]byte, int64, error) {
|
||||
var frameHeader [4]byte
|
||||
rxTimestamp, err := c.readFullLinux(frameHeader[:])
|
||||
if err != nil {
|
||||
return nil, rxTimestamp, err
|
||||
}
|
||||
|
||||
size := binary.BigEndian.Uint32(frameHeader[:])
|
||||
switch {
|
||||
case size == 0:
|
||||
return nil, rxTimestamp, protocol.ErrInvalidFrameLength
|
||||
case size > protocol.MaxFrameSize:
|
||||
return nil, rxTimestamp, protocol.ErrFrameTooLarge
|
||||
}
|
||||
|
||||
payload := make([]byte, int(size))
|
||||
bodyTimestamp, err := c.readFullLinux(payload)
|
||||
if rxTimestamp == 0 {
|
||||
rxTimestamp = bodyTimestamp
|
||||
}
|
||||
if err != nil {
|
||||
return nil, rxTimestamp, err
|
||||
}
|
||||
|
||||
return payload, rxTimestamp, nil
|
||||
}
|
||||
|
||||
// 读满 buf,并保留首个 RX_SOFTWARE(返回进入tcp协议栈的时间戳)。
|
||||
func (c *TCPConn) readFullLinux(buf []byte) (int64, error) {
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
offset int
|
||||
firstRXTime int64
|
||||
)
|
||||
|
||||
for offset < len(buf) {
|
||||
n, rxTimestamp, err := c.recvmsgLinux(buf[offset:])
|
||||
if firstRXTime == 0 && rxTimestamp > 0 {
|
||||
firstRXTime = rxTimestamp
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) && offset > 0 {
|
||||
return firstRXTime, io.ErrUnexpectedEOF
|
||||
}
|
||||
return firstRXTime, err
|
||||
}
|
||||
|
||||
offset += n
|
||||
}
|
||||
|
||||
return firstRXTime, nil
|
||||
}
|
||||
|
||||
// recvmsgLinux 用 recvmsg 同时读取数据和控制消息。
|
||||
func (c *TCPConn) recvmsgLinux(buf []byte) (int, int64, error) {
|
||||
var (
|
||||
n int
|
||||
rxTimeNS int64
|
||||
opErr error
|
||||
)
|
||||
|
||||
err := c.raw.Read(func(fd uintptr) bool {
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
readN, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
|
||||
switch {
|
||||
case recvErr == nil:
|
||||
if readN == 0 {
|
||||
opErr = io.EOF
|
||||
return true
|
||||
}
|
||||
n = readN
|
||||
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
|
||||
return true
|
||||
case errors.Is(recvErr, syscall.EAGAIN), errors.Is(recvErr, syscall.EWOULDBLOCK):
|
||||
return false
|
||||
default:
|
||||
opErr = recvErr
|
||||
return true
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if opErr != nil {
|
||||
return 0, 0, opErr
|
||||
}
|
||||
|
||||
return n, rxTimeNS, nil
|
||||
}
|
||||
|
||||
// 从控制消息里取 RX_SOFTWARE。
|
||||
func parseRXTimestampControlMessages(oob []byte) int64 {
|
||||
if len(oob) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
controlMessages, err := syscall.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, controlMessage := range controlMessages {
|
||||
if controlMessage.Header.Level != syscall.SOL_SOCKET || controlMessage.Header.Type != linuxSCMTimestampingNew {
|
||||
continue
|
||||
}
|
||||
|
||||
if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
|
||||
return ts
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// 取第一个非零 timespec。
|
||||
func parseSCMTimestampingData(data []byte) int64 {
|
||||
const timespec64Size = 16
|
||||
|
||||
for offset := 0; offset+timespec64Size <= len(data); offset += timespec64Size {
|
||||
sec := int64(binary.NativeEndian.Uint64(data[offset : offset+8]))
|
||||
nsec := int64(binary.NativeEndian.Uint64(data[offset+8 : offset+16]))
|
||||
if sec == 0 && nsec == 0 {
|
||||
continue
|
||||
}
|
||||
return sec*int64(time.Second) + nsec
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// 判断错误是否是 EAGAIN 或 EWOULDBLOCK。
|
||||
func isWouldBlock(err error) bool {
|
||||
return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK)
|
||||
}
|
||||
140
cmd/internal/transport/tcp_linux_test.go
Normal file
140
cmd/internal/transport/tcp_linux_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
func TestLinuxTimestampingRecordsKernelEvents(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 41,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello over tcp"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 42,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x00, 0x01, 0x02, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientConn, serverConn := newTCPPair(t)
|
||||
|
||||
senderLogger := &recordingLogger{}
|
||||
receiverLogger := &recordingLogger{}
|
||||
sender, err := NewTCPConn(
|
||||
clientConn,
|
||||
WithLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(sender) error = %v", err)
|
||||
}
|
||||
receiver, err := NewTCPConn(
|
||||
serverConn,
|
||||
WithLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(receiver) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
})
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
|
||||
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID)
|
||||
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID)
|
||||
assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTCPPair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
|
||||
type acceptResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
accepted := make(chan acceptResult, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
accepted <- acceptResult{conn: conn, err: acceptErr}
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp", listener.Addr().String())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
t.Fatalf("net.Dial() error = %v", err)
|
||||
}
|
||||
|
||||
result := <-accepted
|
||||
if err := listener.Close(); err != nil {
|
||||
t.Fatalf("listener.Close() error = %v", err)
|
||||
}
|
||||
if result.err != nil {
|
||||
_ = clientConn.Close()
|
||||
t.Fatalf("listener.Accept() error = %v", result.err)
|
||||
}
|
||||
|
||||
return clientConn, result.conn
|
||||
}
|
||||
|
||||
func assertHasEvent(t *testing.T, events []latencylog.Event, wantEvent string, wantMessageID uint64) {
|
||||
t.Helper()
|
||||
|
||||
for _, event := range events {
|
||||
if event.Event == wantEvent && event.MessageID == wantMessageID {
|
||||
if event.TsUnixNano <= 0 {
|
||||
t.Fatalf("event %s timestamp must be positive: %+v", wantEvent, event)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("missing event %s for message %d in %+v", wantEvent, wantMessageID, events)
|
||||
}
|
||||
416
cmd/internal/transport/tcp_test.go
Normal file
416
cmd/internal/transport/tcp_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
type recordingLogger struct {
|
||||
mu sync.Mutex
|
||||
events []latencylog.Event
|
||||
}
|
||||
|
||||
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.events = append(l.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *recordingLogger) Events() []latencylog.Event {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
return append([]latencylog.Event(nil), l.events...)
|
||||
}
|
||||
|
||||
type failingLogger struct{}
|
||||
|
||||
func (failingLogger) LogEvent(latencylog.Event) error {
|
||||
return errors.New("log failed")
|
||||
}
|
||||
|
||||
// TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。
|
||||
func TestSendReceiveMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "data.bin",
|
||||
Body: []byte{0x00, 0x10, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
//创建一个容量为1的缓冲通道sendErr,用于接收发送操作的错误结果。
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg) //发送消息,并将结果(错误或nil)发送到sendErr通道。
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLogsHandoffEvents(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
sender, receiver := newTransportConnPair(
|
||||
t,
|
||||
[]Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("event count = %d, want 4", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventSendHandoffBegin {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
|
||||
}
|
||||
if events[1].Event != latencylog.EventATXSched {
|
||||
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched)
|
||||
}
|
||||
if events[2].Event != latencylog.EventATXSoftware {
|
||||
t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware)
|
||||
}
|
||||
if events[3].Event != latencylog.EventSendHandoffEnd {
|
||||
t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd)
|
||||
}
|
||||
for i, event := range events {
|
||||
if event.MessageID != msg.ID {
|
||||
t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID)
|
||||
}
|
||||
}
|
||||
if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" {
|
||||
t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a")
|
||||
}
|
||||
if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 {
|
||||
t.Fatalf("timestamps must be positive: %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendIgnoresLoggerFailure(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(
|
||||
t,
|
||||
[]Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 9,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v, want nil even if logger fails", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
|
||||
func TestReceiveLoopDeliversMessages(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
want := []protocol.Message{
|
||||
{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
got []protocol.Message
|
||||
)
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
got = append(got, msg)
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
for _, msg := range want {
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
}
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。
|
||||
func TestConcurrentSendKeepsMessagesIntact(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
// 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。
|
||||
want := []protocol.Message{
|
||||
{Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")},
|
||||
{Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")},
|
||||
{Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")},
|
||||
{Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")},
|
||||
}
|
||||
|
||||
received := make(chan protocol.Message, len(want))
|
||||
readErr := make(chan error, 1)
|
||||
go func() { //异步地运行一个 goroutine
|
||||
for range want {
|
||||
msg, err := receiver.Receive()
|
||||
if err != nil {
|
||||
readErr <- err
|
||||
return
|
||||
}
|
||||
received <- msg
|
||||
}
|
||||
readErr <- nil
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, msg := range want {
|
||||
msg := msg
|
||||
wg.Add(1)
|
||||
go func() { //异步处理
|
||||
defer wg.Done()
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Errorf("Send() error = %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if err := <-readErr; err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
|
||||
gotByID := make(map[uint64]protocol.Message, len(want))
|
||||
for range want {
|
||||
msg := <-received
|
||||
gotByID[msg.ID] = msg
|
||||
}
|
||||
|
||||
for _, msg := range want {
|
||||
got, ok := gotByID[msg.ID]
|
||||
if !ok {
|
||||
t.Fatalf("missing message with ID %d", msg.ID)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。
|
||||
func TestReceiveLoopStopsOnHandlerError(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
wantErr := errors.New("stop loop")
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
return wantErr
|
||||
})
|
||||
}()
|
||||
|
||||
first := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
if err := sender.Send(first); err != nil {
|
||||
t.Fatalf("Send(first) error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop handler") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want handler context", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。
|
||||
func TestReceiveLoopStopsOnReadError(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。
|
||||
func TestCloseIsIdempotent(t *testing.T) {
|
||||
conn, peer := newTransportConnPair(t, nil, nil)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(first) error = %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(second) error = %v, want nil", err)
|
||||
}
|
||||
if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("peer.Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。
|
||||
func TestReceiveReturnsWrappedReadError(t *testing.T) {
|
||||
conn, peer := newTransportConnPair(t, nil, nil)
|
||||
go func() {
|
||||
_ = peer.Close()
|
||||
}()
|
||||
|
||||
_, err := conn.Receive()
|
||||
if err == nil {
|
||||
t.Fatal("Receive() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "transport: receive message") {
|
||||
t.Fatalf("Receive() error = %v, want wrapped receive error", err)
|
||||
}
|
||||
if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("Receive() error = %v, want underlying read failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) {
|
||||
t.Helper()
|
||||
|
||||
left, right := newTCPPair(t)
|
||||
|
||||
sender, err := NewTCPConn(left, senderOpts...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(sender) error = %v", err)
|
||||
}
|
||||
receiver, err := NewTCPConn(right, receiverOpts...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(receiver) error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
})
|
||||
|
||||
return sender, receiver
|
||||
}
|
||||
Reference in New Issue
Block a user