commit 48246752444e0bb786578645570279e4f3db9d6b Author: nnbcccscdscdsc <2709767634@qq.com> Date: Mon Mar 23 20:18:53 2026 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..582ec45 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +bin/* + diff --git a/README.md b/README.md new file mode 100644 index 0000000..d8b47ed --- /dev/null +++ b/README.md @@ -0,0 +1,91 @@ +# OmniSocketGo + +Linux only. Go 1.22. + +如果目标机器只运行 `server`,只需要编译并拷贝 `server` 二进制。 +如果目标机器只运行 `peer`,只需要编译并拷贝 `peer` 二进制。 + +`go build ./cmd/server` 和 `go build ./cmd/peer` 会把各自依赖到的功能一起编译进最终二进制,不需要再单独编译 `cmd/internal/...` 包。 + +- `server` 二进制会包含它依赖到的转发、协议、传输等代码 +- `peer` 二进制会包含它依赖到的注册、交互发送、接收落盘、协议、传输等代码 +- 只有没有被这个可执行程序引用的其他命令,才不在该二进制里,比如 `cmd/latencysummary` + +## Build + +按目标架构分别编译。 + +### Linux amd64 + +```bash +CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/server-linux-amd64 ./cmd/server +CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/peer-linux-amd64 ./cmd/peer +``` + +### Linux arm64 + +```bash +CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/server-linux-arm64 ./cmd/server +CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/peer-linux-arm64 ./cmd/peer +``` + +### Linux armv7 + +```bash +CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/server-linux-armv7 ./cmd/server +CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/peer-linux-armv7 ./cmd/peer +``` + +## Deploy + +把对应架构的二进制拷到目标机器,并赋予可执行权限。 + +```bash +scp bin/server-linux-amd64 user@server-host:~/omnisocket/server +scp bin/peer-linux-amd64 user@peer-host:~/omnisocket/peer +``` + +```bash +ssh user@server-host 'chmod +x ~/omnisocket/server' +ssh user@peer-host 'chmod +x ~/omnisocket/peer' +``` + +## Run On Different Machines + +`server` 所在机器监听 `0.0.0.0:9000`。 + +```bash +~/omnisocket/server -listen 0.0.0.0:9000 +``` + +每个 `peer` 所在机器连接 `server` 的实际 IP,例如 `192.168.1.50:9000`。 + +### peer-a + +```bash +~/omnisocket/peer \ + -id peer-a \ + -server 192.168.1.50:9000 \ + -inbox-dir /tmp/peer-a-inbox +``` + +### peer-b + +```bash +~/omnisocket/peer \ + -id peer-b \ + -server 192.168.1.50:9000 \ + -inbox-dir /tmp/peer-b-inbox +``` + +## Interactive Commands + +`peer` 启动后可以在终端里持续使用同一条长连接发送多次消息。 + +```text +help +text peer-b hello +text peer-a hi +file peer-b /tmp/test.bin +quit +``` diff --git a/cmd/internal/latencylog/logger.go b/cmd/internal/latencylog/logger.go new file mode 100644 index 0000000..f1f1f31 --- /dev/null +++ b/cmd/internal/latencylog/logger.go @@ -0,0 +1,166 @@ +package latencylog + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" + + "omnisocketgo/cmd/internal/protocol" +) + +const ( + NodeRolePeer = "peer" //客户端节点 + NodeRoleServer = "server" //云端转发节点 +) + +// 记录的消息事件的类型常量。 +const ( + EventAAppPrepBegin = "A_APP_PREP_BEGIN" // A 端应用开始准备这条消息 + EventATXSched = "A_TX_SCHED" // A 端进入 Linux qdisc 之前 + EventATXSoftware = "A_TX_SOFTWARE" // A 端即将交给网卡驱动 + EventATXHardware = "A_TX_HARDWARE" // A 端网卡真正发出到物理介质 + EventBRXHardware = "B_RX_HARDWARE" // B 端网卡真正从物理介质收到 + EventBRXSoftware = "B_RX_SOFTWARE" // B 端驱动把数据交给 Linux 接收栈 + EventBAppRecv = "B_APP_RECV" // B 端应用真正读到完整消息 + EventBPersistBegin = "B_PERSIST_BEGIN" // B 端开始写盘 + EventBPersistEnd = "B_PERSIST_END" // B 端写盘完成 + + EventSendHandoffBegin = "send_handoff_begin" // 调试事件:应用把消息交给传输层开始 + EventSendHandoffEnd = "send_handoff_end" // 调试事件:应用把消息交给传输层结束 +) + +// Event 是一条时延时间戳日志记录。 +type Event struct { + TsUnixNano int64 `json:"ts_unix_nano"` + NodeRole string `json:"node_role"` + NodeID string `json:"node_id"` + Event string `json:"event"` + MessageType protocol.MessageType `json:"message_type"` + MessageID uint64 `json:"message_id"` + From string `json:"from"` + To string `json:"to"` + FileName string `json:"file_name,omitempty"` + BodySize int `json:"body_size"` +} + +// Logger 负责接收事件并将其写入外部介质。 +type Logger interface { + LogEvent(Event) error +} + +// NoopLogger 是默认的空实现。 +type NoopLogger struct{} + +// LogEvent 对空日志实现始终返回 nil。 +func (NoopLogger) LogEvent(Event) error { + return nil +} + +// JSONLLogger 以 JSONL 形式追加写日志文件。 +type JSONLLogger struct { + mu sync.Mutex + closeOnce sync.Once + closeErr error + file *os.File +} + +// NewJSONLLogger 创建一个线程安全的 JSONL 文件日志器。 +func NewJSONLLogger(path string) (*JSONLLogger, error) { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, err + } + + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return nil, err + } + + return &JSONLLogger{file: file}, nil +} + +// LogEvent 以单行 JSON 的形式追加一条事件。 +func (l *JSONLLogger) LogEvent(event Event) error { + line, err := json.Marshal(event) + if err != nil { + return err + } + + l.mu.Lock() + defer l.mu.Unlock() + + if _, err := l.file.Write(append(line, '\n')); err != nil { + return err + } + + return nil +} + +// Close 关闭底层文件;重复调用是安全的。 +func (l *JSONLLogger) Close() error { + l.closeOnce.Do(func() { + l.closeErr = l.file.Close() + }) + + return l.closeErr +} + +// IsBusinessMessage 判断消息是否属于要参与 A-C-B 时延分析的业务消息。 +func IsBusinessMessage(msg protocol.Message) bool { + switch msg.Type { + case protocol.MessageTypeText, protocol.MessageTypeFile: + return true + default: + return false + } +} + +// NewMessageEvent 用当前 UTC 时间为一条业务消息构造事件。 +func NewMessageEvent(nodeRole, nodeID, eventName string, msg protocol.Message) Event { + return NewMessageEventAt(time.Now().UTC().UnixNano(), nodeRole, nodeID, eventName, msg) +} + +// NewMessageEventAt 用指定的 UnixNano 时间为一条业务消息构造事件。 +func NewMessageEventAt(tsUnixNano int64, nodeRole, nodeID, eventName string, msg protocol.Message) Event { + return Event{ + TsUnixNano: tsUnixNano, + NodeRole: nodeRole, + NodeID: nodeID, + Event: eventName, + MessageType: msg.Type, + MessageID: msg.ID, + From: msg.From, + To: msg.To, + FileName: msg.FileName, + BodySize: len(msg.Body), + } +} + +// LogBestEffort 写一条事件,失败时静默忽略,避免打断主收发流程。 +func LogBestEffort(logger Logger, event Event) { + if logger == nil { + return + } + + _ = logger.LogEvent(event) +} + +// LogMessageEvent 为业务消息构造并写入一条事件。 +func LogMessageEvent(logger Logger, nodeRole, nodeID, eventName string, msg protocol.Message) { + if !IsBusinessMessage(msg) { + return + } + + LogBestEffort(logger, NewMessageEvent(nodeRole, nodeID, eventName, msg)) +} + +// LogMessageEventAt 为业务消息写入一条指定时间戳的事件。 +func LogMessageEventAt(logger Logger, nodeRole, nodeID, eventName string, tsUnixNano int64, msg protocol.Message) { + if !IsBusinessMessage(msg) { + return + } + + LogBestEffort(logger, NewMessageEventAt(tsUnixNano, nodeRole, nodeID, eventName, msg)) +} diff --git a/cmd/internal/latencylog/logger_test.go b/cmd/internal/latencylog/logger_test.go new file mode 100644 index 0000000..d1850fe --- /dev/null +++ b/cmd/internal/latencylog/logger_test.go @@ -0,0 +1,131 @@ +package latencylog + +import ( + "bufio" + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + + "omnisocketgo/cmd/internal/protocol" +) + +func TestJSONLLoggerWritesOneEventPerLine(t *testing.T) { + path := filepath.Join(t.TempDir(), "latency.jsonl") + + logger, err := NewJSONLLogger(path) + if err != nil { + t.Fatalf("NewJSONLLogger() error = %v", err) + } + t.Cleanup(func() { + _ = logger.Close() + }) + + event := Event{ + TsUnixNano: 123, + NodeRole: NodeRolePeer, + NodeID: "peer-a", + Event: EventAAppPrepBegin, + MessageType: protocol.MessageTypeText, + MessageID: 1, + From: "peer-a", + To: "peer-b", + BodySize: 5, + } + if err := logger.LogEvent(event); err != nil { + t.Fatalf("LogEvent() error = %v", err) + } + + file, err := os.Open(path) + if err != nil { + t.Fatalf("os.Open() error = %v", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + if !scanner.Scan() { + t.Fatal("expected one JSONL line, got none") + } + + var got Event + if err := json.Unmarshal(scanner.Bytes(), &got); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if got != event { + t.Fatalf("event mismatch: got %+v want %+v", got, event) + } + if scanner.Scan() { + t.Fatal("expected exactly one JSONL line") + } + if err := scanner.Err(); err != nil { + t.Fatalf("scanner.Err() = %v", err) + } +} + +func TestJSONLLoggerHandlesConcurrentWrites(t *testing.T) { + path := filepath.Join(t.TempDir(), "latency.jsonl") + + logger, err := NewJSONLLogger(path) + if err != nil { + t.Fatalf("NewJSONLLogger() error = %v", err) + } + t.Cleanup(func() { + _ = logger.Close() + }) + + const total = 32 + + var wg sync.WaitGroup + for i := 0; i < total; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + + err := logger.LogEvent(Event{ + TsUnixNano: int64(i + 1), + NodeRole: NodeRoleServer, + NodeID: protocol.ServerPeerID, + Event: EventBAppRecv, + MessageType: protocol.MessageTypeFile, + MessageID: uint64(i + 1), + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + BodySize: 3, + }) + if err != nil { + t.Errorf("LogEvent() error = %v", err) + } + }() + } + wg.Wait() + + file, err := os.Open(path) + if err != nil { + t.Fatalf("os.Open() error = %v", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + var count int + seen := make(map[uint64]bool, total) + for scanner.Scan() { + var event Event + if err := json.Unmarshal(scanner.Bytes(), &event); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + count++ + seen[event.MessageID] = true + } + if err := scanner.Err(); err != nil { + t.Fatalf("scanner.Err() = %v", err) + } + if count != total { + t.Fatalf("line count = %d, want %d", count, total) + } + if len(seen) != total { + t.Fatalf("unique message count = %d, want %d", len(seen), total) + } +} diff --git a/cmd/internal/latencylog/summary.go b/cmd/internal/latencylog/summary.go new file mode 100644 index 0000000..13337b2 --- /dev/null +++ b/cmd/internal/latencylog/summary.go @@ -0,0 +1,254 @@ +package latencylog + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + + "omnisocketgo/cmd/internal/protocol" +) + +// Summary 是针对单条消息的时延的规则列表。 +var requiredTimestampNames = []string{ + EventAAppPrepBegin, // A 端应用开始准备这条消息 + EventATXSched, // A 端进入 Linux qdisc 之前 + EventATXSoftware, // A 端即将交给网卡驱动 + EventBRXSoftware, // B 端网卡驱动把数据交给 Linux 接收栈 + EventBAppRecv, // B 端应用真正读到完整消息 + EventBPersistEnd, // B 端写盘完成 +} + +// Summary 是针对单条消息的时延整理结果。 +type Summary struct { + MessageType protocol.MessageType `json:"message_type"` //消息类型 + MessageID uint64 `json:"message_id"` //消息ID + From string `json:"from"` //发送方 + To string `json:"to"` //接收方 + FileName string `json:"file_name,omitempty"` //文件名(仅文件消息) + BodySize int `json:"body_size"` //消息体大小(字节数) + Timestamps map[string]int64 `json:"timestamps"` //事件时间戳,key 是事件名称,value 是 UnixNano 时间戳 + + AProcessingLatencyNS *int64 `json:"a_processing_latency_ns,omitempty"` // A 处理时延:A_TX_SCHED - A_APP_PREP_BEGIN + AQueueLatencyNS *int64 `json:"a_queue_latency_ns,omitempty"` // A 排队时延:A_TX_SOFTWARE - A_TX_SCHED + ABTransportPropagationBQueueLatencyNS *int64 `json:"a_b_transport_propagation_b_queue_latency_ns,omitempty"` // A-B 传输时延 + B 排队时延:B_APP_RECV - A_TX_SOFTWARE + BKernelReceivePathLatencyNS *int64 `json:"b_kernel_receive_path_latency_ns,omitempty"` // B 内核接收路径近似:B_APP_RECV - B_RX_SOFTWARE + BProcessingLatencyNS *int64 `json:"b_processing_latency_ns,omitempty"` // B 处理时延:B_PERSIST_END - B_APP_RECV + EndToEndLatencyNS *int64 `json:"end_to_end_latency_ns,omitempty"` // 端到端时延:B_PERSIST_END - A_APP_PREP_BEGIN + MissingTimestamps []string `json:"missing_timestamps,omitempty"` // 缺失的时间戳列表,包含 requiredTimestampNames 中但在原始事件中没有的事件名称 +} + +// LoadEventsFromFiles 从JSONL 原始日志文件中加载事件。 +type messageKey struct { + MessageType protocol.MessageType //消息类型 + MessageID uint64 //消息ID + From string //发送方 + To string //接收方 +} + +// LoadEventsFromFiles 从多个 JSONL 原始日志文件中加载事件。 +func LoadEventsFromFiles(paths []string) ([]Event, error) { + var events []Event + for _, path := range paths { + fileEvents, err := LoadEventsFromFile(path) + if err != nil { + return nil, err + } + events = append(events, fileEvents...) + } + + return events, nil +} + +// LoadEventsFromFile 从单个 JSONL 原始日志文件中加载事件。 +func LoadEventsFromFile(path string) ([]Event, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("latencylog: open raw log %s: %w", path, err) + } + defer file.Close() + + var events []Event + scanner := bufio.NewScanner(file) + for scanner.Scan() { + if len(scanner.Bytes()) == 0 { + continue + } + + var event Event + if err := json.Unmarshal(scanner.Bytes(), &event); err != nil { //解析 JSONL 行失败,返回错误 + return nil, fmt.Errorf("latencylog: decode event from %s: %w", path, err) + } + events = append(events, event) + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("latencylog: scan raw log %s: %w", path, err) + } + + return events, nil +} + +// SummarizeEvents 将原始事件整理成按消息分组的时延结果。 +func SummarizeEvents(events []Event) []Summary { + grouped := make(map[messageKey]*Summary) + + for _, event := range events { + if !IsBusinessEvent(event) { + continue + } + + key := messageKey{ + MessageType: event.MessageType, + MessageID: event.MessageID, + From: event.From, + To: event.To, + } + + summary, ok := grouped[key] + if !ok { + summary = &Summary{ + MessageType: event.MessageType, + MessageID: event.MessageID, + From: event.From, + To: event.To, + FileName: event.FileName, + BodySize: event.BodySize, + Timestamps: make(map[string]int64), + } + grouped[key] = summary + } + + if summary.FileName == "" { + summary.FileName = event.FileName + } + if event.BodySize > 0 { + summary.BodySize = event.BodySize + } + + if existing, exists := summary.Timestamps[event.Event]; !exists || event.TsUnixNano < existing { + summary.Timestamps[event.Event] = event.TsUnixNano + } + } + + summaries := make([]Summary, 0, len(grouped)) + for _, summary := range grouped { + completeSummary(summary) //补全时延指标和缺失时间戳信息 + summaries = append(summaries, *summary) + } + //对整理结果进行排序,先按发送方、再按接收方、再按消息 ID、最后按消息类型排序,保证输出的稳定性和可读性。 + sort.Slice(summaries, func(i, j int) bool { + if summaries[i].From != summaries[j].From { + return summaries[i].From < summaries[j].From + } + if summaries[i].To != summaries[j].To { + return summaries[i].To < summaries[j].To + } + if summaries[i].MessageID != summaries[j].MessageID { + return summaries[i].MessageID < summaries[j].MessageID + } + return summaries[i].MessageType < summaries[j].MessageType + }) + + return summaries +} + +// WriteSummariesJSONL 将整理结果写成 JSONL 汇总文件。 +func WriteSummariesJSONL(path string, summaries []Summary) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("latencylog: create summary dir for %s: %w", path, err) + } + + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + if err != nil { + return fmt.Errorf("latencylog: open summary file %s: %w", path, err) + } + defer file.Close() + + writer := bufio.NewWriter(file) + for _, summary := range summaries { //将每条整理结果编码成 JSONL 行并写入文件 + line, err := json.Marshal(summary) + if err != nil { + return fmt.Errorf("latencylog: encode summary for message %d: %w", summary.MessageID, err) + } + if _, err := writer.Write(append(line, '\n')); err != nil { + return fmt.Errorf("latencylog: write summary file %s: %w", path, err) + } + } + + if err := writer.Flush(); err != nil { //将缓冲区内容写入文件 + return fmt.Errorf("latencylog: flush summary file %s: %w", path, err) + } + + return nil +} + +// completeSummary 根据事件时间戳计算时延指标,并找出缺失的时间戳。 +func completeSummary(summary *Summary) { + summary.MissingTimestamps = missingTimestampNames(summary.Timestamps) + + if value := subtractIfPresent(summary.Timestamps, EventATXSched, EventAAppPrepBegin); value != nil { + summary.AProcessingLatencyNS = value + } + if value := subtractIfPresent(summary.Timestamps, EventATXSoftware, EventATXSched); value != nil { + summary.AQueueLatencyNS = value + } + if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventATXSoftware); value != nil { + summary.ABTransportPropagationBQueueLatencyNS = value + } + if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventBRXSoftware); value != nil { + summary.BKernelReceivePathLatencyNS = value + } + if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventBAppRecv); value != nil { + summary.BProcessingLatencyNS = value + } + if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventAAppPrepBegin); value != nil { + summary.EndToEndLatencyNS = value + } +} + +// 返回 requiredTimestampNames 中哪些在给定的 timestamps 中缺失。 +func missingTimestampNames(timestamps map[string]int64) []string { + var missing []string + for _, name := range requiredTimestampNames { + if _, ok := timestamps[name]; !ok { + missing = append(missing, name) + } + } + + return missing +} + +// 如果 timestamps 中同时存在 endName 和 beginName,则返回它们的差值;否则返回 nil。 +func subtractIfPresent(timestamps map[string]int64, endName, beginName string) *int64 { + end, ok := timestamps[endName] + if !ok { + return nil + } + begin, ok := timestamps[beginName] + if !ok { + return nil + } + + value := end - begin + return &value +} + +// 判断事件是否是业务相关的时延事件(其中一项) +func IsBusinessEvent(event Event) bool { + switch event.Event { + case EventAAppPrepBegin, + EventATXSched, + EventATXSoftware, + EventATXHardware, + EventBRXHardware, + EventBRXSoftware, + EventBAppRecv, + EventBPersistBegin, + EventBPersistEnd: + return true + default: + return false + } +} diff --git a/cmd/internal/latencylog/summary_chart.go b/cmd/internal/latencylog/summary_chart.go new file mode 100644 index 0000000..1953fe9 --- /dev/null +++ b/cmd/internal/latencylog/summary_chart.go @@ -0,0 +1,440 @@ +package latencylog + +import ( + "bufio" + "fmt" + "html/template" + "os" + "path/filepath" + "strings" + + "omnisocketgo/cmd/internal/protocol" +) + +const summaryChartHTMLTemplate = ` + + + + + Latency Summary Chart + + + +
+

Latency Summary

+

A simple per-message end-to-end latency chart generated from summarized JSONL records.

+ +
+
+
Messages
+
{{.TotalMessages}}
+
+
+
With End-To-End
+
{{.MessagesWithEndToEnd}}
+
+
+
Average End-To-End
+
{{.AverageEndToEnd}}
+
+
+
Max End-To-End
+
{{.MaxEndToEnd}}
+
+
+ +
+ {{range .Legend}} + + + {{.Label}} + + {{end}} +
+ + {{if .Rows}} +
+ {{range .Rows}} +
+
+

{{.Title}}

+
{{.EndToEnd}}
+
+
{{.Subtitle}}
+
+ {{range .Segments}} +
+ {{end}} +
+ {{if .Segments}} +
+ {{range .Segments}} + + + {{.Label}} {{.Value}} + + {{end}} +
+ {{end}} + {{if .MissingTimestamps}} +
Missing timestamps: {{.MissingTimestamps}}
+ {{end}} +
+ {{end}} +
+ {{else}} +
No summarized messages were available for chart rendering.
+ {{end}} +
+ + +` + +type summaryChartPage struct { + TotalMessages int + MessagesWithEndToEnd int + AverageEndToEnd string + MaxEndToEnd string + Legend []summaryChartLegendItem + Rows []summaryChartRow +} + +type summaryChartLegendItem struct { + Label string + Color string +} + +type summaryChartRow struct { + Title string + Subtitle string + EndToEnd string + MissingTimestamps string + Segments []summaryChartSegment +} + +type summaryChartSegment struct { + Label string + Value string + Color string + WidthPercent float64 +} + +type summaryChartSegmentMetric struct { + label string + value *int64 + color string +} + +// WriteSummariesHTMLChart 将整理结果写成一个可直接在浏览器中打开的简单 HTML 图表。 +func WriteSummariesHTMLChart(path string, summaries []Summary) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("latencylog: create chart dir for %s: %w", path, err) + } + + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + if err != nil { + return fmt.Errorf("latencylog: open chart file %s: %w", path, err) + } + defer file.Close() + + page := buildSummaryChartPage(summaries) + tmpl, err := template.New("summary-chart").Parse(summaryChartHTMLTemplate) + if err != nil { + return fmt.Errorf("latencylog: parse chart template: %w", err) + } + + writer := bufio.NewWriter(file) + if err := tmpl.Execute(writer, page); err != nil { + return fmt.Errorf("latencylog: render chart %s: %w", path, err) + } + if err := writer.Flush(); err != nil { + return fmt.Errorf("latencylog: flush chart %s: %w", path, err) + } + + return nil +} + +func buildSummaryChartPage(summaries []Summary) summaryChartPage { + page := summaryChartPage{ + TotalMessages: len(summaries), + Legend: []summaryChartLegendItem{ + {Label: "A processing", Color: "var(--a-proc)"}, + {Label: "A queue", Color: "var(--a-queue)"}, + {Label: "Transport + B queue", Color: "var(--transport)"}, + {Label: "B processing", Color: "var(--b-proc)"}, + {Label: "Unknown / missing", Color: "var(--unknown)"}, + }, + Rows: make([]summaryChartRow, 0, len(summaries)), + } + + var ( + endToEndValues []int64 + totalEndToEnd int64 + maxEndToEnd int64 + ) + + for _, summary := range summaries { + page.Rows = append(page.Rows, buildSummaryChartRow(summary)) + + if summary.EndToEndLatencyNS == nil { + continue + } + endToEnd := *summary.EndToEndLatencyNS + endToEndValues = append(endToEndValues, endToEnd) + totalEndToEnd += endToEnd + if endToEnd > maxEndToEnd { + maxEndToEnd = endToEnd + } + } + + page.MessagesWithEndToEnd = len(endToEndValues) + page.AverageEndToEnd = "n/a" + page.MaxEndToEnd = "n/a" + if len(endToEndValues) > 0 { + page.AverageEndToEnd = formatLatencyNS(totalEndToEnd / int64(len(endToEndValues))) + page.MaxEndToEnd = formatLatencyNS(maxEndToEnd) + } + + return page +} + +func buildSummaryChartRow(summary Summary) summaryChartRow { + row := summaryChartRow{ + Title: buildSummaryChartTitle(summary), + Subtitle: buildSummaryChartSubtitle(summary), + EndToEnd: "End-to-end: n/a", + MissingTimestamps: strings.Join(summary.MissingTimestamps, ", "), + } + + if summary.EndToEndLatencyNS == nil || *summary.EndToEndLatencyNS <= 0 { + return row + } + + total := *summary.EndToEndLatencyNS + row.EndToEnd = fmt.Sprintf("End-to-end: %s", formatLatencyNS(total)) + + metrics := []summaryChartSegmentMetric{ + {label: "A processing", value: summary.AProcessingLatencyNS, color: "var(--a-proc)"}, + {label: "A queue", value: summary.AQueueLatencyNS, color: "var(--a-queue)"}, + {label: "Transport + B queue", value: summary.ABTransportPropagationBQueueLatencyNS, color: "var(--transport)"}, + {label: "B processing", value: summary.BProcessingLatencyNS, color: "var(--b-proc)"}, + } + + var knownTotal int64 + for _, metric := range metrics { + if metric.value == nil || *metric.value <= 0 { + continue + } + knownTotal += *metric.value + } + + scaleTotal := total + if knownTotal > scaleTotal { + scaleTotal = knownTotal + } + if scaleTotal <= 0 { + return row + } + + for _, metric := range metrics { + if metric.value == nil || *metric.value <= 0 { + continue + } + row.Segments = append(row.Segments, summaryChartSegment{ + Label: metric.label, + Value: formatLatencyNS(*metric.value), + Color: metric.color, + WidthPercent: float64(*metric.value) * 100 / float64(scaleTotal), + }) + } + + if remaining := total - knownTotal; remaining > 0 { + row.Segments = append(row.Segments, summaryChartSegment{ + Label: "Unknown / missing", + Value: formatLatencyNS(remaining), + Color: "var(--unknown)", + WidthPercent: float64(remaining) * 100 / float64(scaleTotal), + }) + } + + return row +} + +func buildSummaryChartTitle(summary Summary) string { + if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" { + return fmt.Sprintf("%s #%d (%s)", summary.MessageType, summary.MessageID, summary.FileName) + } + + return fmt.Sprintf("%s #%d", summary.MessageType, summary.MessageID) +} + +func buildSummaryChartSubtitle(summary Summary) string { + parts := []string{ + fmt.Sprintf("%s -> %s", summary.From, summary.To), + fmt.Sprintf("%d bytes", summary.BodySize), + } + + if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" { + parts = append(parts, fmt.Sprintf("file: %s", summary.FileName)) + } + + return strings.Join(parts, " | ") +} + +func formatLatencyNS(ns int64) string { + return fmt.Sprintf("%.3f ms", float64(ns)/1_000_000) +} diff --git a/cmd/internal/latencylog/summary_chart_test.go b/cmd/internal/latencylog/summary_chart_test.go new file mode 100644 index 0000000..a7a7d0d --- /dev/null +++ b/cmd/internal/latencylog/summary_chart_test.go @@ -0,0 +1,67 @@ +package latencylog + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "omnisocketgo/cmd/internal/protocol" +) + +func TestWriteSummariesHTMLChart(t *testing.T) { + aProcessing := int64(20_000_000) + aQueue := int64(10_000_000) + transport := int64(40_000_000) + bProcessing := int64(30_000_000) + endToEnd := int64(100_000_000) + + summaries := []Summary{ + { + MessageType: protocol.MessageTypeText, + MessageID: 7, + From: "peer-a", + To: "peer-b", + BodySize: 5, + AProcessingLatencyNS: &aProcessing, + AQueueLatencyNS: &aQueue, + ABTransportPropagationBQueueLatencyNS: &transport, + BProcessingLatencyNS: &bProcessing, + EndToEndLatencyNS: &endToEnd, + }, + { + MessageType: protocol.MessageTypeFile, + MessageID: 8, + From: "peer-b", + To: "peer-a", + FileName: "payload.bin", + BodySize: 128, + MissingTimestamps: []string{EventBRXSoftware}, + }, + } + + path := filepath.Join(t.TempDir(), "charts", "latency-summary.html") + if err := WriteSummariesHTMLChart(path, summaries); err != nil { + t.Fatalf("WriteSummariesHTMLChart() error = %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("os.ReadFile() error = %v", err) + } + + content := string(data) + for _, want := range []string{ + "Latency Summary", + "text #7", + "peer-a -> peer-b | 5 bytes", + "End-to-end: 100.000 ms", + "A processing 20.000 ms", + "file #8 (payload.bin)", + "Missing timestamps: B_RX_SOFTWARE", + } { + if !strings.Contains(content, want) { + t.Fatalf("chart content missing %q\n%s", want, content) + } + } +} diff --git a/cmd/internal/latencylog/summary_test.go b/cmd/internal/latencylog/summary_test.go new file mode 100644 index 0000000..a7a0eb6 --- /dev/null +++ b/cmd/internal/latencylog/summary_test.go @@ -0,0 +1,144 @@ +package latencylog + +import ( + "bufio" + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + + "omnisocketgo/cmd/internal/protocol" +) + +func TestSummarizeEventsComputesLatencyMetrics(t *testing.T) { + events := []Event{ + {TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 230, Event: EventBPersistBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"}, + } + + summaries := SummarizeEvents(events) + if len(summaries) != 1 { + t.Fatalf("summary count = %d, want 1", len(summaries)) + } + + summary := summaries[0] + if got := ptrValue(summary.AProcessingLatencyNS); got != 20 { + t.Fatalf("AProcessingLatencyNS = %d, want 20", got) + } + if got := ptrValue(summary.AQueueLatencyNS); got != 20 { + t.Fatalf("AQueueLatencyNS = %d, want 20", got) + } + if got := ptrValue(summary.ABTransportPropagationBQueueLatencyNS); got != 80 { + t.Fatalf("ABTransportPropagationBQueueLatencyNS = %d, want 80", got) + } + if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 { + t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got) + } + if got := ptrValue(summary.BProcessingLatencyNS); got != 40 { + t.Fatalf("BProcessingLatencyNS = %d, want 40", got) + } + if got := ptrValue(summary.EndToEndLatencyNS); got != 160 { + t.Fatalf("EndToEndLatencyNS = %d, want 160", got) + } + if got := summary.Timestamps[EventBRXSoftware]; got != 180 { + t.Fatalf("timestamps[%q] = %d, want 180", EventBRXSoftware, got) + } + if len(summary.MissingTimestamps) != 0 { + t.Fatalf("MissingTimestamps = %v, want empty", summary.MissingTimestamps) + } +} + +func TestSummarizeEventsReportsMissingTimestamps(t *testing.T) { + events := []Event{ + {TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 240, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"}, + } + + summaries := SummarizeEvents(events) + if len(summaries) != 1 { + t.Fatalf("summary count = %d, want 1", len(summaries)) + } + + wantMissing := []string{EventATXSched, EventATXSoftware, EventBRXSoftware, EventBAppRecv} + if !reflect.DeepEqual(summaries[0].MissingTimestamps, wantMissing) { + t.Fatalf("MissingTimestamps = %v, want %v", summaries[0].MissingTimestamps, wantMissing) + } + if summaries[0].AProcessingLatencyNS != nil { + t.Fatalf("AProcessingLatencyNS = %v, want nil", ptrValue(summaries[0].AProcessingLatencyNS)) + } + if summaries[0].EndToEndLatencyNS == nil { + t.Fatal("EndToEndLatencyNS = nil, want non-nil because endpoints are present") + } +} + +func TestLoadAndWriteSummaryFiles(t *testing.T) { + rawPath := filepath.Join(t.TempDir(), "raw.jsonl") + rawLogger, err := NewJSONLLogger(rawPath) + if err != nil { + t.Fatalf("NewJSONLLogger() error = %v", err) + } + t.Cleanup(func() { + _ = rawLogger.Close() + }) + + for _, event := range []Event{ + {TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"}, + {TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"}, + } { + if err := rawLogger.LogEvent(event); err != nil { + t.Fatalf("LogEvent() error = %v", err) + } + } + + events, err := LoadEventsFromFile(rawPath) + if err != nil { + t.Fatalf("LoadEventsFromFile() error = %v", err) + } + + summaryPath := filepath.Join(t.TempDir(), "summary.jsonl") + if err := WriteSummariesJSONL(summaryPath, SummarizeEvents(events)); err != nil { + t.Fatalf("WriteSummariesJSONL() error = %v", err) + } + + file, err := os.Open(summaryPath) + if err != nil { + t.Fatalf("os.Open() error = %v", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + if !scanner.Scan() { + t.Fatal("expected one summary line, got none") + } + + var summary Summary + if err := json.Unmarshal(scanner.Bytes(), &summary); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if summary.MessageID != 3 { + t.Fatalf("MessageID = %d, want 3", summary.MessageID) + } + if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 { + t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got) + } + if got := ptrValue(summary.EndToEndLatencyNS); got != 160 { + t.Fatalf("EndToEndLatencyNS = %d, want 160", got) + } +} + +func ptrValue(value *int64) int64 { + if value == nil { + return 0 + } + return *value +} diff --git a/cmd/internal/peer/client.go b/cmd/internal/peer/client.go new file mode 100644 index 0000000..1c586fa --- /dev/null +++ b/cmd/internal/peer/client.go @@ -0,0 +1,179 @@ +package peer + +import ( + "fmt" + "net" + "os" + "path/filepath" + "sync/atomic" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" +) + +var dialServer = net.Dial + +type clientOptions struct { + logger latencylog.Logger +} + +// Option 用于配置 Client 的可选行为,例如时延日志。 +type Option func(*clientOptions) + +// WithLogger 为 client 注入时延日志记录器。 +func WithLogger(logger latencylog.Logger) Option { + return func(options *clientOptions) { + options.logger = logger + } +} + +// Client 表示一个已经连接到 server 的 peer。 +type Client struct { + id string + conn *transport.TCPConn + logger latencylog.Logger + + nextID uint64 +} + +// Dial 连接到 server,并立即发送 register 消息完成身份注册。 +func Dial(serverAddr, peerID string, opts ...Option) (*Client, error) { + options := clientOptions{ + logger: latencylog.NoopLogger{}, + } + for _, opt := range opts { + opt(&options) + } + if options.logger == nil { + options.logger = latencylog.NoopLogger{} + } + + rawConn, err := dialServer("tcp", serverAddr) //使用 net.Dial 连接到 serverAddr 指定的 TCP 地址,返回一个 net.Conn。 + if err != nil { + return nil, fmt.Errorf("peer: dial server: %w", err) + } + + conn, err := transport.NewTCPConn( + rawConn, + transport.WithLogger(options.logger, latencylog.NodeRolePeer, peerID), + ) + if err != nil { + _ = rawConn.Close() + return nil, fmt.Errorf("peer: create transport conn: %w", err) + } + client := &Client{ + id: peerID, + conn: conn, + logger: options.logger, + } + + if err := conn.Send(protocol.Message{ //向 server 发送一条 register 消息,完成身份注册。 + Type: protocol.MessageTypeRegister, + From: peerID, + To: protocol.ServerPeerID, + }); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("peer: register with server: %w", err) + } + + return client, nil +} + +// ID 返回当前 client 的 peer 标识。 +func (c *Client) ID() string { + return c.id +} + +// SendText 向目标 peer 发送一条文本消息。 +func (c *Client) SendText(to, body string) error { + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: c.nextMessageID(), + From: c.id, + To: to, + } + // 记录 A 端应用开始准备消息的时间点。 + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg) + + msg.Body = []byte(body) + + return c.conn.Send(msg) +} + +// SendFile 向目标 peer 发送一条文件消息。 +func (c *Client) SendFile(to, fileName string, body []byte) error { + msg := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: c.nextMessageID(), + From: c.id, + To: to, + FileName: fileName, + } + // 记录 A 端应用开始准备消息的时间点。 + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg) + + bodyCopy := make([]byte, len(body)) + copy(bodyCopy, body) + + msg.Body = bodyCopy + + return c.conn.Send(msg) +} + +// SendFilePath 从本地文件读取内容并发送给目标 peer。 +func (c *Client) SendFilePath(to, path string) error { + msg := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: c.nextMessageID(), + From: c.id, + To: to, + FileName: filepath.Base(path), + } + // 记录 A 端应用开始准备消息的时间点。 + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg) + + body, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("peer: read file %s: %w", path, err) + } + + msg.Body = body + + return c.conn.Send(msg) +} + +// Receive 读取一条来自 server 的消息。 +func (c *Client) Receive() (protocol.Message, error) { + msg, err := c.conn.Receive() //从底层 TCP 连接读取一条消息,返回一个 protocol.Message 结构体。 + if err != nil { + return protocol.Message{}, fmt.Errorf("peer: receive from server: %w", err) + } + + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg) + + return msg, nil +} + +// ReceiveLoop 持续接收 server 消息并交给 handler 处理。 +func (c *Client) ReceiveLoop(handler func(protocol.Message) error) error { + return c.conn.ReceiveLoop(func(msg protocol.Message) error { + switch msg.Type { + case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError: + // 记录 B 端应用真正读到完整消息的时间点。 + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg) + return handler(msg) + default: + return fmt.Errorf("peer: unexpected message type from server: %s", msg.Type) + } + }) +} + +// Close 关闭与 server 的连接。 +func (c *Client) Close() error { + return c.conn.Close() +} + +func (c *Client) nextMessageID() uint64 { + return atomic.AddUint64(&c.nextID, 1) +} diff --git a/cmd/internal/peer/client_linux_test.go b/cmd/internal/peer/client_linux_test.go new file mode 100644 index 0000000..f5432af --- /dev/null +++ b/cmd/internal/peer/client_linux_test.go @@ -0,0 +1,188 @@ +//go:build linux + +package peer + +import ( + "net" + "strings" + "sync" + "testing" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/server" +) + +func TestClientsExchangeMessagesWithLinuxTimestamps(t *testing.T) { + hub := server.NewHub() + serverAddr, cleanup := startRealHubServer(t, hub) + defer cleanup() + + peerALogger := &recordingLogger{} + peerA, err := Dial(serverAddr, "peer-a", WithLogger(peerALogger)) + if err != nil { + t.Fatalf("Dial(peer-a) error = %v", err) + } + defer func() { _ = peerA.Close() }() + + peerBLogger := &recordingLogger{} + peerB, err := Dial(serverAddr, "peer-b", WithLogger(peerBLogger)) + if err != nil { + t.Fatalf("Dial(peer-b) error = %v", err) + } + defer func() { _ = peerB.Close() }() + + inboxDir := t.TempDir() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + + if err := peerA.SendText("peer-b", "hello"); err != nil { + t.Fatalf("SendText() error = %v", err) + } + textMsg, err := peerB.Receive() + if err != nil { + t.Fatalf("peerB.Receive(text) error = %v", err) + } + if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil { + t.Fatalf("peerB.PersistMessage(text) error = %v", err) + } + + if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil { + t.Fatalf("SendFile() error = %v", err) + } + fileMsg, err := peerB.Receive() + if err != nil { + t.Fatalf("peerB.Receive(file) error = %v", err) + } + if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil { + t.Fatalf("peerB.PersistMessage(file) error = %v", err) + } + + waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 1, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a text kernel timestamps") + waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 2, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a file kernel timestamps") + waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 1, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b text receive timestamps") + waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 2, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b file receive timestamps") +} + +func startRealHubServer(t *testing.T, hub *server.Hub) (string, func()) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen() error = %v", err) + } + + var ( + wg sync.WaitGroup + stop = make(chan struct{}) + errOnce sync.Once + ) + + wg.Add(1) + go func() { + defer wg.Done() + for { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + select { + case <-stop: + return + default: + } + if strings.Contains(acceptErr.Error(), "closed") { + return + } + t.Errorf("listener.Accept() error = %v", acceptErr) + return + } + + wg.Add(1) + go func(rawConn net.Conn) { + defer wg.Done() + if serveErr := hub.ServeConn(rawConn); serveErr != nil && !isExpectedHubServeExit(serveErr) { + errOnce.Do(func() { + t.Logf("hub.ServeConn() ended with %v", serveErr) + }) + } + }(conn) + } + }() + + cleanup := func() { + close(stop) + _ = listener.Close() + wg.Wait() + } + + return listener.Addr().String(), cleanup +} + +func hasMessageEvents(events []latencylog.Event, messageID uint64, wantEvents ...string) bool { + seen := make(map[string]bool, len(wantEvents)) + for _, event := range events { + if event.MessageID != messageID { + continue + } + if event.TsUnixNano <= 0 { + return false + } + seen[event.Event] = true + } + + for _, wantEvent := range wantEvents { + if !seen[wantEvent] { + return false + } + } + + return true +} + +func isExpectedHubServeExit(err error) bool { + if err == nil { + return true + } + + message := err.Error() + return strings.Contains(message, "closed") || strings.Contains(message, "protocol: read frame: EOF") +} + +func TestLinuxTimestampedReceivePreservesBusinessMessageShape(t *testing.T) { + hub := server.NewHub() + serverAddr, cleanup := startRealHubServer(t, hub) + defer cleanup() + + peerA, err := Dial(serverAddr, "peer-a") + if err != nil { + t.Fatalf("Dial(peer-a) error = %v", err) + } + defer func() { _ = peerA.Close() }() + + peerB, err := Dial(serverAddr, "peer-b") + if err != nil { + t.Fatalf("Dial(peer-b) error = %v", err) + } + defer func() { _ = peerB.Close() }() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + + want := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 1, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0xde, 0xad, 0xbe, 0xef}, + } + if err := peerA.SendFile(want.To, want.FileName, want.Body); err != nil { + t.Fatalf("SendFile() error = %v", err) + } + + got, err := peerB.Receive() + if err != nil { + t.Fatalf("peerB.Receive() error = %v", err) + } + if got.Type != want.Type || got.ID != want.ID || got.From != want.From || got.To != want.To || got.FileName != want.FileName || string(got.Body) != string(want.Body) { + t.Fatalf("received message mismatch: got %+v want %+v", got, want) + } +} diff --git a/cmd/internal/peer/client_test.go b/cmd/internal/peer/client_test.go new file mode 100644 index 0000000..cdf7cef --- /dev/null +++ b/cmd/internal/peer/client_test.go @@ -0,0 +1,770 @@ +package peer + +import ( + "bytes" + "encoding/json" + "net" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + "testing" + "time" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/server" + "omnisocketgo/cmd/internal/transport" +) + +type recordingLogger struct { + mu sync.Mutex + events []latencylog.Event +} + +func (l *recordingLogger) LogEvent(event latencylog.Event) error { + l.mu.Lock() + defer l.mu.Unlock() + + l.events = append(l.events, event) + return nil +} + +func (l *recordingLogger) Events() []latencylog.Event { + l.mu.Lock() + defer l.mu.Unlock() + + return append([]latencylog.Event(nil), l.events...) +} + +type failingLogger struct{} + +func (failingLogger) LogEvent(latencylog.Event) error { + return net.ErrClosed +} + +func TestDialRegistersPeer(t *testing.T) { + hub := server.NewHub() + cleanup := stubDialToHub(t, hub) + defer cleanup() + + client, err := Dial("ignored", "peer-a") + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer func() { _ = client.Close() }() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") +} + +func TestClientsExchangeTextAndFileMessages(t *testing.T) { + hub := server.NewHub() + cleanup := stubDialToHub(t, hub) + defer cleanup() + + peerA, err := Dial("ignored", "peer-a") + if err != nil { + t.Fatalf("Dial(peer-a) error = %v", err) + } + defer func() { _ = peerA.Close() }() + + peerB, err := Dial("ignored", "peer-b") + if err != nil { + t.Fatalf("Dial(peer-b) error = %v", err) + } + defer func() { _ = peerB.Close() }() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + + received := make(chan protocol.Message, 2) + receiveErr := make(chan error, 1) + go func() { + for i := 0; i < 2; i++ { + msg, err := peerB.Receive() + if err != nil { + receiveErr <- err + return + } + received <- msg + } + receiveErr <- nil + }() + + if err := peerA.SendText("peer-b", "hello"); err != nil { + t.Fatalf("SendText() error = %v", err) + } + fileBody := []byte{0x01, 0x02, 0x03} + if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil { + t.Fatalf("SendFile() error = %v", err) + } + + if err := <-receiveErr; err != nil { + t.Fatalf("peerB.Receive() error = %v", err) + } + + gotFirst := <-received + wantFirst := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + if !reflect.DeepEqual(gotFirst, wantFirst) { + t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst) + } + + gotSecond := <-received + wantSecond := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 2, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: fileBody, + } + if !reflect.DeepEqual(gotSecond, wantSecond) { + t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond) + } +} + +func TestClientReceivesServerErrorForUnknownTarget(t *testing.T) { + hub := server.NewHub() + cleanup := stubDialToHub(t, hub) + defer cleanup() + + client, err := Dial("ignored", "peer-a") + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer func() { _ = client.Close() }() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + if err := client.SendText("missing-peer", "hello"); err != nil { + t.Fatalf("SendText() error = %v", err) + } + + got, err := client.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if got.Type != protocol.MessageTypeError { + t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError) + } + if string(got.Body) != "unknown target: missing-peer" { + t.Fatalf("error body = %q, want unknown target message", got.Body) + } +} + +func TestClientReceiveLoopHandlesForwardedMessages(t *testing.T) { + hub := server.NewHub() + cleanup := stubDialToHub(t, hub) + defer cleanup() + + peerA, err := Dial("ignored", "peer-a") + if err != nil { + t.Fatalf("Dial(peer-a) error = %v", err) + } + defer func() { _ = peerA.Close() }() + + peerB, err := Dial("ignored", "peer-b") + if err != nil { + t.Fatalf("Dial(peer-b) error = %v", err) + } + defer func() { _ = peerB.Close() }() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + + var ( + mu sync.Mutex + got []protocol.Message + ) + loopErr := make(chan error, 1) + go func() { + loopErr <- peerB.ReceiveLoop(func(msg protocol.Message) error { + mu.Lock() + defer mu.Unlock() + got = append(got, msg) + if len(got) == 1 { + return peerB.Close() + } + return nil + }) + }() + + if err := peerA.SendText("peer-b", "hello"); err != nil { + t.Fatalf("SendText() error = %v", err) + } + + err = <-loopErr + if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection")) { + t.Fatalf("ReceiveLoop() error = %v, want close-related error", err) + } + + mu.Lock() + defer mu.Unlock() + want := []protocol.Message{ + { + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + }, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("received messages mismatch: got %+v want %+v", got, want) + } +} + +func TestClientSendLogsLatencyEvents(t *testing.T) { + tests := []struct { + name string + setup func(*testing.T) string + send func(*Client, string) error + wantMsg protocol.Message + wantEvents []string + }{ + { + name: "text", + send: func(client *Client, _ string) error { + return client.SendText("peer-b", "hello") + }, + wantMsg: protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + }, + wantEvents: []string{ + latencylog.EventAAppPrepBegin, + latencylog.EventSendHandoffBegin, + latencylog.EventATXSched, + latencylog.EventATXSoftware, + latencylog.EventSendHandoffEnd, + }, + }, + { + name: "file-bytes", + send: func(client *Client, _ string) error { + return client.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}) + }, + wantMsg: protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 1, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x01, 0x02, 0x03}, + }, + wantEvents: []string{ + latencylog.EventAAppPrepBegin, + latencylog.EventSendHandoffBegin, + latencylog.EventATXSched, + latencylog.EventATXSoftware, + latencylog.EventSendHandoffEnd, + }, + }, + { + name: "file-path", + setup: func(t *testing.T) string { + t.Helper() + + path := filepath.Join(t.TempDir(), "payload.bin") + if err := os.WriteFile(path, []byte{0x01, 0x02, 0x03}, 0o644); err != nil { + t.Fatalf("os.WriteFile() error = %v", err) + } + + return path + }, + send: func(client *Client, path string) error { + return client.SendFilePath("peer-b", path) + }, + wantMsg: protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 1, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x01, 0x02, 0x03}, + }, + wantEvents: []string{ + latencylog.EventAAppPrepBegin, + latencylog.EventSendHandoffBegin, + latencylog.EventATXSched, + latencylog.EventATXSoftware, + latencylog.EventSendHandoffEnd, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputPath := "" + if tt.setup != nil { + inputPath = tt.setup(t) + } + + logger := &recordingLogger{} + clientConn, receiver := newClientTransportPair( + t, + []transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-a")}, + nil, + ) + client := &Client{ + id: "peer-a", + conn: clientConn, + logger: logger, + } + + sendErr := make(chan error, 1) + go func() { + sendErr <- tt.send(client, inputPath) + }() + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("receiver.Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("send() error = %v", err) + } + if !reflect.DeepEqual(got, tt.wantMsg) { + t.Fatalf("message mismatch: got %+v want %+v", got, tt.wantMsg) + } + + events := logger.Events() + if len(events) != len(tt.wantEvents) { + t.Fatalf("event count = %d, want %d", len(events), len(tt.wantEvents)) + } + for i, wantEvent := range tt.wantEvents { + if events[i].Event != wantEvent { + t.Fatalf("event[%d] = %q, want %q", i, events[i].Event, wantEvent) + } + if events[i].MessageID != tt.wantMsg.ID || events[i].From != tt.wantMsg.From || events[i].To != tt.wantMsg.To { + t.Fatalf("event[%d] metadata mismatch: %+v", i, events[i]) + } + } + }) + } +} + +func TestClientReceiveLogsOnlyBusinessMessages(t *testing.T) { + logger := &recordingLogger{} + clientConn, sender := newClientTransportPair( + t, + []transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-b")}, + nil, + ) + client := &Client{ + id: "peer-b", + conn: clientConn, + logger: logger, + } + + textMsg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 21, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(textMsg) + }() + if _, err := client.Receive(); err != nil { + t.Fatalf("client.Receive(text) error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("sender.Send(text) error = %v", err) + } + + errorMsg := protocol.Message{ + Type: protocol.MessageTypeError, + ID: 22, + From: protocol.ServerPeerID, + To: "peer-b", + Body: []byte("failure"), + } + sendErr = make(chan error, 1) + go func() { + sendErr <- sender.Send(errorMsg) + }() + if _, err := client.Receive(); err != nil { + t.Fatalf("client.Receive(error) error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("sender.Send(error) error = %v", err) + } + + events := logger.Events() + if len(events) != 2 { + t.Fatalf("event count = %d, want 2", len(events)) + } + if events[0].Event != latencylog.EventBRXSoftware { + t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBRXSoftware) + } + if events[1].Event != latencylog.EventBAppRecv { + t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBAppRecv) + } + if events[0].MessageID != textMsg.ID || events[1].MessageID != textMsg.ID { + t.Fatalf("message IDs = %d,%d, want %d", events[0].MessageID, events[1].MessageID, textMsg.ID) + } +} + +func TestClientPersistTextMessageWritesInboxFileAndLogs(t *testing.T) { + inboxDir := t.TempDir() + logger := &recordingLogger{} + client := &Client{ + id: "peer-b", + logger: logger, + } + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 31, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + path, err := client.PersistMessage(msg, inboxDir) + if err != nil { + t.Fatalf("PersistMessage() error = %v", err) + } + + if path != filepath.Join(inboxDir, textInboxFileName) { + t.Fatalf("path = %q, want %q", path, filepath.Join(inboxDir, textInboxFileName)) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("os.ReadFile() error = %v", err) + } + + var record textInboxRecord + if err := json.Unmarshal(bytes.TrimSpace(data), &record); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if record.MessageID != msg.ID || record.From != msg.From || record.To != msg.To || record.Body != "hello" { + t.Fatalf("record mismatch: got %+v want message %+v", record, msg) + } + + events := logger.Events() + if len(events) != 2 { + t.Fatalf("event count = %d, want 2", len(events)) + } + if events[0].Event != latencylog.EventBPersistBegin { + t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin) + } + if events[1].Event != latencylog.EventBPersistEnd { + t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd) + } +} + +func TestClientPersistFileMessageWritesInboxFileAndLogs(t *testing.T) { + inboxDir := t.TempDir() + logger := &recordingLogger{} + client := &Client{ + id: "peer-b", + logger: logger, + } + + msg := protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 32, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x01, 0x02, 0x03}, + } + + path, err := client.PersistMessage(msg, inboxDir) + if err != nil { + t.Fatalf("PersistMessage() error = %v", err) + } + + wantPath := filepath.Join(inboxDir, "peer-a-32-payload.bin") + if path != wantPath { + t.Fatalf("path = %q, want %q", path, wantPath) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("os.ReadFile() error = %v", err) + } + if !reflect.DeepEqual(data, msg.Body) { + t.Fatalf("file body mismatch: got %v want %v", data, msg.Body) + } + + events := logger.Events() + if len(events) != 2 { + t.Fatalf("event count = %d, want 2", len(events)) + } + if events[0].Event != latencylog.EventBPersistBegin { + t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin) + } + if events[1].Event != latencylog.EventBPersistEnd { + t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd) + } +} + +func TestClientPersistMessageReturnsErrorOnWriteFailure(t *testing.T) { + blocker := filepath.Join(t.TempDir(), "blocker") + if err := os.WriteFile(blocker, []byte("not a directory"), 0o644); err != nil { + t.Fatalf("os.WriteFile() error = %v", err) + } + + logger := &recordingLogger{} + client := &Client{ + id: "peer-b", + logger: logger, + } + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 33, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + if _, err := client.PersistMessage(msg, blocker); err == nil { + t.Fatal("PersistMessage() error = nil, want non-nil") + } + + events := logger.Events() + if len(events) != 1 { + t.Fatalf("event count = %d, want 1", len(events)) + } + if events[0].Event != latencylog.EventBPersistBegin { + t.Fatalf("event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin) + } +} + +func TestClientIgnoresLoggerFailure(t *testing.T) { + clientConn, receiver := newClientTransportPair( + t, + []transport.Option{transport.WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")}, + nil, + ) + client := &Client{ + id: "peer-a", + conn: clientConn, + logger: failingLogger{}, + } + + sendErr := make(chan error, 1) + go func() { + sendErr <- client.SendText("peer-b", "hello") + }() + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("receiver.Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("SendText() error = %v, want nil even if logger fails", err) + } + if string(got.Body) != "hello" { + t.Fatalf("body = %q, want hello", got.Body) + } +} + +func TestClientPersistIgnoresLoggerFailure(t *testing.T) { + client := &Client{ + id: "peer-b", + logger: failingLogger{}, + } + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 34, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + path, err := client.PersistMessage(msg, t.TempDir()) + if err != nil { + t.Fatalf("PersistMessage() error = %v, want nil even if logger fails", err) + } + if path == "" { + t.Fatal("PersistMessage() path = empty, want non-empty") + } +} + +func TestClientsExchangeMessagesWithLatencyLogs(t *testing.T) { + hub := server.NewHub() + cleanup := stubDialToHub(t, hub) + defer cleanup() + + peerALogger := &recordingLogger{} + peerA, err := Dial("ignored", "peer-a", WithLogger(peerALogger)) + if err != nil { + t.Fatalf("Dial(peer-a) error = %v", err) + } + defer func() { _ = peerA.Close() }() + + peerBLogger := &recordingLogger{} + peerB, err := Dial("ignored", "peer-b", WithLogger(peerBLogger)) + if err != nil { + t.Fatalf("Dial(peer-b) error = %v", err) + } + defer func() { _ = peerB.Close() }() + + inboxDir := t.TempDir() + + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + + if err := peerA.SendText("peer-b", "hello"); err != nil { + t.Fatalf("SendText() error = %v", err) + } + textMsg, err := peerB.Receive() + if err != nil { + t.Fatalf("peerB.Receive(text) error = %v", err) + } + if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil { + t.Fatalf("peerB.PersistMessage(text) error = %v", err) + } + + if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil { + t.Fatalf("SendFile() error = %v", err) + } + fileMsg, err := peerB.Receive() + if err != nil { + t.Fatalf("peerB.Receive(file) error = %v", err) + } + if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil { + t.Fatalf("peerB.PersistMessage(file) error = %v", err) + } + + waitFor(t, func() bool { return len(peerALogger.Events()) == 10 }, "peer-a latency events") + waitFor(t, func() bool { return len(peerBLogger.Events()) == 8 }, "peer-b latency events") + + assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{ + 1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd}, + 2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd}, + }) + assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{ + 1: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd}, + 2: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd}, + }) +} + +func stubDialToHub(t *testing.T, hub *server.Hub) func() { + t.Helper() + + originalDial := dialServer + serverAddr, cleanup := startRealHubServer(t, hub) + + dialServer = func(network, addr string) (net.Conn, error) { + return net.Dial(network, serverAddr) + } + + return func() { + dialServer = originalDial + cleanup() + } +} + +func waitFor(t *testing.T, condition func() bool, description string) { + t.Helper() + + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("timed out waiting for %s", description) +} + +func assertEventSequencesByMessage(t *testing.T, events []latencylog.Event, want map[uint64][]string) { + t.Helper() + + grouped := make(map[uint64][]latencylog.Event) + for _, event := range events { + grouped[event.MessageID] = append(grouped[event.MessageID], event) + if event.TsUnixNano <= 0 { + t.Fatalf("event timestamp must be positive: %+v", event) + } + } + + if len(grouped) != len(want) { + t.Fatalf("message group count = %d, want %d", len(grouped), len(want)) + } + + for messageID, wantEvents := range want { + gotEvents := grouped[messageID] + if len(gotEvents) != len(wantEvents) { + t.Fatalf("message %d event count = %d, want %d", messageID, len(gotEvents), len(wantEvents)) + } + for i, wantEvent := range wantEvents { + if gotEvents[i].Event != wantEvent { + t.Fatalf("message %d event[%d] = %q, want %q", messageID, i, gotEvents[i].Event, wantEvent) + } + } + } +} + +func newClientTransportPair(t *testing.T, clientOpts []transport.Option, peerOpts []transport.Option) (*transport.TCPConn, *transport.TCPConn) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen() error = %v", err) + } + + type acceptResult struct { + conn net.Conn + err error + } + + accepted := make(chan acceptResult, 1) + go func() { + conn, acceptErr := listener.Accept() + accepted <- acceptResult{conn: conn, err: acceptErr} + }() + + clientSide, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + _ = listener.Close() + t.Fatalf("net.Dial() error = %v", err) + } + + result := <-accepted + if err := listener.Close(); err != nil { + t.Fatalf("listener.Close() error = %v", err) + } + if result.err != nil { + _ = clientSide.Close() + t.Fatalf("listener.Accept() error = %v", result.err) + } + + clientConn, err := transport.NewTCPConn(clientSide, clientOpts...) + if err != nil { + _ = clientSide.Close() + _ = result.conn.Close() + t.Fatalf("transport.NewTCPConn(client) error = %v", err) + } + peerConn, err := transport.NewTCPConn(result.conn, peerOpts...) + if err != nil { + _ = clientConn.Close() + _ = result.conn.Close() + t.Fatalf("transport.NewTCPConn(peer) error = %v", err) + } + + t.Cleanup(func() { + _ = clientConn.Close() + _ = peerConn.Close() + }) + + return clientConn, peerConn +} diff --git a/cmd/internal/peer/persist.go b/cmd/internal/peer/persist.go new file mode 100644 index 0000000..cdd982c --- /dev/null +++ b/cmd/internal/peer/persist.go @@ -0,0 +1,99 @@ +package peer + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +const textInboxFileName = "messages.log" + +type textInboxRecord struct { + MessageType protocol.MessageType `json:"message_type"` + MessageID uint64 `json:"message_id"` + From string `json:"from"` + To string `json:"to"` + Body string `json:"body"` +} + +// PersistMessage 将收到的业务消息写入本地磁盘,并记录处理完成节点。 +func (c *Client) PersistMessage(msg protocol.Message, inboxDir string) (string, error) { + if !latencylog.IsBusinessMessage(msg) { + return "", fmt.Errorf("peer: cannot persist message type %s", msg.Type) + } + if inboxDir == "" { + return "", fmt.Errorf("peer: inbox directory is required") + } + + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistBegin, msg) + + if err := os.MkdirAll(inboxDir, 0o755); err != nil { + return "", fmt.Errorf("peer: create inbox dir %s: %w", inboxDir, err) + } + + path, err := persistMessageToDisk(msg, inboxDir) + if err != nil { + return "", err + } + + latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistEnd, msg) + + return path, nil +} + +// persistMessageToDisk 根据消息类型将消息内容写入磁盘,文本消息追加到文本日志文件,文件消息写成独立文件。 +func persistMessageToDisk(msg protocol.Message, inboxDir string) (string, error) { + switch msg.Type { + case protocol.MessageTypeText: + return persistTextMessage(msg, inboxDir) + case protocol.MessageTypeFile: + return persistFileMessage(msg, inboxDir) + default: + return "", fmt.Errorf("peer: cannot persist unsupported message type %s", msg.Type) + } +} + +// registerPeer 验证 peer ID 的合法性和唯一性,并将其与连接关联起来。 +func persistTextMessage(msg protocol.Message, inboxDir string) (string, error) { + record := textInboxRecord{ + MessageType: msg.Type, + MessageID: msg.ID, + From: msg.From, + To: msg.To, + Body: string(msg.Body), + } + + line, err := json.Marshal(record) + if err != nil { + return "", fmt.Errorf("peer: encode text inbox record: %w", err) + } + + path := filepath.Join(inboxDir, textInboxFileName) + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return "", fmt.Errorf("peer: open text inbox %s: %w", path, err) + } + defer file.Close() + + if _, err := file.Write(append(line, '\n')); err != nil { + return "", fmt.Errorf("peer: append text inbox %s: %w", path, err) + } + + return path, nil +} + +// persistFileMessage 将文件消息的内容写成独立文件,文件名包含发送方、消息 ID 和原始文件名,保证唯一性和可读性。 +func persistFileMessage(msg protocol.Message, inboxDir string) (string, error) { + fileName := filepath.Base(msg.FileName) + path := filepath.Join(inboxDir, fmt.Sprintf("%s-%d-%s", msg.From, msg.ID, fileName)) + + if err := os.WriteFile(path, msg.Body, 0o644); err != nil { + return "", fmt.Errorf("peer: write received file %s: %w", path, err) + } + + return path, nil +} diff --git a/cmd/internal/protocol/codec.go b/cmd/internal/protocol/codec.go new file mode 100644 index 0000000..fef658b --- /dev/null +++ b/cmd/internal/protocol/codec.go @@ -0,0 +1,279 @@ +package protocol + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "unicode/utf8" +) + +// MaxFrameSize 用于限制单个帧的最大长度, +// 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。 +const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB + +var ( + ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。 + ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。 + ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。 + ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。 + ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。 + ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。 + ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。 + ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。 + ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。 + ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。 + ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。 + ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。 + ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。 + ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。 +) + +// 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON(下面自定义的Message头)][body bytes] +// 写了 tag:JSON 字段名是你指定的 type;不写 tag:JSON 字段名默认是 Go 字段名 Type +type messageHeader struct { + Type MessageType `json:"type"` + ID uint64 `json:"id"` + From string `json:"from"` + To string `json:"to"` + FileName string `json:"file_name,omitempty"` + ContentLength int `json:"content_length"` +} + +// EncodeMessage 将逻辑消息编码为帧内字节格式: +// 1. 4 字节大端序 header 长度 +// 2. header JSON +// 3. 原始 body 字节 +func EncodeMessage(msg Message) ([]byte, error) { + if err := validateMessage(msg); err != nil { + return nil, err + } + + header := messageHeader{ + Type: msg.Type, + ID: msg.ID, + From: msg.From, + To: msg.To, + FileName: msg.FileName, + ContentLength: len(msg.Body), + } + + headerPayload, err := json.Marshal(header) + if err != nil { + return nil, fmt.Errorf("protocol: encode header: %w", err) + } + // 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。 + payload := make([]byte, 4+len(headerPayload)+len(msg.Body)) + // 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON(第五个字节开始) 和 body。 + binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload))) + copy(payload[4:], headerPayload) + copy(payload[4+len(headerPayload):], msg.Body) + + //检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。 + if len(payload) > MaxFrameSize { + return nil, ErrFrameTooLarge + } + + return payload, nil +} + +// DecodeMessage 将帧内字节格式还原为 Message。 +func DecodeMessage(data []byte) (Message, error) { + if len(data) > MaxFrameSize { + return Message{}, ErrFrameTooLarge + } + if len(data) < 4 { + return Message{}, ErrInvalidHeaderLength + } + + headerLen := int(binary.BigEndian.Uint32(data[:4])) + if headerLen == 0 || headerLen > len(data)-4 { + return Message{}, ErrInvalidHeaderLength + } + + headerPayload := data[4 : 4+headerLen] + body := data[4+headerLen:] + + var header messageHeader + if err := json.Unmarshal(headerPayload, &header); err != nil { + return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err)) + } + + if header.ContentLength < 0 || header.ContentLength != len(body) { + return Message{}, ErrInvalidContentLength + } + + bodyCopy := make([]byte, len(body)) + copy(bodyCopy, body) + + msg := Message{ + Type: header.Type, + ID: header.ID, + From: header.From, + To: header.To, + FileName: header.FileName, + Body: bodyCopy, + } + + if err := validateMessage(msg); err != nil { + return Message{}, err + } + + return msg, nil +} + +// WriteFrame 向流中写入一个带长度前缀的帧。 +// TCP帧格式如下: +// 1. 4 字节大端序长度 +// 2. 后续 payload 内容 +// +// TCP 是字节流协议,没有天然的消息边界。 +// 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节, +// 从而解决粘包和拆包问题。 +func WriteFrame(w io.Writer, payload []byte) error { + size := len(payload) + //空帧 + if size == 0 { + return ErrInvalidFrameLength + } + //帧过大 + if size > MaxFrameSize { + return ErrFrameTooLarge + } + + var header [4]byte + binary.BigEndian.PutUint32(header[:], uint32(size)) + + // 先写长度头,接收方才能根据长度一次性读取完整消息体。 + if err := writeFull(w, header[:]); err != nil { + return err + } + + return writeFull(w, payload) +} + +// ReadFrame 从流中读取一个完整的长度前缀帧。 +// 它会先读取固定 4 字节长度头,校验长度是否合法, +// 再使用 io.ReadFull 按长度读取完整消息体, +// 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。 +func ReadFrame(r io.Reader) ([]byte, error) { + var header [4]byte + if _, err := io.ReadFull(r, header[:]); err != nil { + return nil, err + } + + size := binary.BigEndian.Uint32(header[:]) + // 长度为 0 的帧被认为是非法输入,而不是合法的空消息。 + if size == 0 { + return nil, ErrInvalidFrameLength + } + // 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。 + if size > MaxFrameSize { + return nil, ErrFrameTooLarge + } + + payload := make([]byte, int(size)) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, err + } + + return payload, nil +} + +// WriteMessage 是给上层直接使用的完整发送路径: +// 把一条结构化消息完整编码并发送出去”的总入口。 +// Message -> header+body -> 长度前缀帧 -> io.Writer。 +func WriteMessage(w io.Writer, msg Message) error { + payload, err := EncodeMessage(msg) + if err != nil { + return fmt.Errorf("protocol: encode message: %w", err) + } + + if err := WriteFrame(w, payload); err != nil { + return fmt.Errorf("protocol: write frame: %w", err) + } + + return nil +} + +// ReadMessage 是给上层直接使用的完整接收路径: +// io.Reader -> 长度前缀帧 -> header+body -> Message。 +func ReadMessage(r io.Reader) (Message, error) { + payload, err := ReadFrame(r) + if err != nil { + return Message{}, fmt.Errorf("protocol: read frame: %w", err) + } + + msg, err := DecodeMessage(payload) + if err != nil { + return Message{}, fmt.Errorf("protocol: decode message: %w", err) + } + + return msg, nil +} + +// validateMessage 检查 Message 传输的类型(只接受 text 和 file )。 +func validateMessage(msg Message) error { + if msg.From == "" { + return ErrMissingFrom + } + if msg.To == "" { + return ErrMissingTo + } + + switch msg.Type { + case MessageTypeText: + if msg.FileName != "" { + return ErrUnexpectedFileName + } + if !utf8.Valid(msg.Body) { + return ErrInvalidTextBody + } + case MessageTypeFile: + if msg.FileName == "" { + return ErrMissingFileName + } + case MessageTypeRegister: + if msg.To != ServerPeerID { + return ErrInvalidRegisterTarget + } + if msg.FileName != "" { + return ErrUnexpectedFileName + } + if len(msg.Body) != 0 { + return ErrUnexpectedBody + } + case MessageTypeError: + if msg.From != ServerPeerID { + return ErrInvalidErrorSource + } + if msg.FileName != "" { + return ErrUnexpectedFileName + } + if !utf8.Valid(msg.Body) { + return ErrInvalidTextBody + } + default: + return ErrInvalidMessageType + } + + return nil +} + +// writeFull 会持续写入,直到所有字节都写完或者底层返回错误。 +// 这样可以避免某些 Writer 发生部分写入时破坏帧格式。 +func writeFull(w io.Writer, data []byte) error { + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + return err + } + if n == 0 { + return io.ErrShortWrite + } + data = data[n:] + } + + return nil +} diff --git a/cmd/internal/protocol/codec_test.go b/cmd/internal/protocol/codec_test.go new file mode 100644 index 0000000..b229b36 --- /dev/null +++ b/cmd/internal/protocol/codec_test.go @@ -0,0 +1,507 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "reflect" + "strings" + "testing" +) + +// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。 +func TestEncodeDecodeMessageTextASCII(t *testing.T) { + original := Message{ + Type: MessageTypeText, + ID: 42, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + data, err := EncodeMessage(original) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + + decoded, err := DecodeMessage(data) + if err != nil { + t.Fatalf("DecodeMessage() error = %v", err) + } + + if !reflect.DeepEqual(decoded, original) { + t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) + } +} + +// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8, +// 从而天然兼容 ASCII 之外的普通文本。 +func TestEncodeDecodeMessageTextUTF8(t *testing.T) { + original := Message{ + Type: MessageTypeText, + ID: 43, + From: "peer-a", + To: "peer-b", + Body: []byte("你好, world"), + } + + data, err := EncodeMessage(original) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + + decoded, err := DecodeMessage(data) + if err != nil { + t.Fatalf("DecodeMessage() error = %v", err) + } + + if !reflect.DeepEqual(decoded, original) { + t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) + } +} + +// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。 +func TestEncodeDecodeMessageFile(t *testing.T) { + original := Message{ + Type: MessageTypeFile, + ID: 44, + From: "peer-a", + To: "peer-b", + FileName: "data.bin", + Body: []byte{0x00, 0xff, 0x10, 0x7f}, + } + + data, err := EncodeMessage(original) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + + decoded, err := DecodeMessage(data) + if err != nil { + t.Fatalf("DecodeMessage() error = %v", err) + } + + if !reflect.DeepEqual(decoded, original) { + t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) + } +} + +// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。 +func TestEncodeDecodeMessageRegister(t *testing.T) { + original := Message{ + Type: MessageTypeRegister, + ID: 45, + From: "peer-a", + To: ServerPeerID, + Body: []byte{}, + } + + data, err := EncodeMessage(original) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + + decoded, err := DecodeMessage(data) + if err != nil { + t.Fatalf("DecodeMessage() error = %v", err) + } + + if !reflect.DeepEqual(decoded, original) { + t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) + } +} + +// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。 +func TestEncodeDecodeMessageError(t *testing.T) { + original := Message{ + Type: MessageTypeError, + ID: 46, + From: ServerPeerID, + To: "peer-a", + Body: []byte("unknown target"), + } + + data, err := EncodeMessage(original) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + + decoded, err := DecodeMessage(data) + if err != nil { + t.Fatalf("DecodeMessage() error = %v", err) + } + + if !reflect.DeepEqual(decoded, original) { + t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) + } +} + +// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑, +// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。 +func TestWriteReadFrame(t *testing.T) { + var buf bytes.Buffer + payload := []byte("header+body") + + if err := WriteFrame(&buf, payload); err != nil { + t.Fatalf("WriteFrame() error = %v", err) + } + + got, err := ReadFrame(&buf) + if err != nil { + t.Fatalf("ReadFrame() error = %v", err) + } + + if !bytes.Equal(got, payload) { + t.Fatalf("payload mismatch: got %q want %q", got, payload) + } +} + +// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层, +// 因为外层帧非空的前提下,空正文是合法业务内容。 +func TestWriteReadMessageAllowsEmptyBody(t *testing.T) { + tests := []struct { + name string + message Message + }{ + { + name: "empty text", + message: Message{ + Type: MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte(""), + }, + }, + { + name: "empty file", + message: Message{ + Type: MessageTypeFile, + ID: 2, + From: "peer-a", + To: "peer-b", + FileName: "empty.txt", + Body: []byte{}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + + if err := WriteMessage(&buf, tt.message); err != nil { + t.Fatalf("WriteMessage() error = %v", err) + } + + got, err := ReadMessage(&buf) + if err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + + if !reflect.DeepEqual(got, tt.message) { + t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message) + } + }) + } +} + +// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。 +func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) { + tests := []struct { + name string + message Message + wantErr error + }{ + { + name: "invalid type", + message: Message{ + Type: MessageType("unknown"), + ID: 1, + From: "peer-a", + To: "peer-b", + }, + wantErr: ErrInvalidMessageType, + }, + { + name: "missing from", + message: Message{ + Type: MessageTypeText, + ID: 2, + To: "peer-b", + }, + wantErr: ErrMissingFrom, + }, + { + name: "missing to", + message: Message{ + Type: MessageTypeText, + ID: 3, + From: "peer-a", + }, + wantErr: ErrMissingTo, + }, + { + name: "text with file name", + message: Message{ + Type: MessageTypeText, + ID: 4, + From: "peer-a", + To: "peer-b", + FileName: "bad.txt", + Body: []byte("hello"), + }, + wantErr: ErrUnexpectedFileName, + }, + { + name: "text with invalid utf8", + message: Message{ + Type: MessageTypeText, + ID: 5, + From: "peer-a", + To: "peer-b", + Body: []byte{0xff, 0xfe}, + }, + wantErr: ErrInvalidTextBody, + }, + { + name: "file without file name", + message: Message{ + Type: MessageTypeFile, + ID: 6, + From: "peer-a", + To: "peer-b", + Body: []byte{0x01}, + }, + wantErr: ErrMissingFileName, + }, + { + name: "register with wrong target", + message: Message{ + Type: MessageTypeRegister, + ID: 7, + From: "peer-a", + To: "peer-b", + }, + wantErr: ErrInvalidRegisterTarget, + }, + { + name: "register with body", + message: Message{ + Type: MessageTypeRegister, + ID: 8, + From: "peer-a", + To: ServerPeerID, + Body: []byte("unexpected"), + }, + wantErr: ErrUnexpectedBody, + }, + { + name: "error with wrong source", + message: Message{ + Type: MessageTypeError, + ID: 9, + From: "peer-a", + To: "peer-b", + Body: []byte("bad"), + }, + wantErr: ErrInvalidErrorSource, + }, + { + name: "error with file name", + message: Message{ + Type: MessageTypeError, + ID: 10, + From: ServerPeerID, + To: "peer-a", + FileName: "bad.txt", + Body: []byte("bad"), + }, + wantErr: ErrUnexpectedFileName, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := EncodeMessage(tt.message) + if !errors.Is(err, tt.wantErr) { + t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入, +// 而不是被当成一条合法的空消息。 +func TestReadFrameRejectsInvalidLength(t *testing.T) { + var buf bytes.Buffer + + if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil { + t.Fatalf("binary.Write() error = %v", err) + } + + _, err := ReadFrame(&buf) + if !errors.Is(err, ErrInvalidFrameLength) { + t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength) + } +} + +// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝, +// 从而保证最大长度限制真正生效。 +func TestReadFrameRejectsTooLargeFrame(t *testing.T) { + var buf bytes.Buffer + + if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil { + t.Fatalf("binary.Write() error = %v", err) + } + + _, err := ReadFrame(&buf) + if !errors.Is(err, ErrFrameTooLarge) { + t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge) + } +} + +// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致: +// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。 +func TestWriteFrameRejectsEmptyPayload(t *testing.T) { + var buf bytes.Buffer + + err := WriteFrame(&buf, nil) + if !errors.Is(err, ErrInvalidFrameLength) { + t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength) + } +} + +// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。 +func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + { + name: "too short for header len", + data: []byte{0x00, 0x00, 0x00}, + }, + { + name: "zero header len", + data: []byte{0x00, 0x00, 0x00, 0x00}, + }, + { + name: "header len exceeds payload", + data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := DecodeMessage(tt.data) + if !errors.Is(err, ErrInvalidHeaderLength) { + t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength) + } + }) + } +} + +// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。 +func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) { + data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...) + + _, err := DecodeMessage(data) + if !errors.Is(err, ErrInvalidHeaderJSON) { + t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON) + } +} + +// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。 +func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) { + headerPayload, err := json.Marshal(messageHeader{ + Type: MessageTypeText, + ID: 7, + From: "peer-a", + To: "peer-b", + ContentLength: 10, + }) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + + var data bytes.Buffer + if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil { + t.Fatalf("binary.Write() error = %v", err) + } + if _, err := data.Write(headerPayload); err != nil { + t.Fatalf("data.Write(headerPayload) error = %v", err) + } + if _, err := data.Write([]byte("hello")); err != nil { + t.Fatalf("data.Write(body) error = %v", err) + } + + _, err = DecodeMessage(data.Bytes()) + if !errors.Is(err, ErrInvalidContentLength) { + t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength) + } +} + +// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file, +// 验证读取端每次都能严格停在当前帧边界,不会串包。 +func TestReadMultipleMessages(t *testing.T) { + var buf bytes.Buffer + + first := Message{ + Type: MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + second := Message{ + Type: MessageTypeFile, + ID: 2, + From: "peer-b", + To: "peer-a", + FileName: "payload.bin", + Body: []byte{0x01, 0x02, 0x03}, + } + + if err := WriteMessage(&buf, first); err != nil { + t.Fatalf("WriteMessage(first) error = %v", err) + } + if err := WriteMessage(&buf, second); err != nil { + t.Fatalf("WriteMessage(second) error = %v", err) + } + + gotFirst, err := ReadMessage(&buf) + if err != nil { + t.Fatalf("ReadMessage(first) error = %v", err) + } + gotSecond, err := ReadMessage(&buf) + if err != nil { + t.Fatalf("ReadMessage(second) error = %v", err) + } + + if !reflect.DeepEqual(gotFirst, first) { + t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first) + } + if !reflect.DeepEqual(gotSecond, second) { + t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second) + } +} + +// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。 +func TestReadMessageWrapsDecodeError(t *testing.T) { + var buf bytes.Buffer + + if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil { + t.Fatalf("WriteFrame() error = %v", err) + } + + _, err := ReadMessage(&buf) + if err == nil { + t.Fatal("ReadMessage() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "decode message") { + t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err) + } +} diff --git a/cmd/internal/protocol/message.go b/cmd/internal/protocol/message.go new file mode 100644 index 0000000..5f5d28b --- /dev/null +++ b/cmd/internal/protocol/message.go @@ -0,0 +1,33 @@ +package protocol + +// MessageType 表示一条消息的传输类型。 +// v1 只区分普通文本和文件两类负载。 +type MessageType string + +const ( + // MessageTypeText 表示正文按 UTF-8 文本解释,天然兼容 ASCII。 + MessageTypeText MessageType = "text" + // MessageTypeFile 表示正文是原始文件字节。 + MessageTypeFile MessageType = "file" + // MessageTypeRegister 表示 peer 向 server 显式注册自己的身份。 + MessageTypeRegister MessageType = "register" + // MessageTypeError 表示 server 向 peer 返回错误信息。 + MessageTypeError MessageType = "error" +) + +// ServerPeerID 是协议中约定的 server 端固定标识。 +const ServerPeerID = "server" + +// Message 是 peer 和 server 共用的传输消息结构。 +// 头部元信息会被编码为 JSON,Body 则作为原始字节拼接在头部之后。 +type Message struct { + Type MessageType `json:"type"` // 消息类型,只允许 text 或 file。 + ID uint64 `json:"id"` // 由发送方生成,用于追踪消息。 + From string `json:"from"` // 发送方标识。 + To string `json:"to"` // 接收方标识。 + + // FileName 仅在 Type 为 file 时使用。 + FileName string `json:"file_name,omitempty"` + // Body 是真正传输的正文内容,不进入头部 JSON。 + Body []byte `json:"-"` +} diff --git a/cmd/internal/server/hub.go b/cmd/internal/server/hub.go new file mode 100644 index 0000000..ed11bc3 --- /dev/null +++ b/cmd/internal/server/hub.go @@ -0,0 +1,192 @@ +package server + +import ( + "fmt" + "net" + "sync" + "time" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" +) + +const gracefulRejectCloseTimeout = 100 * time.Millisecond + +// Hub 管理已注册 peer 的连接,并负责在它们之间转发消息。 +type Hub struct { + mu sync.RWMutex + peers map[string]*transport.TCPConn + logger latencylog.Logger +} + +// Option 用于配置 Hub 的可选行为,例如时延日志。 +type Option func(*Hub) + +// WithLogger 为 hub 注入时延日志记录器。 +func WithLogger(logger latencylog.Logger) Option { + return func(hub *Hub) { + hub.logger = logger + } +} + +// NewHub 创建一个空的连接中心。 +func NewHub(opts ...Option) *Hub { + hub := &Hub{ + peers: make(map[string]*transport.TCPConn), + logger: latencylog.NoopLogger{}, + } + + for _, opt := range opts { + opt(hub) + } + + if hub.logger == nil { + hub.logger = latencylog.NoopLogger{} + } + + return hub +} + +// HasPeer 返回给定 ID 是否已经注册到 hub。 +func (h *Hub) HasPeer(peerID string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + + _, ok := h.peers[peerID] + return ok +} + +// ServeConn 处理一条新接入的底层 TCP 连接。 +// 连接上的第一条消息必须是 register,之后才允许转发 text/file。 +func (h *Hub) ServeConn(rawConn net.Conn) error { + conn, err := transport.NewTCPConn(rawConn) + if err != nil { + _ = rawConn.Close() + return fmt.Errorf("server: create transport conn: %w", err) + } + + peerID, gracefulClose, err := h.registerConn(conn) + if err != nil { + h.closeConn(conn, gracefulClose) + return err + } + defer h.unregister(peerID, conn) + + if err := h.receivePeerLoop(peerID, conn); err != nil { + return err + } + + return nil +} + +// registerConn 从新连接上读取第一条消息,验证它是 register 消息,并把连接注册到 hub。 +func (h *Hub) registerConn(conn *transport.TCPConn) (string, bool, error) { + msg, err := conn.Receive() + if err != nil { + return "", false, fmt.Errorf("server: receive register: %w", err) + } + + if msg.Type != protocol.MessageTypeRegister { + if sendErr := sendServerError(conn, msg.From, "first message must be register"); sendErr != nil { + return "", false, fmt.Errorf("server: reject unregistered peer: %w", sendErr) + } + return "", true, fmt.Errorf("server: first message must be register, got %s", msg.Type) + } + + h.mu.Lock() + defer h.mu.Unlock() + + if _, exists := h.peers[msg.From]; exists { + if sendErr := sendServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil { + return "", false, fmt.Errorf("server: duplicate peer id %s: %w", msg.From, sendErr) + } + return "", true, fmt.Errorf("server: duplicate peer id: %s", msg.From) + } + + h.peers[msg.From] = conn + return msg.From, false, nil +} + +// handlePeerMessage 验证消息类型并执行相应的转发或错误响应。 +func (h *Hub) handlePeerMessage(peerID string, conn *transport.TCPConn, msg protocol.Message) (bool, error) { + switch msg.Type { + case protocol.MessageTypeText, protocol.MessageTypeFile: //只允许已注册的 peer 发送文本或文件消息,其他类型都视为协议错误。 + msg.From = peerID + targetConn, ok := h.lookup(msg.To) + if !ok { + return false, sendServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To)) + } + if err := targetConn.Send(msg); err != nil { //转发消息,如果发送失败,说明目标连接可能已经不可用,此时从 hub 中注销该连接并关闭它,并向发送方返回错误响应。 + h.unregister(msg.To, targetConn) + _ = targetConn.Close() + return false, sendServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To)) + } + return false, nil + case protocol.MessageTypeRegister, protocol.MessageTypeError: //已注册的 peer 不允许再发送 register 或 error 消息,这些都视为协议错误。 + if err := sendServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil { + return false, fmt.Errorf("server: send protocol error: %w", err) + } + return true, fmt.Errorf("server: unexpected message type from peer %s: %s", peerID, msg.Type) + default: // 其他任何消息类型都视为协议错误。 + if err := sendServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil { + return false, fmt.Errorf("server: send unsupported type error: %w", err) + } + return true, fmt.Errorf("server: unsupported message type: %s", msg.Type) + } +} + +func (h *Hub) receivePeerLoop(peerID string, conn *transport.TCPConn) error { + for { + msg, err := conn.Receive() + if err != nil { + _ = conn.Close() + return fmt.Errorf("transport: receive loop read: %w", err) + } + + gracefulClose, err := h.handlePeerMessage(peerID, conn, msg) + if err != nil { + h.closeConn(conn, gracefulClose) + return fmt.Errorf("transport: receive loop handler: %w", err) + } + } +} + +// lookup 在 hub 中查找目标 peer 的连接。 +func (h *Hub) lookup(peerID string) (*transport.TCPConn, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + conn, ok := h.peers[peerID] + return conn, ok +} + +// unregister 从 hub 中移除指定 peer 的连接,通常在连接关闭或发生错误时调用。 +func (h *Hub) unregister(peerID string, conn *transport.TCPConn) { + h.mu.Lock() + defer h.mu.Unlock() + + current, ok := h.peers[peerID] + if ok && current == conn { + delete(h.peers, peerID) + } +} + +func (h *Hub) closeConn(conn *transport.TCPConn, graceful bool) { + if graceful { + _ = conn.CloseGracefully(gracefulRejectCloseTimeout) + return + } + + _ = conn.Close() +} + +// sendServerError 是一个辅助函数,用于向指定 peer 发送错误消息。 +func sendServerError(conn *transport.TCPConn, to, message string) error { + return conn.Send(protocol.Message{ + Type: protocol.MessageTypeError, + From: protocol.ServerPeerID, + To: to, + Body: []byte(message), + }) +} diff --git a/cmd/internal/server/hub_test.go b/cmd/internal/server/hub_test.go new file mode 100644 index 0000000..2cee305 --- /dev/null +++ b/cmd/internal/server/hub_test.go @@ -0,0 +1,398 @@ +package server + +import ( + "net" + "reflect" + "strings" + "sync" + "testing" + "time" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" + "omnisocketgo/cmd/internal/transport" +) + +type recordingLogger struct { + mu sync.Mutex + events []latencylog.Event +} + +func (l *recordingLogger) LogEvent(event latencylog.Event) error { + l.mu.Lock() + defer l.mu.Unlock() + + l.events = append(l.events, event) + return nil +} + +func (l *recordingLogger) Events() []latencylog.Event { + l.mu.Lock() + defer l.mu.Unlock() + + return append([]latencylog.Event(nil), l.events...) +} + +func TestServeConnRegistersPeer(t *testing.T) { + hub := NewHub() + client, done := startHubConn(t, hub) + + if err := client.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-a", + To: protocol.ServerPeerID, + }); err != nil { + t.Fatalf("Send(register) error = %v", err) + } + + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + if err := client.Close(); err != nil { + t.Fatalf("client.Close() error = %v", err) + } + + err := <-done + if err == nil || !strings.Contains(err.Error(), "receive loop read") { + t.Fatalf("ServeConn() error = %v, want read-loop shutdown error", err) + } +} + +func TestServeConnRejectsDuplicatePeerID(t *testing.T) { + hub := NewHub() + + first, firstDone := startHubConn(t, hub) + registerPeer(t, first, "peer-a") + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + second, secondDone := startHubConn(t, hub) + registerPeer(t, second, "peer-a") + + got, err := second.Receive() + if err != nil { + t.Fatalf("second.Receive() error = %v", err) + } + if got.Type != protocol.MessageTypeError { + t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError) + } + if string(got.Body) != "duplicate peer id: peer-a" { + t.Fatalf("error body = %q, want duplicate peer message", got.Body) + } + + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "original peer-a to remain registered") + + if err := first.Close(); err != nil { + t.Fatalf("first.Close() error = %v", err) + } + + if err := <-secondDone; err == nil || !strings.Contains(err.Error(), "duplicate peer id") { + t.Fatalf("second ServeConn() error = %v, want duplicate peer id error", err) + } + if err := <-firstDone; err == nil || !strings.Contains(err.Error(), "receive loop read") { + t.Fatalf("first ServeConn() error = %v, want read-loop shutdown error", err) + } +} + +func TestServeConnForwardsMessages(t *testing.T) { + tests := []struct { + name string + msg protocol.Message + }{ + { + name: "text", + msg: protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "spoofed", + To: "peer-b", + Body: []byte("hello"), + }, + }, + { + name: "file", + msg: protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 2, + From: "spoofed", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x01, 0x02, 0x03}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hub := NewHub() + + sender, senderDone := startHubConn(t, hub) + receiver, receiverDone := startHubConn(t, hub) + registerPeer(t, sender, "peer-a") + registerPeer(t, receiver, "peer-b") + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + + if err := sender.Send(tt.msg); err != nil { + t.Fatalf("sender.Send() error = %v", err) + } + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("receiver.Receive() error = %v", err) + } + + want := tt.msg + want.From = "peer-a" + if !reflect.DeepEqual(got, want) { + t.Fatalf("forwarded message mismatch: got %+v want %+v", got, want) + } + + _ = sender.Close() + _ = receiver.Close() + <-senderDone + <-receiverDone + }) + } +} + +func TestServeConnReturnsErrorForUnknownTarget(t *testing.T) { + hub := NewHub() + + client, done := startHubConn(t, hub) + registerPeer(t, client, "peer-a") + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + if err := client.Send(protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "missing-peer", + Body: []byte("hello"), + }); err != nil { + t.Fatalf("Send(text) error = %v", err) + } + + got, err := client.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if got.Type != protocol.MessageTypeError { + t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError) + } + if string(got.Body) != "unknown target: missing-peer" { + t.Fatalf("error body = %q, want unknown target message", got.Body) + } + if !hub.HasPeer("peer-a") { + t.Fatal("peer-a should remain registered after unknown target error") + } + + _ = client.Close() + <-done +} + +func TestServeConnRejectsRegisterAfterRegistration(t *testing.T) { + hub := NewHub() + + client, done := startHubConn(t, hub) + registerPeer(t, client, "peer-a") + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + if err := client.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: "peer-a", + To: protocol.ServerPeerID, + }); err != nil { + t.Fatalf("Send(register again) error = %v", err) + } + + got, err := client.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if got.Type != protocol.MessageTypeError { + t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError) + } + if string(got.Body) != "registered peers can only send text or file messages" { + t.Fatalf("error body = %q, want registered-peer protocol error", got.Body) + } + + if err := <-done; err == nil || !strings.Contains(err.Error(), "unexpected message type from peer peer-a: register") { + t.Fatalf("ServeConn() error = %v, want unexpected register message error", err) + } +} + +func TestServeConnUnregistersPeerOnClose(t *testing.T) { + hub := NewHub() + + client, done := startHubConn(t, hub) + registerPeer(t, client, "peer-a") + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + if err := client.Close(); err != nil { + t.Fatalf("client.Close() error = %v", err) + } + <-done + + waitFor(t, func() bool { return !hub.HasPeer("peer-a") }, "peer-a to be unregistered") +} + +func TestServeConnDoesNotEmitEndpointLatencyEventsOnForward(t *testing.T) { + logger := &recordingLogger{} + hub := NewHub(WithLogger(logger)) + + sender, senderDone := startHubConn(t, hub) + receiver, receiverDone := startHubConn(t, hub) + registerPeer(t, sender, "peer-a") + registerPeer(t, receiver, "peer-b") + waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered") + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 11, + From: "spoofed", + To: "peer-b", + Body: []byte("hello"), + } + if err := sender.Send(msg); err != nil { + t.Fatalf("sender.Send() error = %v", err) + } + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("receiver.Receive() error = %v", err) + } + msg.From = "peer-a" + if !reflect.DeepEqual(got, msg) { + t.Fatalf("forwarded message mismatch: got %+v want %+v", got, msg) + } + + events := logger.Events() + if len(events) != 0 { + t.Fatalf("event count = %d, want 0 because server is a black-box relay", len(events)) + } + + _ = sender.Close() + _ = receiver.Close() + <-senderDone + <-receiverDone +} + +func TestServeConnDoesNotLogLatencyEventsForUnknownTarget(t *testing.T) { + logger := &recordingLogger{} + hub := NewHub(WithLogger(logger)) + + client, done := startHubConn(t, hub) + registerPeer(t, client, "peer-a") + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + if err := client.Send(protocol.Message{ + Type: protocol.MessageTypeText, + ID: 15, + From: "peer-a", + To: "missing-peer", + Body: []byte("hello"), + }); err != nil { + t.Fatalf("Send(text) error = %v", err) + } + + got, err := client.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if got.Type != protocol.MessageTypeError { + t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError) + } + if events := logger.Events(); len(events) != 0 { + t.Fatalf("event count = %d, want 0 for unknown target path", len(events)) + } + + _ = client.Close() + <-done +} + +func TestServeConnDoesNotLogLatencyEventsForDuplicateRegister(t *testing.T) { + logger := &recordingLogger{} + hub := NewHub(WithLogger(logger)) + + first, firstDone := startHubConn(t, hub) + registerPeer(t, first, "peer-a") + waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered") + + second, secondDone := startHubConn(t, hub) + registerPeer(t, second, "peer-a") + + got, err := second.Receive() + if err != nil { + t.Fatalf("second.Receive() error = %v", err) + } + if got.Type != protocol.MessageTypeError { + t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError) + } + if events := logger.Events(); len(events) != 0 { + t.Fatalf("event count = %d, want 0 for duplicate register path", len(events)) + } + + _ = first.Close() + <-secondDone + <-firstDone +} + +func startHubConn(t *testing.T, hub *Hub) (*transport.TCPConn, <-chan error) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen() error = %v", err) + } + done := make(chan error, 1) + + go func() { + serverSide, acceptErr := listener.Accept() + if acceptErr != nil { + done <- acceptErr + return + } + done <- hub.ServeConn(serverSide) + }() + + clientSide, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + _ = listener.Close() + t.Fatalf("net.Dial() error = %v", err) + } + if err := listener.Close(); err != nil { + t.Fatalf("listener.Close() error = %v", err) + } + + conn, err := transport.NewTCPConn(clientSide) + if err != nil { + _ = clientSide.Close() + t.Fatalf("transport.NewTCPConn() error = %v", err) + } + + return conn, done +} + +func registerPeer(t *testing.T, conn *transport.TCPConn, peerID string) { + t.Helper() + + if err := conn.Send(protocol.Message{ + Type: protocol.MessageTypeRegister, + From: peerID, + To: protocol.ServerPeerID, + }); err != nil { + t.Fatalf("Send(register %s) error = %v", peerID, err) + } +} + +func waitFor(t *testing.T, condition func() bool, description string) { + t.Helper() + + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("timed out waiting for %s", description) +} diff --git a/cmd/internal/transport/tcp.go b/cmd/internal/transport/tcp.go new file mode 100644 index 0000000..d0a0f8b --- /dev/null +++ b/cmd/internal/transport/tcp.go @@ -0,0 +1,151 @@ +package transport + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "syscall" + "time" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +// TCPConn 是对单条活跃 TCP 连接的轻量封装。 +// 它负责把协议层的单条消息读写,提升为可复用的收发接口。 +type TCPConn struct { + conn net.Conn + raw syscall.RawConn // 连接对应的底层 syscall 句柄,用于 Linux socket timestamping 收发。 + + logger latencylog.Logger + nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer" + nodeID string // 日志中记录的节点 ID,例如 peer 的 ID 或 server 的 "hub" + writeMu sync.Mutex // 保护 Send 方法的互斥锁,确保同一时刻只有一条完整协议消息被写入连接,防止多条消息字节交叉 + closeOnce sync.Once // 保护 Close 方法的 sync.Once,确保连接只被关闭一次 + closeErr error // 连接关闭时的错误,如果连接成功关闭则为 nil,重复调用 Close 时会返回同样的错误 +} + +// Option 用于为 TCPConn 注入可选行为,例如时延日志。 +type Option func(*TCPConn) + +// WithLogger 为连接发送路径注入业务消息日志上下文。 +func WithLogger(logger latencylog.Logger, nodeRole, nodeID string) Option { + return func(conn *TCPConn) { + conn.logger = logger + conn.nodeRole = nodeRole + conn.nodeID = nodeID + } +} + +// NewTCPConn 用已有的 net.Conn 创建 transport 连接封装。 +func NewTCPConn(conn net.Conn, opts ...Option) (*TCPConn, error) { + tcpConn := &TCPConn{ + conn: conn, + logger: latencylog.NoopLogger{}, + } + + for _, opt := range opts { + opt(tcpConn) + } + + if tcpConn.logger == nil { + tcpConn.logger = latencylog.NoopLogger{} + } + + if err := tcpConn.initLinuxTimestamping(); err != nil { + return nil, err + } + + return tcpConn, nil +} + +// Send 将一条协议消息完整写入底层连接。 +// 多个 goroutine 可以并发调用,内部会串行化写入。 +func (c *TCPConn) Send(msg protocol.Message) error { + c.writeMu.Lock() //“同一时刻只能有一条完整协议消息往连接里写,防止多条消息字节交叉 + defer c.writeMu.Unlock() + latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg) + + if err := c.sendMessageLinux(msg); err != nil { + return fmt.Errorf("transport: send message: %w", err) + } + //记录发送完成的时延日志事件,事件类型为 EventSendHandoffEnd,包含消息的基本信息(类型、ID、来源、目标)。 + latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg) + + return nil +} + +// Receive 从底层连接读取一条完整协议消息。 +// 同一条连接应只由单个 reader 持续调用该方法。 +func (c *TCPConn) Receive() (protocol.Message, error) { + msg, err := c.receiveMessageLinux() + if err != nil { + return protocol.Message{}, fmt.Errorf("transport: receive message: %w", err) + } + + return msg, nil +} + +// ReceiveLoop 持续读取消息并交给 handler 处理。 +// 读取错误、handler 错误或连接关闭都会结束循环,并关闭连接。 +func (c *TCPConn) ReceiveLoop(handler func(protocol.Message) error) error { + for { + msg, err := c.Receive() + if err != nil { + _ = c.Close() + return fmt.Errorf("transport: receive loop read: %w", err) + } + + if err := handler(msg); err != nil { + _ = c.Close() + return fmt.Errorf("transport: receive loop handler: %w", err) + } + } +} + +// CloseGracefully 在支持 half-close 的连接上先关闭写方向,给对端留出读取最终响应的机会, +// 然后在短暂等待后再彻底关闭连接。 +func (c *TCPConn) CloseGracefully(drainTimeout time.Duration) error { + if closeWriter, ok := c.conn.(interface{ CloseWrite() error }); ok { + if err := closeWriter.CloseWrite(); err != nil && !errors.Is(err, net.ErrClosed) { + return c.Close() + } + + if drainTimeout > 0 { + _ = c.conn.SetReadDeadline(time.Now().Add(drainTimeout)) + defer func() { + _ = c.conn.SetReadDeadline(time.Time{}) + }() + + var buf [256]byte + for { + _, err := c.conn.Read(buf[:]) + switch { + case err == nil: + continue + case errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): + return c.Close() + default: + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return c.Close() + } + return c.Close() + } + } + } + } + + return c.Close() +} + +// Close 关闭底层连接,并保证重复调用是安全的。 +func (c *TCPConn) Close() error { + c.closeOnce.Do(func() { + c.closeErr = c.conn.Close() + }) + + return c.closeErr +} diff --git a/cmd/internal/transport/tcp_linux.go b/cmd/internal/transport/tcp_linux.go new file mode 100644 index 0000000..4264310 --- /dev/null +++ b/cmd/internal/transport/tcp_linux.go @@ -0,0 +1,462 @@ +//go:build linux + +package transport + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "syscall" + "time" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +const ( + linuxTimestampControlBufferSize = 256 // 控制消息缓冲区。 + linuxTXTimestampWaitTimeout = 250 * time.Millisecond // 等待 TX 时间戳的上限。 + linuxTXTimestampPollInterval = time.Millisecond // 轮询 errqueue 的间隔。 + + linuxSOTimestampingNew = 0x41 + linuxSCMTimestampingNew = linuxSOTimestampingNew + linuxSOEEOriginTimestamping = 4 // timestamping errqueue 事件。 + linuxSCMTstampSnd = 0 // 对应 A_TX_SOFTWARE。 + linuxSCMTstampSched = 1 // 对应 A_TX_SCHED。 + + linuxSOFTimestampingTXSoftware = 1 << 1 // 打开 TX software timestamp。 + linuxSOFTimestampingRXSoftware = 1 << 3 // 打开 RX software timestamp。 + linuxSOFTimestampingSoftware = 1 << 4 // software timestamp 总开关。 + linuxSOFTimestampingOptID = 1 << 7 // 给时间戳关联 ID。 + linuxSOFTimestampingTXSched = 1 << 8 // 打开 TX sched timestamp。 + linuxSOFTimestampingOptTSONLY = 1 << 11 // 只回时间戳。 + linuxSOFTimestampingOptIDTCP = 1 << 16 // 让 TCP 也带 timestamp ID。 +) + +// 拿到底层 fd,并打开 Linux timestamping。 +func (c *TCPConn) initLinuxTimestamping() error { + sysConn, ok := c.conn.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return fmt.Errorf("transport: connection does not support SyscallConn") + } + + rawConn, err := sysConn.SyscallConn() + if err != nil || rawConn == nil { + if err != nil { + return fmt.Errorf("transport: get syscall conn: %w", err) + } + return fmt.Errorf("transport: missing syscall conn") + } + + //socket是否可以成功打开 timestamping 取决于内核版本和配置,尝试多个 flag 组合直到成功或遇到非 EINVAL 错误。 + if err := enableLinuxTimestamping(rawConn); err != nil { + return fmt.Errorf("transport: enable linux timestamping: %w", err) + } + //成功打开 timestamping 后,rawConn 就可以用来收 TX/RX 时间戳了。 + c.raw = rawConn + return nil +} + +// 给 socket开权限打开TX software timestamping。 +func enableLinuxTimestamping(rawConn syscall.RawConn) error { + flagCandidates := []int{ //不同linux版本可能支持不同的 flag 组合,尝试多个组合直到成功。 + linuxSOFTimestampingTXSched | + linuxSOFTimestampingTXSoftware | + linuxSOFTimestampingRXSoftware | + linuxSOFTimestampingSoftware | + linuxSOFTimestampingOptID | //TCP 协议栈给每个时间戳生成一个序列号 + linuxSOFTimestampingOptIDTCP | + linuxSOFTimestampingOptTSONLY, + linuxSOFTimestampingTXSched | + linuxSOFTimestampingTXSoftware | + linuxSOFTimestampingRXSoftware | + linuxSOFTimestampingSoftware | + linuxSOFTimestampingOptID | + linuxSOFTimestampingOptTSONLY, + linuxSOFTimestampingTXSched | + linuxSOFTimestampingTXSoftware | + linuxSOFTimestampingRXSoftware | + linuxSOFTimestampingSoftware | + linuxSOFTimestampingOptTSONLY, + } + + var lastErr error + for _, flags := range flagCandidates { //尝试不同的 flag 组合,直到成功或遇到非 EINVAL 错误。 + // 内核根据 fd 找到对应的内存结构体(Socket 缓冲区) + err := rawConn.Control(func(fd uintptr) { //Control 方法保证在回调里 fd 是有效的,可以安全地调用 syscall.SetsockoptInt。 + lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags) + }) + if err != nil { + return err + } + if lastErr == nil { + return nil + } + if !errors.Is(lastErr, syscall.EINVAL) { + return lastErr + } + } + + return lastErr +} + +// sendMessageLinux 编码消息、写完整帧,再记录 TX 时间戳。 +func (c *TCPConn) sendMessageLinux(msg protocol.Message) error { + payload, err := protocol.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("protocol: encode message: %w", err) + } + + //编码后的消息 payload 前面加 4 字节长度,构成完整帧。 + frame := make([]byte, 4+len(payload)) + binary.BigEndian.PutUint32(frame[:4], uint32(len(payload))) + copy(frame[4:], payload) + + if err := c.writeFrameLinux(frame); err != nil { + return fmt.Errorf("protocol: write frame: %w", err) + } + //记录发送延时日志 + c.logTXTimestampEvents(msg) + return nil +} + +// writeFrameLinux 用 sendmsg 写完整帧。 +func (c *TCPConn) writeFrameLinux(frame []byte) error { + written := 0 + var opErr error + + err := c.raw.Write(func(fd uintptr) bool { + if written >= len(frame) { + return true + } + + n, sendErr := syscall.SendmsgN(int(fd), frame[written:], nil, nil, 0) + switch { + case sendErr == nil: + if n <= 0 { + opErr = io.ErrShortWrite + return true + } + written += n + return written >= len(frame) + case errors.Is(sendErr, syscall.EAGAIN), errors.Is(sendErr, syscall.EWOULDBLOCK): + return false + default: + opErr = sendErr + return true + } + }) + if err != nil { + return err + } + if opErr != nil { + return opErr + } + if written != len(frame) { + return io.ErrShortWrite + } + + return nil +} + +// 把 A_TX_SCHED / A_TX_SOFTWARE 写入日志。(发送过程中) +func (c *TCPConn) logTXTimestampEvents(msg protocol.Message) { + timestamps := c.collectTXTimestampEvents() + + if ts, ok := timestamps[latencylog.EventATXSched]; ok { + latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg) + } + if ts, ok := timestamps[latencylog.EventATXSoftware]; ok { + latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg) + } +} + +// 在 errqueue 里等两类 TX 时间戳。 +func (c *TCPConn) collectTXTimestampEvents() map[string]int64 { + timestamps := make(map[string]int64, 2) + //设置合理等待上限 + deadline := time.Now().Add(linuxTXTimestampWaitTimeout) + + //轮询 errqueue 直到拿到两类时间戳,或超时,或遇到非 EAGAIN 错误。 + for len(timestamps) < 2 && time.Now().Before(deadline) { + eventName, ts, err := c.recvTXTimestampOnce() + if err != nil { + if isWouldBlock(err) { + time.Sleep(linuxTXTimestampPollInterval) + continue + } + break + } + if eventName == "" || ts <= 0 { + continue + } + if _, exists := timestamps[eventName]; !exists { + timestamps[eventName] = ts + } + } + + return timestamps +} + +// recvTXTimestampOnce 从 errqueue 读一次时间戳事件。 +func (c *TCPConn) recvTXTimestampOnce() (string, int64, error) { + var ( + eventName string // 事件名,例如 A_TX_SCHED 或 A_TX_SOFTWARE。 + tsUnixNS int64 // 时间戳的 UnixNano 表示。 + opErr error + ) + + err := c.raw.Control(func(fd uintptr) { + //设置足够大的 oob buffer 来接收控制消息,调用 recvmsg 从 errqueue 读一条消息。 + oob := make([]byte, linuxTimestampControlBufferSize) + //recvmsg 的 flags 里必须带 MSG_ERRQUEUE,才能从 errqueue 里读消息,非阻塞模式下如果没有消息可读会返回 EAGAIN。 + _, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT) + if recvErr != nil { + opErr = recvErr + return + } + //解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。 + eventName, tsUnixNS = parseTXTimestampControlMessages(oob[:oobn]) + }) + if err != nil { + return "", 0, err + } + if opErr != nil { + return "", 0, opErr + } + + return eventName, tsUnixNS, nil //如果成功拿到时间戳事件,eventName 会是 A_TX_SCHED 或 A_TX_SOFTWARE 之一,tsUnixNS 是对应的时间戳;如果没有拿到事件或时间戳无效,eventName 会是空字符串,tsUnixNS 会是 0。 +} + +// 把底层时间戳映射成日志事件名。 +func parseTXTimestampControlMessages(oob []byte) (string, int64) { + if len(oob) == 0 { + return "", 0 + } + //解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。 + controlMessages, err := syscall.ParseSocketControlMessage(oob) + if err != nil { + return "", 0 + } + + var ( + tsUnixNS int64 //时间戳的 UnixNano 表示。 + tsKind uint32 //extended err里,告诉我们这个时间戳是 sched 还是 software。 + hasTS bool // 是否拿到时间戳了。 + hasKind bool // 是否拿到时间戳类型了。 + ) + //一个 recvmsg 可能会收到多个控制消息,循环找我们关心的时间戳事件,拿到时间戳和事件类型。 + for _, controlMessage := range controlMessages { + switch { + case controlMessage.Header.Level == syscall.SOL_SOCKET && controlMessage.Header.Type == linuxSCMTimestampingNew: + if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 { + tsUnixNS = ts + hasTS = true + } + case isSocketExtendedErr(controlMessage): //判断时间戳是否进入了errqueue, + if info, ok := parseSocketExtendedErrInfo(controlMessage.Data); ok { + tsKind = info //时间戳类型被内核放在 extended err 的附加信息里,解析出来。 + hasKind = true + } + } + } + + if !hasTS || !hasKind { + return "", 0 + } + + switch tsKind { //把内核的时间戳类型映射成日志事件名。(记录时只关心 sched 和 software 两类时间戳) + case linuxSCMTstampSched: + return latencylog.EventATXSched, tsUnixNS + case linuxSCMTstampSnd: + return latencylog.EventATXSoftware, tsUnixNS + default: + return "", 0 + } +} + +// 判断控制消息是否来自 socket extended err。 +// 内核产生的时间戳并不会混合在普通的数据流里,而是被包装成一种特殊的“错误消息”丢进 Error Queue。 +func isSocketExtendedErr(controlMessage syscall.SocketControlMessage) bool { + switch { + case controlMessage.Header.Level == syscall.SOL_IP && controlMessage.Header.Type == syscall.IP_RECVERR: + return true + case controlMessage.Header.Level == syscall.SOL_IPV6 && controlMessage.Header.Type == syscall.IPV6_RECVERR: + return true + default: + return false + } +} + +// 从 socket extended err 的数据里取 origin timestamping 信息。 +func parseSocketExtendedErrInfo(data []byte) (uint32, bool) { + if len(data) < 16 { + return 0, false + } + if data[4] != linuxSOEEOriginTimestamping { + return 0, false + } + + return binary.NativeEndian.Uint32(data[8:12]), true +} + +// 读一条完整消息,并记录 B_RX_SOFTWARE。 +func (c *TCPConn) receiveMessageLinux() (protocol.Message, error) { + payload, rxTimestamp, err := c.readFrameLinux() + if err != nil { + return protocol.Message{}, fmt.Errorf("protocol: read frame: %w", err) + } + + msg, err := protocol.DecodeMessage(payload) + if err != nil { + return protocol.Message{}, fmt.Errorf("protocol: decode message: %w", err) + } + + if rxTimestamp > 0 { + latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg) + } + + return msg, nil +} + +// readFrameLinux 先读 4 字节长度,再读整条 payload。 +func (c *TCPConn) readFrameLinux() ([]byte, int64, error) { + var frameHeader [4]byte + rxTimestamp, err := c.readFullLinux(frameHeader[:]) + if err != nil { + return nil, rxTimestamp, err + } + + size := binary.BigEndian.Uint32(frameHeader[:]) + switch { + case size == 0: + return nil, rxTimestamp, protocol.ErrInvalidFrameLength + case size > protocol.MaxFrameSize: + return nil, rxTimestamp, protocol.ErrFrameTooLarge + } + + payload := make([]byte, int(size)) + bodyTimestamp, err := c.readFullLinux(payload) + if rxTimestamp == 0 { + rxTimestamp = bodyTimestamp + } + if err != nil { + return nil, rxTimestamp, err + } + + return payload, rxTimestamp, nil +} + +// 读满 buf,并保留首个 RX_SOFTWARE(返回进入tcp协议栈的时间戳)。 +func (c *TCPConn) readFullLinux(buf []byte) (int64, error) { + if len(buf) == 0 { + return 0, nil + } + + var ( + offset int + firstRXTime int64 + ) + + for offset < len(buf) { + n, rxTimestamp, err := c.recvmsgLinux(buf[offset:]) + if firstRXTime == 0 && rxTimestamp > 0 { + firstRXTime = rxTimestamp + } + if err != nil { + if errors.Is(err, io.EOF) && offset > 0 { + return firstRXTime, io.ErrUnexpectedEOF + } + return firstRXTime, err + } + + offset += n + } + + return firstRXTime, nil +} + +// recvmsgLinux 用 recvmsg 同时读取数据和控制消息。 +func (c *TCPConn) recvmsgLinux(buf []byte) (int, int64, error) { + var ( + n int + rxTimeNS int64 + opErr error + ) + + err := c.raw.Read(func(fd uintptr) bool { + oob := make([]byte, linuxTimestampControlBufferSize) + readN, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0) + switch { + case recvErr == nil: + if readN == 0 { + opErr = io.EOF + return true + } + n = readN + rxTimeNS = parseRXTimestampControlMessages(oob[:oobN]) + return true + case errors.Is(recvErr, syscall.EAGAIN), errors.Is(recvErr, syscall.EWOULDBLOCK): + return false + default: + opErr = recvErr + return true + } + }) + if err != nil { + return 0, 0, err + } + if opErr != nil { + return 0, 0, opErr + } + + return n, rxTimeNS, nil +} + +// 从控制消息里取 RX_SOFTWARE。 +func parseRXTimestampControlMessages(oob []byte) int64 { + if len(oob) == 0 { + return 0 + } + + controlMessages, err := syscall.ParseSocketControlMessage(oob) + if err != nil { + return 0 + } + + for _, controlMessage := range controlMessages { + if controlMessage.Header.Level != syscall.SOL_SOCKET || controlMessage.Header.Type != linuxSCMTimestampingNew { + continue + } + + if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 { + return ts + } + } + + return 0 +} + +// 取第一个非零 timespec。 +func parseSCMTimestampingData(data []byte) int64 { + const timespec64Size = 16 + + for offset := 0; offset+timespec64Size <= len(data); offset += timespec64Size { + sec := int64(binary.NativeEndian.Uint64(data[offset : offset+8])) + nsec := int64(binary.NativeEndian.Uint64(data[offset+8 : offset+16])) + if sec == 0 && nsec == 0 { + continue + } + return sec*int64(time.Second) + nsec + } + + return 0 +} + +// 判断错误是否是 EAGAIN 或 EWOULDBLOCK。 +func isWouldBlock(err error) bool { + return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK) +} diff --git a/cmd/internal/transport/tcp_linux_test.go b/cmd/internal/transport/tcp_linux_test.go new file mode 100644 index 0000000..5d99c20 --- /dev/null +++ b/cmd/internal/transport/tcp_linux_test.go @@ -0,0 +1,140 @@ +//go:build linux + +package transport + +import ( + "net" + "reflect" + "testing" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +func TestLinuxTimestampingRecordsKernelEvents(t *testing.T) { + tests := []struct { + name string + msg protocol.Message + }{ + { + name: "text", + msg: protocol.Message{ + Type: protocol.MessageTypeText, + ID: 41, + From: "peer-a", + To: "peer-b", + Body: []byte("hello over tcp"), + }, + }, + { + name: "file", + msg: protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 42, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x00, 0x01, 0x02, 0xff}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientConn, serverConn := newTCPPair(t) + + senderLogger := &recordingLogger{} + receiverLogger := &recordingLogger{} + sender, err := NewTCPConn( + clientConn, + WithLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"), + ) + if err != nil { + t.Fatalf("NewTCPConn(sender) error = %v", err) + } + receiver, err := NewTCPConn( + serverConn, + WithLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"), + ) + if err != nil { + t.Fatalf("NewTCPConn(receiver) error = %v", err) + } + t.Cleanup(func() { + _ = sender.Close() + _ = receiver.Close() + }) + + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(tt.msg) + }() + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("Send() error = %v", err) + } + if !reflect.DeepEqual(got, tt.msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg) + } + + assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID) + assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID) + assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID) + }) + } +} + +func newTCPPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen() error = %v", err) + } + + type acceptResult struct { + conn net.Conn + err error + } + + accepted := make(chan acceptResult, 1) + go func() { + conn, acceptErr := listener.Accept() + accepted <- acceptResult{conn: conn, err: acceptErr} + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + _ = listener.Close() + t.Fatalf("net.Dial() error = %v", err) + } + + result := <-accepted + if err := listener.Close(); err != nil { + t.Fatalf("listener.Close() error = %v", err) + } + if result.err != nil { + _ = clientConn.Close() + t.Fatalf("listener.Accept() error = %v", result.err) + } + + return clientConn, result.conn +} + +func assertHasEvent(t *testing.T, events []latencylog.Event, wantEvent string, wantMessageID uint64) { + t.Helper() + + for _, event := range events { + if event.Event == wantEvent && event.MessageID == wantMessageID { + if event.TsUnixNano <= 0 { + t.Fatalf("event %s timestamp must be positive: %+v", wantEvent, event) + } + return + } + } + + t.Fatalf("missing event %s for message %d in %+v", wantEvent, wantMessageID, events) +} diff --git a/cmd/internal/transport/tcp_test.go b/cmd/internal/transport/tcp_test.go new file mode 100644 index 0000000..8d7d5ca --- /dev/null +++ b/cmd/internal/transport/tcp_test.go @@ -0,0 +1,416 @@ +package transport + +import ( + "errors" + "io" + "reflect" + "strings" + "sync" + "testing" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/protocol" +) + +type recordingLogger struct { + mu sync.Mutex + events []latencylog.Event +} + +func (l *recordingLogger) LogEvent(event latencylog.Event) error { + l.mu.Lock() + defer l.mu.Unlock() + + l.events = append(l.events, event) + return nil +} + +func (l *recordingLogger) Events() []latencylog.Event { + l.mu.Lock() + defer l.mu.Unlock() + + return append([]latencylog.Event(nil), l.events...) +} + +type failingLogger struct{} + +func (failingLogger) LogEvent(latencylog.Event) error { + return errors.New("log failed") +} + +// TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。 +func TestSendReceiveMessage(t *testing.T) { + tests := []struct { + name string + msg protocol.Message + }{ + { + name: "text", + msg: protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + }, + }, + { + name: "file", + msg: protocol.Message{ + Type: protocol.MessageTypeFile, + ID: 2, + From: "peer-a", + To: "peer-b", + FileName: "data.bin", + Body: []byte{0x00, 0x10, 0xff}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sender, receiver := newTransportConnPair(t, nil, nil) + //创建一个容量为1的缓冲通道sendErr,用于接收发送操作的错误结果。 + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(tt.msg) //发送消息,并将结果(错误或nil)发送到sendErr通道。 + }() + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。 + t.Fatalf("Send() error = %v", err) + } + + if !reflect.DeepEqual(got, tt.msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg) + } + }) + } +} + +func TestSendLogsHandoffEvents(t *testing.T) { + logger := &recordingLogger{} + sender, receiver := newTransportConnPair( + t, + []Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")}, + nil, + ) + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 7, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(msg) + }() + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("Send() error = %v", err) + } + if !reflect.DeepEqual(got, msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, msg) + } + + events := logger.Events() + if len(events) != 4 { + t.Fatalf("event count = %d, want 4", len(events)) + } + if events[0].Event != latencylog.EventSendHandoffBegin { + t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin) + } + if events[1].Event != latencylog.EventATXSched { + t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched) + } + if events[2].Event != latencylog.EventATXSoftware { + t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware) + } + if events[3].Event != latencylog.EventSendHandoffEnd { + t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd) + } + for i, event := range events { + if event.MessageID != msg.ID { + t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID) + } + } + if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" { + t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a") + } + if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 { + t.Fatalf("timestamps must be positive: %+v", events) + } +} + +func TestSendIgnoresLoggerFailure(t *testing.T) { + sender, receiver := newTransportConnPair( + t, + []Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")}, + nil, + ) + + msg := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 9, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + + sendErr := make(chan error, 1) + go func() { + sendErr <- sender.Send(msg) + }() + + got, err := receiver.Receive() + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + if err := <-sendErr; err != nil { + t.Fatalf("Send() error = %v, want nil even if logger fails", err) + } + if !reflect.DeepEqual(got, msg) { + t.Fatalf("message mismatch: got %+v want %+v", got, msg) + } +} + +// TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。 +func TestReceiveLoopDeliversMessages(t *testing.T) { + sender, receiver := newTransportConnPair(t, nil, nil) + + want := []protocol.Message{ + { + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + }, + { + Type: protocol.MessageTypeFile, + ID: 2, + From: "peer-a", + To: "peer-b", + FileName: "payload.bin", + Body: []byte{0x01, 0x02, 0x03}, + }, + } + + var ( + mu sync.Mutex + got []protocol.Message + ) + loopErr := make(chan error, 1) + go func() { + loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error { + mu.Lock() + defer mu.Unlock() + got = append(got, msg) + return nil + }) + }() + + for _, msg := range want { + if err := sender.Send(msg); err != nil { + t.Fatalf("Send() error = %v", err) + } + } + if err := sender.Close(); err != nil { + t.Fatalf("sender.Close() error = %v", err) + } + + err := <-loopErr + if err == nil { + t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close") + } + if !strings.Contains(err.Error(), "receive loop read") { + t.Fatalf("ReceiveLoop() error = %v, want read context", err) + } + + mu.Lock() + defer mu.Unlock() + if !reflect.DeepEqual(got, want) { + t.Fatalf("received messages mismatch: got %+v want %+v", got, want) + } +} + +// TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。 +func TestConcurrentSendKeepsMessagesIntact(t *testing.T) { + sender, receiver := newTransportConnPair(t, nil, nil) + // 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。 + want := []protocol.Message{ + {Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")}, + {Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")}, + {Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")}, + {Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")}, + } + + received := make(chan protocol.Message, len(want)) + readErr := make(chan error, 1) + go func() { //异步地运行一个 goroutine + for range want { + msg, err := receiver.Receive() + if err != nil { + readErr <- err + return + } + received <- msg + } + readErr <- nil + }() + + var wg sync.WaitGroup + for _, msg := range want { + msg := msg + wg.Add(1) + go func() { //异步处理 + defer wg.Done() + if err := sender.Send(msg); err != nil { + t.Errorf("Send() error = %v", err) + } + }() + } + wg.Wait() + + if err := <-readErr; err != nil { + t.Fatalf("Receive() error = %v", err) + } + + gotByID := make(map[uint64]protocol.Message, len(want)) + for range want { + msg := <-received + gotByID[msg.ID] = msg + } + + for _, msg := range want { + got, ok := gotByID[msg.ID] + if !ok { + t.Fatalf("missing message with ID %d", msg.ID) + } + if !reflect.DeepEqual(got, msg) { + t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg) + } + } +} + +// TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。 +func TestReceiveLoopStopsOnHandlerError(t *testing.T) { + sender, receiver := newTransportConnPair(t, nil, nil) + + wantErr := errors.New("stop loop") + loopErr := make(chan error, 1) + go func() { + loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error { + return wantErr + }) + }() + + first := protocol.Message{ + Type: protocol.MessageTypeText, + ID: 1, + From: "peer-a", + To: "peer-b", + Body: []byte("hello"), + } + if err := sender.Send(first); err != nil { + t.Fatalf("Send(first) error = %v", err) + } + + err := <-loopErr + if !errors.Is(err, wantErr) { + t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr) + } + if !strings.Contains(err.Error(), "receive loop handler") { + t.Fatalf("ReceiveLoop() error = %v, want handler context", err) + } +} + +// TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。 +func TestReceiveLoopStopsOnReadError(t *testing.T) { + sender, receiver := newTransportConnPair(t, nil, nil) + + loopErr := make(chan error, 1) + go func() { + loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error { + return nil + }) + }() + + if err := sender.Close(); err != nil { + t.Fatalf("sender.Close() error = %v", err) + } + + err := <-loopErr + if err == nil { + t.Fatal("ReceiveLoop() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "receive loop read") { + t.Fatalf("ReceiveLoop() error = %v, want read context", err) + } +} + +// TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。 +func TestCloseIsIdempotent(t *testing.T) { + conn, peer := newTransportConnPair(t, nil, nil) + + if err := conn.Close(); err != nil { + t.Fatalf("Close(first) error = %v", err) + } + if err := conn.Close(); err != nil { + t.Fatalf("Close(second) error = %v, want nil", err) + } + if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") { + t.Fatalf("peer.Close() error = %v", err) + } +} + +// TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。 +func TestReceiveReturnsWrappedReadError(t *testing.T) { + conn, peer := newTransportConnPair(t, nil, nil) + go func() { + _ = peer.Close() + }() + + _, err := conn.Receive() + if err == nil { + t.Fatal("Receive() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "transport: receive message") { + t.Fatalf("Receive() error = %v, want wrapped receive error", err) + } + if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") { + t.Fatalf("Receive() error = %v, want underlying read failure", err) + } +} + +func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) { + t.Helper() + + left, right := newTCPPair(t) + + sender, err := NewTCPConn(left, senderOpts...) + if err != nil { + t.Fatalf("NewTCPConn(sender) error = %v", err) + } + receiver, err := NewTCPConn(right, receiverOpts...) + if err != nil { + t.Fatalf("NewTCPConn(receiver) error = %v", err) + } + + t.Cleanup(func() { + _ = sender.Close() + _ = receiver.Close() + }) + + return sender, receiver +} diff --git a/cmd/latencysummary/main.go b/cmd/latencysummary/main.go new file mode 100644 index 0000000..76c4df1 --- /dev/null +++ b/cmd/latencysummary/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "flag" + "log" + "path/filepath" + "strings" + + "omnisocketgo/cmd/internal/latencylog" +) + +type stringListFlag []string + +func (f *stringListFlag) String() string { + return "" +} + +func (f *stringListFlag) Set(value string) error { + *f = append(*f, value) + return nil +} + +func main() { + var inputPaths stringListFlag + outputPath := flag.String("output", "latency-summary.jsonl", "output JSONL file for summarized latency metrics") + flag.Var(&inputPaths, "input", "raw latency JSONL file path; can be provided multiple times") + flag.Parse() + + if len(inputPaths) == 0 { + log.Fatal("at least one -input raw latency log file is required") + } + + events, err := latencylog.LoadEventsFromFiles(inputPaths) + if err != nil { + log.Fatalf("load raw latency logs: %v", err) + } + + summaries := latencylog.SummarizeEvents(events) + if err := latencylog.WriteSummariesJSONL(*outputPath, summaries); err != nil { + log.Fatalf("write latency summary: %v", err) + } + + chartPath := replaceFileExt(*outputPath, ".html") + if err := latencylog.WriteSummariesHTMLChart(chartPath, summaries); err != nil { + log.Fatalf("write latency chart: %v", err) + } + + log.Printf("wrote %d summarized message records to %s", len(summaries), *outputPath) + log.Printf("wrote simple latency chart to %s", chartPath) +} + +func replaceFileExt(path, ext string) string { + currentExt := filepath.Ext(path) + if currentExt == "" { + return path + ext + } + + return strings.TrimSuffix(path, currentExt) + ext +} diff --git a/cmd/peer/interactive.go b/cmd/peer/interactive.go new file mode 100644 index 0000000..8137bc0 --- /dev/null +++ b/cmd/peer/interactive.go @@ -0,0 +1,89 @@ +package main + +import ( + "errors" + "fmt" + "io" + "strings" +) + +const ( + interactiveCommandHelp = "help" + interactiveCommandQuit = "quit" + interactiveCommandText = "text" + interactiveCommandFile = "file" +) + +// 交互式命令行界面,允许用户在连接建立后反复发送文本或文件消息。 +var errEmptyInteractiveCommand = errors.New("interactive command is empty") + +type interactiveCommand struct { + name string + to string + value string +} + +// 解析用户输入的交互式命令,支持发送文本或文件消息,以及查看帮助和退出。 +func parseInteractiveCommand(line string) (interactiveCommand, error) { + commandName, rest, ok := cutInteractiveField(strings.TrimSpace(line)) + if !ok { + return interactiveCommand{}, errEmptyInteractiveCommand + } + + switch strings.ToLower(commandName) { + case "help", "h", "?": + return interactiveCommand{name: interactiveCommandHelp}, nil + case "quit", "exit": + return interactiveCommand{name: interactiveCommandQuit}, nil + case interactiveCommandText: + to, body, err := parseInteractiveTargetValue(rest, interactiveCommandText) + if err != nil { + return interactiveCommand{}, err + } + return interactiveCommand{name: interactiveCommandText, to: to, value: body}, nil + case interactiveCommandFile: + to, path, err := parseInteractiveTargetValue(rest, interactiveCommandFile) + if err != nil { + return interactiveCommand{}, err + } + return interactiveCommand{name: interactiveCommandFile, to: to, value: path}, nil + default: + return interactiveCommand{}, fmt.Errorf("unknown command %q; type help for usage", commandName) + } +} + +func parseInteractiveTargetValue(rest, commandName string) (string, string, error) { + to, value, ok := cutInteractiveField(strings.TrimSpace(rest)) + if !ok { + return "", "", fmt.Errorf("%s command requires a target peer and payload", commandName) + } + if strings.TrimSpace(value) == "" { + return "", "", fmt.Errorf("%s command requires a non-empty payload", commandName) + } + + return to, strings.TrimSpace(value), nil +} + +func cutInteractiveField(input string) (string, string, bool) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return "", "", false + } + + for i, r := range trimmed { + if r == ' ' || r == '\t' { + return trimmed[:i], strings.TrimSpace(trimmed[i+1:]), true + } + } + + return trimmed, "", true +} + +// 打印交互式命令帮助信息,列出可用的命令和用法说明。 +func printInteractiveHelp(w io.Writer) { + _, _ = fmt.Fprintln(w, "interactive mode commands:") + _, _ = fmt.Fprintln(w, " help show this help") + _, _ = fmt.Fprintln(w, " text send one text message over the existing connection") + _, _ = fmt.Fprintln(w, " file send one file over the existing connection") + _, _ = fmt.Fprintln(w, " quit exit this peer process") +} diff --git a/cmd/peer/interactive_test.go b/cmd/peer/interactive_test.go new file mode 100644 index 0000000..74abf87 --- /dev/null +++ b/cmd/peer/interactive_test.go @@ -0,0 +1,87 @@ +package main + +import "testing" + +func TestParseInteractiveCommand(t *testing.T) { + tests := []struct { + name string + line string + want interactiveCommand + wantErr string + }{ + { + name: "text command preserves spaces in body", + line: "text peer-b hello over the same connection", + want: interactiveCommand{ + name: interactiveCommandText, + to: "peer-b", + value: "hello over the same connection", + }, + }, + { + name: "file command preserves spaces in path", + line: "file peer-b /tmp/demo payload.bin", + want: interactiveCommand{ + name: interactiveCommandFile, + to: "peer-b", + value: "/tmp/demo payload.bin", + }, + }, + { + name: "help alias", + line: "?", + want: interactiveCommand{ + name: interactiveCommandHelp, + }, + }, + { + name: "quit alias", + line: "exit", + want: interactiveCommand{ + name: interactiveCommandQuit, + }, + }, + { + name: "empty command", + line: " ", + wantErr: errEmptyInteractiveCommand.Error(), + }, + { + name: "text requires payload", + line: "text peer-b", + wantErr: "text command requires a non-empty payload", + }, + { + name: "file requires target and payload", + line: "file", + wantErr: "file command requires a target peer and payload", + }, + { + name: "unknown command", + line: "ping peer-b", + wantErr: `unknown command "ping"; type help for usage`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseInteractiveCommand(tt.line) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("parseInteractiveCommand(%q) error = nil, want %q", tt.line, tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("parseInteractiveCommand(%q) error = %q, want %q", tt.line, err.Error(), tt.wantErr) + } + return + } + + if err != nil { + t.Fatalf("parseInteractiveCommand(%q) error = %v", tt.line, err) + } + if got != tt.want { + t.Fatalf("parseInteractiveCommand(%q) = %+v, want %+v", tt.line, got, tt.want) + } + }) + } +} diff --git a/cmd/peer/main.go b/cmd/peer/main.go new file mode 100644 index 0000000..ab3244b --- /dev/null +++ b/cmd/peer/main.go @@ -0,0 +1,173 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "io" + "log" + "os" + + "omnisocketgo/cmd/internal/latencylog" + peerpkg "omnisocketgo/cmd/internal/peer" + "omnisocketgo/cmd/internal/protocol" +) + +func main() { + peerID := flag.String("id", "peer-a", "peer identity") // peer 的标识 + serverAddr := flag.String("server", "127.0.0.1:9000", "server address") // server 的地址 + targetPeer := flag.String("to", "", "optional target peer for one outgoing message") // 可选的目标 peer 标识 + text := flag.String("text", "", "optional text to send after connecting") // 可选的文本消息内容 + filePath := flag.String("file", "", "optional file path to send after connecting") + inboxDir := flag.String("inbox-dir", "inbox", "directory used to persist received text and file messages") + logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs") + interactive := flag.Bool("interactive", true, "enable interactive REPL for repeated text/file sends on the same connection") + flag.Parse() + + clientOptions := make([]peerpkg.Option, 0, 1) + if *logPath != "" { + logger, err := latencylog.NewJSONLLogger(*logPath) + if err != nil { + log.Fatalf("create latency logger %s: %v", *logPath, err) + } + defer logger.Close() + clientOptions = append(clientOptions, peerpkg.WithLogger(logger)) + } + + client, err := peerpkg.Dial(*serverAddr, *peerID, clientOptions...) + if err != nil { + log.Fatalf("dial server %s: %v", *serverAddr, err) + } + defer client.Close() + + log.Printf("connected to %s as %s", *serverAddr, client.ID()) + + receiveErr := make(chan error, 1) + go func() { + receiveErr <- client.ReceiveLoop(func(msg protocol.Message) error { + switch msg.Type { + case protocol.MessageTypeText: + path, err := client.PersistMessage(msg, *inboxDir) + if err != nil { + return err + } + log.Printf("received text from %s to %s and persisted to %s", msg.From, msg.To, path) + case protocol.MessageTypeFile: + path, err := client.PersistMessage(msg, *inboxDir) + if err != nil { + return err + } + log.Printf("received file from %s to %s: %s (%d bytes) -> %s", msg.From, msg.To, msg.FileName, len(msg.Body), path) + case protocol.MessageTypeError: + log.Printf("received %s from %s to %s: %s", msg.Type, msg.From, msg.To, string(msg.Body)) + default: + log.Printf("received unexpected message type %s from %s", msg.Type, msg.From) + } + return nil + }) + }() + + if *text != "" && *filePath != "" { + log.Fatal("only one of -text or -file may be specified") + } + + if (*text != "" || *filePath != "") && *targetPeer == "" { + log.Fatal("flag -to is required when sending text or file") + } + + //如果指定了目标 peer 和文本消息内容,则向目标 peer 发送一条文本消息,如果发送失败,打印错误日志并退出。 + if *targetPeer != "" && *text != "" { + if err := client.SendText(*targetPeer, *text); err != nil { + log.Fatalf("send text to %s: %v", *targetPeer, err) + } + log.Printf("sent text to %s", *targetPeer) + } + + if *targetPeer != "" && *filePath != "" { + if err := client.SendFilePath(*targetPeer, *filePath); err != nil { + log.Fatalf("send file %s to %s: %v", *filePath, *targetPeer, err) + } + log.Printf("sent file %s to %s", *filePath, *targetPeer) + } + + if *interactive { + if err := runInteractiveShell(client, os.Stdin, os.Stdout, receiveErr); err != nil { + log.Printf("interactive shell ended: %v", err) + } + return + } + + if err := <-receiveErr; err != nil { + log.Printf("receive loop ended: %v", err) + } +} + +func runInteractiveShell(client *peerpkg.Client, in io.Reader, out io.Writer, receiveErr <-chan error) error { + printInteractiveHelp(out) + lines, inputErr := readInteractiveLines(in, out, fmt.Sprintf("%s> ", client.ID())) + + for { + select { + case err := <-receiveErr: + return err + case line, ok := <-lines: + if !ok { + return <-inputErr + } + + command, err := parseInteractiveCommand(line) + if err != nil { + if err == errEmptyInteractiveCommand { + continue + } + log.Printf("interactive command error: %v", err) + continue + } + + switch command.name { + case interactiveCommandHelp: + printInteractiveHelp(out) + case interactiveCommandQuit: + return nil + case interactiveCommandText: + if err := client.SendText(command.to, command.value); err != nil { + log.Printf("send text to %s: %v", command.to, err) + continue + } + log.Printf("sent text to %s", command.to) + case interactiveCommandFile: + if err := client.SendFilePath(command.to, command.value); err != nil { + log.Printf("send file %s to %s: %v", command.value, command.to, err) + continue + } + log.Printf("sent file %s to %s", command.value, command.to) + } + } + } +} + +func readInteractiveLines(in io.Reader, out io.Writer, prompt string) (<-chan string, <-chan error) { + lines := make(chan string) + errs := make(chan error, 1) + + go func() { + defer close(lines) + + scanner := bufio.NewScanner(in) + scanner.Buffer(make([]byte, 0, 1024), 1024*1024) + + for { + if _, err := fmt.Fprint(out, prompt); err != nil { + errs <- err + return + } + if !scanner.Scan() { + errs <- scanner.Err() + return + } + lines <- scanner.Text() + } + }() + + return lines, errs +} diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..390e145 --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "flag" + "log" + "net" + + "omnisocketgo/cmd/internal/latencylog" + "omnisocketgo/cmd/internal/server" +) + +func main() { + listenAddr := flag.String("listen", ":9000", "server listen address") //监听地址 + logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs") + flag.Parse() //查看命令行参数 + + hubOptions := make([]server.Option, 0, 1) + if *logPath != "" { + logger, err := latencylog.NewJSONLLogger(*logPath) + if err != nil { + log.Fatalf("create latency logger %s: %v", *logPath, err) + } + defer logger.Close() + hubOptions = append(hubOptions, server.WithLogger(logger)) + } + + listener, err := net.Listen("tcp", *listenAddr) //开启tcp监听器,监听来自客户端的连接请求 + if err != nil { + log.Fatalf("listen on %s: %v", *listenAddr, err) + } + defer listener.Close() //确保在 main 函数退出时关闭监听器 + + hub := server.NewHub(hubOptions...) //创建一个新的 Hub 实例,负责管理客户端连接和消息转发 + log.Printf("server listening on %s", listener.Addr()) + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("accept connection: %v", err) + continue + } + + go func(rawConn net.Conn) { + if err := hub.ServeConn(rawConn); err != nil { + log.Printf("connection closed: %v", err) + } + }(conn) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c8fbf8c --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module omnisocketgo + +go 1.22 diff --git a/latencysummary b/latencysummary new file mode 100755 index 0000000..e8ecfe9 Binary files /dev/null and b/latencysummary differ