commit 48246752444e0bb786578645570279e4f3db9d6b
Author: nnbcccscdscdsc <2709767634@qq.com>
Date: Mon Mar 23 20:18:53 2026 +0800
init
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..582ec45
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+bin/*
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..d8b47ed
--- /dev/null
+++ b/README.md
@@ -0,0 +1,91 @@
+# OmniSocketGo
+
+Linux only. Go 1.22.
+
+如果目标机器只运行 `server`,只需要编译并拷贝 `server` 二进制。
+如果目标机器只运行 `peer`,只需要编译并拷贝 `peer` 二进制。
+
+`go build ./cmd/server` 和 `go build ./cmd/peer` 会把各自依赖到的功能一起编译进最终二进制,不需要再单独编译 `cmd/internal/...` 包。
+
+- `server` 二进制会包含它依赖到的转发、协议、传输等代码
+- `peer` 二进制会包含它依赖到的注册、交互发送、接收落盘、协议、传输等代码
+- 只有没有被这个可执行程序引用的其他命令,才不在该二进制里,比如 `cmd/latencysummary`
+
+## Build
+
+按目标架构分别编译。
+
+### Linux amd64
+
+```bash
+CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/server-linux-amd64 ./cmd/server
+CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/peer-linux-amd64 ./cmd/peer
+```
+
+### Linux arm64
+
+```bash
+CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/server-linux-arm64 ./cmd/server
+CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/peer-linux-arm64 ./cmd/peer
+```
+
+### Linux armv7
+
+```bash
+CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/server-linux-armv7 ./cmd/server
+CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/peer-linux-armv7 ./cmd/peer
+```
+
+## Deploy
+
+把对应架构的二进制拷到目标机器,并赋予可执行权限。
+
+```bash
+scp bin/server-linux-amd64 user@server-host:~/omnisocket/server
+scp bin/peer-linux-amd64 user@peer-host:~/omnisocket/peer
+```
+
+```bash
+ssh user@server-host 'chmod +x ~/omnisocket/server'
+ssh user@peer-host 'chmod +x ~/omnisocket/peer'
+```
+
+## Run On Different Machines
+
+`server` 所在机器监听 `0.0.0.0:9000`。
+
+```bash
+~/omnisocket/server -listen 0.0.0.0:9000
+```
+
+每个 `peer` 所在机器连接 `server` 的实际 IP,例如 `192.168.1.50:9000`。
+
+### peer-a
+
+```bash
+~/omnisocket/peer \
+ -id peer-a \
+ -server 192.168.1.50:9000 \
+ -inbox-dir /tmp/peer-a-inbox
+```
+
+### peer-b
+
+```bash
+~/omnisocket/peer \
+ -id peer-b \
+ -server 192.168.1.50:9000 \
+ -inbox-dir /tmp/peer-b-inbox
+```
+
+## Interactive Commands
+
+`peer` 启动后可以在终端里持续使用同一条长连接发送多次消息。
+
+```text
+help
+text peer-b hello
+text peer-a hi
+file peer-b /tmp/test.bin
+quit
+```
diff --git a/cmd/internal/latencylog/logger.go b/cmd/internal/latencylog/logger.go
new file mode 100644
index 0000000..f1f1f31
--- /dev/null
+++ b/cmd/internal/latencylog/logger.go
@@ -0,0 +1,166 @@
+package latencylog
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+const (
+ NodeRolePeer = "peer" //客户端节点
+ NodeRoleServer = "server" //云端转发节点
+)
+
+// 记录的消息事件的类型常量。
+const (
+ EventAAppPrepBegin = "A_APP_PREP_BEGIN" // A 端应用开始准备这条消息
+ EventATXSched = "A_TX_SCHED" // A 端进入 Linux qdisc 之前
+ EventATXSoftware = "A_TX_SOFTWARE" // A 端即将交给网卡驱动
+ EventATXHardware = "A_TX_HARDWARE" // A 端网卡真正发出到物理介质
+ EventBRXHardware = "B_RX_HARDWARE" // B 端网卡真正从物理介质收到
+ EventBRXSoftware = "B_RX_SOFTWARE" // B 端驱动把数据交给 Linux 接收栈
+ EventBAppRecv = "B_APP_RECV" // B 端应用真正读到完整消息
+ EventBPersistBegin = "B_PERSIST_BEGIN" // B 端开始写盘
+ EventBPersistEnd = "B_PERSIST_END" // B 端写盘完成
+
+ EventSendHandoffBegin = "send_handoff_begin" // 调试事件:应用把消息交给传输层开始
+ EventSendHandoffEnd = "send_handoff_end" // 调试事件:应用把消息交给传输层结束
+)
+
+// Event 是一条时延时间戳日志记录。
+type Event struct {
+ TsUnixNano int64 `json:"ts_unix_nano"`
+ NodeRole string `json:"node_role"`
+ NodeID string `json:"node_id"`
+ Event string `json:"event"`
+ MessageType protocol.MessageType `json:"message_type"`
+ MessageID uint64 `json:"message_id"`
+ From string `json:"from"`
+ To string `json:"to"`
+ FileName string `json:"file_name,omitempty"`
+ BodySize int `json:"body_size"`
+}
+
+// Logger 负责接收事件并将其写入外部介质。
+type Logger interface {
+ LogEvent(Event) error
+}
+
+// NoopLogger 是默认的空实现。
+type NoopLogger struct{}
+
+// LogEvent 对空日志实现始终返回 nil。
+func (NoopLogger) LogEvent(Event) error {
+ return nil
+}
+
+// JSONLLogger 以 JSONL 形式追加写日志文件。
+type JSONLLogger struct {
+ mu sync.Mutex
+ closeOnce sync.Once
+ closeErr error
+ file *os.File
+}
+
+// NewJSONLLogger 创建一个线程安全的 JSONL 文件日志器。
+func NewJSONLLogger(path string) (*JSONLLogger, error) {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return nil, err
+ }
+
+ file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ return nil, err
+ }
+
+ return &JSONLLogger{file: file}, nil
+}
+
+// LogEvent 以单行 JSON 的形式追加一条事件。
+func (l *JSONLLogger) LogEvent(event Event) error {
+ line, err := json.Marshal(event)
+ if err != nil {
+ return err
+ }
+
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ if _, err := l.file.Write(append(line, '\n')); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// Close 关闭底层文件;重复调用是安全的。
+func (l *JSONLLogger) Close() error {
+ l.closeOnce.Do(func() {
+ l.closeErr = l.file.Close()
+ })
+
+ return l.closeErr
+}
+
+// IsBusinessMessage 判断消息是否属于要参与 A-C-B 时延分析的业务消息。
+func IsBusinessMessage(msg protocol.Message) bool {
+ switch msg.Type {
+ case protocol.MessageTypeText, protocol.MessageTypeFile:
+ return true
+ default:
+ return false
+ }
+}
+
+// NewMessageEvent 用当前 UTC 时间为一条业务消息构造事件。
+func NewMessageEvent(nodeRole, nodeID, eventName string, msg protocol.Message) Event {
+ return NewMessageEventAt(time.Now().UTC().UnixNano(), nodeRole, nodeID, eventName, msg)
+}
+
+// NewMessageEventAt 用指定的 UnixNano 时间为一条业务消息构造事件。
+func NewMessageEventAt(tsUnixNano int64, nodeRole, nodeID, eventName string, msg protocol.Message) Event {
+ return Event{
+ TsUnixNano: tsUnixNano,
+ NodeRole: nodeRole,
+ NodeID: nodeID,
+ Event: eventName,
+ MessageType: msg.Type,
+ MessageID: msg.ID,
+ From: msg.From,
+ To: msg.To,
+ FileName: msg.FileName,
+ BodySize: len(msg.Body),
+ }
+}
+
+// LogBestEffort 写一条事件,失败时静默忽略,避免打断主收发流程。
+func LogBestEffort(logger Logger, event Event) {
+ if logger == nil {
+ return
+ }
+
+ _ = logger.LogEvent(event)
+}
+
+// LogMessageEvent 为业务消息构造并写入一条事件。
+func LogMessageEvent(logger Logger, nodeRole, nodeID, eventName string, msg protocol.Message) {
+ if !IsBusinessMessage(msg) {
+ return
+ }
+
+ LogBestEffort(logger, NewMessageEvent(nodeRole, nodeID, eventName, msg))
+}
+
+// LogMessageEventAt 为业务消息写入一条指定时间戳的事件。
+func LogMessageEventAt(logger Logger, nodeRole, nodeID, eventName string, tsUnixNano int64, msg protocol.Message) {
+ if !IsBusinessMessage(msg) {
+ return
+ }
+
+ LogBestEffort(logger, NewMessageEventAt(tsUnixNano, nodeRole, nodeID, eventName, msg))
+}
diff --git a/cmd/internal/latencylog/logger_test.go b/cmd/internal/latencylog/logger_test.go
new file mode 100644
index 0000000..d1850fe
--- /dev/null
+++ b/cmd/internal/latencylog/logger_test.go
@@ -0,0 +1,131 @@
+package latencylog
+
+import (
+ "bufio"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+func TestJSONLLoggerWritesOneEventPerLine(t *testing.T) {
+ path := filepath.Join(t.TempDir(), "latency.jsonl")
+
+ logger, err := NewJSONLLogger(path)
+ if err != nil {
+ t.Fatalf("NewJSONLLogger() error = %v", err)
+ }
+ t.Cleanup(func() {
+ _ = logger.Close()
+ })
+
+ event := Event{
+ TsUnixNano: 123,
+ NodeRole: NodeRolePeer,
+ NodeID: "peer-a",
+ Event: EventAAppPrepBegin,
+ MessageType: protocol.MessageTypeText,
+ MessageID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ BodySize: 5,
+ }
+ if err := logger.LogEvent(event); err != nil {
+ t.Fatalf("LogEvent() error = %v", err)
+ }
+
+ file, err := os.Open(path)
+ if err != nil {
+ t.Fatalf("os.Open() error = %v", err)
+ }
+ defer file.Close()
+
+ scanner := bufio.NewScanner(file)
+ if !scanner.Scan() {
+ t.Fatal("expected one JSONL line, got none")
+ }
+
+ var got Event
+ if err := json.Unmarshal(scanner.Bytes(), &got); err != nil {
+ t.Fatalf("json.Unmarshal() error = %v", err)
+ }
+ if got != event {
+ t.Fatalf("event mismatch: got %+v want %+v", got, event)
+ }
+ if scanner.Scan() {
+ t.Fatal("expected exactly one JSONL line")
+ }
+ if err := scanner.Err(); err != nil {
+ t.Fatalf("scanner.Err() = %v", err)
+ }
+}
+
+func TestJSONLLoggerHandlesConcurrentWrites(t *testing.T) {
+ path := filepath.Join(t.TempDir(), "latency.jsonl")
+
+ logger, err := NewJSONLLogger(path)
+ if err != nil {
+ t.Fatalf("NewJSONLLogger() error = %v", err)
+ }
+ t.Cleanup(func() {
+ _ = logger.Close()
+ })
+
+ const total = 32
+
+ var wg sync.WaitGroup
+ for i := 0; i < total; i++ {
+ i := i
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ err := logger.LogEvent(Event{
+ TsUnixNano: int64(i + 1),
+ NodeRole: NodeRoleServer,
+ NodeID: protocol.ServerPeerID,
+ Event: EventBAppRecv,
+ MessageType: protocol.MessageTypeFile,
+ MessageID: uint64(i + 1),
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ BodySize: 3,
+ })
+ if err != nil {
+ t.Errorf("LogEvent() error = %v", err)
+ }
+ }()
+ }
+ wg.Wait()
+
+ file, err := os.Open(path)
+ if err != nil {
+ t.Fatalf("os.Open() error = %v", err)
+ }
+ defer file.Close()
+
+ scanner := bufio.NewScanner(file)
+ var count int
+ seen := make(map[uint64]bool, total)
+ for scanner.Scan() {
+ var event Event
+ if err := json.Unmarshal(scanner.Bytes(), &event); err != nil {
+ t.Fatalf("json.Unmarshal() error = %v", err)
+ }
+ count++
+ seen[event.MessageID] = true
+ }
+ if err := scanner.Err(); err != nil {
+ t.Fatalf("scanner.Err() = %v", err)
+ }
+ if count != total {
+ t.Fatalf("line count = %d, want %d", count, total)
+ }
+ if len(seen) != total {
+ t.Fatalf("unique message count = %d, want %d", len(seen), total)
+ }
+}
diff --git a/cmd/internal/latencylog/summary.go b/cmd/internal/latencylog/summary.go
new file mode 100644
index 0000000..13337b2
--- /dev/null
+++ b/cmd/internal/latencylog/summary.go
@@ -0,0 +1,254 @@
+package latencylog
+
+import (
+ "bufio"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sort"
+
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+// Summary 是针对单条消息的时延的规则列表。
+var requiredTimestampNames = []string{
+ EventAAppPrepBegin, // A 端应用开始准备这条消息
+ EventATXSched, // A 端进入 Linux qdisc 之前
+ EventATXSoftware, // A 端即将交给网卡驱动
+ EventBRXSoftware, // B 端网卡驱动把数据交给 Linux 接收栈
+ EventBAppRecv, // B 端应用真正读到完整消息
+ EventBPersistEnd, // B 端写盘完成
+}
+
+// Summary 是针对单条消息的时延整理结果。
+type Summary struct {
+ MessageType protocol.MessageType `json:"message_type"` //消息类型
+ MessageID uint64 `json:"message_id"` //消息ID
+ From string `json:"from"` //发送方
+ To string `json:"to"` //接收方
+ FileName string `json:"file_name,omitempty"` //文件名(仅文件消息)
+ BodySize int `json:"body_size"` //消息体大小(字节数)
+ Timestamps map[string]int64 `json:"timestamps"` //事件时间戳,key 是事件名称,value 是 UnixNano 时间戳
+
+ AProcessingLatencyNS *int64 `json:"a_processing_latency_ns,omitempty"` // A 处理时延:A_TX_SCHED - A_APP_PREP_BEGIN
+ AQueueLatencyNS *int64 `json:"a_queue_latency_ns,omitempty"` // A 排队时延:A_TX_SOFTWARE - A_TX_SCHED
+ ABTransportPropagationBQueueLatencyNS *int64 `json:"a_b_transport_propagation_b_queue_latency_ns,omitempty"` // A-B 传输时延 + B 排队时延:B_APP_RECV - A_TX_SOFTWARE
+ BKernelReceivePathLatencyNS *int64 `json:"b_kernel_receive_path_latency_ns,omitempty"` // B 内核接收路径近似:B_APP_RECV - B_RX_SOFTWARE
+ BProcessingLatencyNS *int64 `json:"b_processing_latency_ns,omitempty"` // B 处理时延:B_PERSIST_END - B_APP_RECV
+ EndToEndLatencyNS *int64 `json:"end_to_end_latency_ns,omitempty"` // 端到端时延:B_PERSIST_END - A_APP_PREP_BEGIN
+ MissingTimestamps []string `json:"missing_timestamps,omitempty"` // 缺失的时间戳列表,包含 requiredTimestampNames 中但在原始事件中没有的事件名称
+}
+
+// LoadEventsFromFiles 从JSONL 原始日志文件中加载事件。
+type messageKey struct {
+ MessageType protocol.MessageType //消息类型
+ MessageID uint64 //消息ID
+ From string //发送方
+ To string //接收方
+}
+
+// LoadEventsFromFiles 从多个 JSONL 原始日志文件中加载事件。
+func LoadEventsFromFiles(paths []string) ([]Event, error) {
+ var events []Event
+ for _, path := range paths {
+ fileEvents, err := LoadEventsFromFile(path)
+ if err != nil {
+ return nil, err
+ }
+ events = append(events, fileEvents...)
+ }
+
+ return events, nil
+}
+
+// LoadEventsFromFile 从单个 JSONL 原始日志文件中加载事件。
+func LoadEventsFromFile(path string) ([]Event, error) {
+ file, err := os.Open(path)
+ if err != nil {
+ return nil, fmt.Errorf("latencylog: open raw log %s: %w", path, err)
+ }
+ defer file.Close()
+
+ var events []Event
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ if len(scanner.Bytes()) == 0 {
+ continue
+ }
+
+ var event Event
+ if err := json.Unmarshal(scanner.Bytes(), &event); err != nil { //解析 JSONL 行失败,返回错误
+ return nil, fmt.Errorf("latencylog: decode event from %s: %w", path, err)
+ }
+ events = append(events, event)
+ }
+ if err := scanner.Err(); err != nil {
+ return nil, fmt.Errorf("latencylog: scan raw log %s: %w", path, err)
+ }
+
+ return events, nil
+}
+
+// SummarizeEvents 将原始事件整理成按消息分组的时延结果。
+func SummarizeEvents(events []Event) []Summary {
+ grouped := make(map[messageKey]*Summary)
+
+ for _, event := range events {
+ if !IsBusinessEvent(event) {
+ continue
+ }
+
+ key := messageKey{
+ MessageType: event.MessageType,
+ MessageID: event.MessageID,
+ From: event.From,
+ To: event.To,
+ }
+
+ summary, ok := grouped[key]
+ if !ok {
+ summary = &Summary{
+ MessageType: event.MessageType,
+ MessageID: event.MessageID,
+ From: event.From,
+ To: event.To,
+ FileName: event.FileName,
+ BodySize: event.BodySize,
+ Timestamps: make(map[string]int64),
+ }
+ grouped[key] = summary
+ }
+
+ if summary.FileName == "" {
+ summary.FileName = event.FileName
+ }
+ if event.BodySize > 0 {
+ summary.BodySize = event.BodySize
+ }
+
+ if existing, exists := summary.Timestamps[event.Event]; !exists || event.TsUnixNano < existing {
+ summary.Timestamps[event.Event] = event.TsUnixNano
+ }
+ }
+
+ summaries := make([]Summary, 0, len(grouped))
+ for _, summary := range grouped {
+ completeSummary(summary) //补全时延指标和缺失时间戳信息
+ summaries = append(summaries, *summary)
+ }
+ //对整理结果进行排序,先按发送方、再按接收方、再按消息 ID、最后按消息类型排序,保证输出的稳定性和可读性。
+ sort.Slice(summaries, func(i, j int) bool {
+ if summaries[i].From != summaries[j].From {
+ return summaries[i].From < summaries[j].From
+ }
+ if summaries[i].To != summaries[j].To {
+ return summaries[i].To < summaries[j].To
+ }
+ if summaries[i].MessageID != summaries[j].MessageID {
+ return summaries[i].MessageID < summaries[j].MessageID
+ }
+ return summaries[i].MessageType < summaries[j].MessageType
+ })
+
+ return summaries
+}
+
+// WriteSummariesJSONL 将整理结果写成 JSONL 汇总文件。
+func WriteSummariesJSONL(path string, summaries []Summary) error {
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ return fmt.Errorf("latencylog: create summary dir for %s: %w", path, err)
+ }
+
+ file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
+ if err != nil {
+ return fmt.Errorf("latencylog: open summary file %s: %w", path, err)
+ }
+ defer file.Close()
+
+ writer := bufio.NewWriter(file)
+ for _, summary := range summaries { //将每条整理结果编码成 JSONL 行并写入文件
+ line, err := json.Marshal(summary)
+ if err != nil {
+ return fmt.Errorf("latencylog: encode summary for message %d: %w", summary.MessageID, err)
+ }
+ if _, err := writer.Write(append(line, '\n')); err != nil {
+ return fmt.Errorf("latencylog: write summary file %s: %w", path, err)
+ }
+ }
+
+ if err := writer.Flush(); err != nil { //将缓冲区内容写入文件
+ return fmt.Errorf("latencylog: flush summary file %s: %w", path, err)
+ }
+
+ return nil
+}
+
+// completeSummary 根据事件时间戳计算时延指标,并找出缺失的时间戳。
+func completeSummary(summary *Summary) {
+ summary.MissingTimestamps = missingTimestampNames(summary.Timestamps)
+
+ if value := subtractIfPresent(summary.Timestamps, EventATXSched, EventAAppPrepBegin); value != nil {
+ summary.AProcessingLatencyNS = value
+ }
+ if value := subtractIfPresent(summary.Timestamps, EventATXSoftware, EventATXSched); value != nil {
+ summary.AQueueLatencyNS = value
+ }
+ if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventATXSoftware); value != nil {
+ summary.ABTransportPropagationBQueueLatencyNS = value
+ }
+ if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventBRXSoftware); value != nil {
+ summary.BKernelReceivePathLatencyNS = value
+ }
+ if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventBAppRecv); value != nil {
+ summary.BProcessingLatencyNS = value
+ }
+ if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventAAppPrepBegin); value != nil {
+ summary.EndToEndLatencyNS = value
+ }
+}
+
+// 返回 requiredTimestampNames 中哪些在给定的 timestamps 中缺失。
+func missingTimestampNames(timestamps map[string]int64) []string {
+ var missing []string
+ for _, name := range requiredTimestampNames {
+ if _, ok := timestamps[name]; !ok {
+ missing = append(missing, name)
+ }
+ }
+
+ return missing
+}
+
+// 如果 timestamps 中同时存在 endName 和 beginName,则返回它们的差值;否则返回 nil。
+func subtractIfPresent(timestamps map[string]int64, endName, beginName string) *int64 {
+ end, ok := timestamps[endName]
+ if !ok {
+ return nil
+ }
+ begin, ok := timestamps[beginName]
+ if !ok {
+ return nil
+ }
+
+ value := end - begin
+ return &value
+}
+
+// 判断事件是否是业务相关的时延事件(其中一项)
+func IsBusinessEvent(event Event) bool {
+ switch event.Event {
+ case EventAAppPrepBegin,
+ EventATXSched,
+ EventATXSoftware,
+ EventATXHardware,
+ EventBRXHardware,
+ EventBRXSoftware,
+ EventBAppRecv,
+ EventBPersistBegin,
+ EventBPersistEnd:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/cmd/internal/latencylog/summary_chart.go b/cmd/internal/latencylog/summary_chart.go
new file mode 100644
index 0000000..1953fe9
--- /dev/null
+++ b/cmd/internal/latencylog/summary_chart.go
@@ -0,0 +1,440 @@
+package latencylog
+
+import (
+ "bufio"
+ "fmt"
+ "html/template"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+const summaryChartHTMLTemplate = `
+
+
+
+
+ Latency Summary Chart
+
+
+
+
+ Latency Summary
+ A simple per-message end-to-end latency chart generated from summarized JSONL records.
+
+
+
+
Messages
+
{{.TotalMessages}}
+
+
+
With End-To-End
+
{{.MessagesWithEndToEnd}}
+
+
+
Average End-To-End
+
{{.AverageEndToEnd}}
+
+
+
Max End-To-End
+
{{.MaxEndToEnd}}
+
+
+
+
+ {{range .Legend}}
+
+
+ {{.Label}}
+
+ {{end}}
+
+
+ {{if .Rows}}
+
+ {{range .Rows}}
+
+
+
{{.Title}}
+
{{.EndToEnd}}
+
+ {{.Subtitle}}
+
+ {{range .Segments}}
+
+ {{end}}
+
+ {{if .Segments}}
+
+ {{range .Segments}}
+
+
+ {{.Label}} {{.Value}}
+
+ {{end}}
+
+ {{end}}
+ {{if .MissingTimestamps}}
+ Missing timestamps: {{.MissingTimestamps}}
+ {{end}}
+
+ {{end}}
+
+ {{else}}
+ No summarized messages were available for chart rendering.
+ {{end}}
+
+
+
+`
+
+type summaryChartPage struct {
+ TotalMessages int
+ MessagesWithEndToEnd int
+ AverageEndToEnd string
+ MaxEndToEnd string
+ Legend []summaryChartLegendItem
+ Rows []summaryChartRow
+}
+
+type summaryChartLegendItem struct {
+ Label string
+ Color string
+}
+
+type summaryChartRow struct {
+ Title string
+ Subtitle string
+ EndToEnd string
+ MissingTimestamps string
+ Segments []summaryChartSegment
+}
+
+type summaryChartSegment struct {
+ Label string
+ Value string
+ Color string
+ WidthPercent float64
+}
+
+type summaryChartSegmentMetric struct {
+ label string
+ value *int64
+ color string
+}
+
+// WriteSummariesHTMLChart 将整理结果写成一个可直接在浏览器中打开的简单 HTML 图表。
+func WriteSummariesHTMLChart(path string, summaries []Summary) error {
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ return fmt.Errorf("latencylog: create chart dir for %s: %w", path, err)
+ }
+
+ file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
+ if err != nil {
+ return fmt.Errorf("latencylog: open chart file %s: %w", path, err)
+ }
+ defer file.Close()
+
+ page := buildSummaryChartPage(summaries)
+ tmpl, err := template.New("summary-chart").Parse(summaryChartHTMLTemplate)
+ if err != nil {
+ return fmt.Errorf("latencylog: parse chart template: %w", err)
+ }
+
+ writer := bufio.NewWriter(file)
+ if err := tmpl.Execute(writer, page); err != nil {
+ return fmt.Errorf("latencylog: render chart %s: %w", path, err)
+ }
+ if err := writer.Flush(); err != nil {
+ return fmt.Errorf("latencylog: flush chart %s: %w", path, err)
+ }
+
+ return nil
+}
+
+func buildSummaryChartPage(summaries []Summary) summaryChartPage {
+ page := summaryChartPage{
+ TotalMessages: len(summaries),
+ Legend: []summaryChartLegendItem{
+ {Label: "A processing", Color: "var(--a-proc)"},
+ {Label: "A queue", Color: "var(--a-queue)"},
+ {Label: "Transport + B queue", Color: "var(--transport)"},
+ {Label: "B processing", Color: "var(--b-proc)"},
+ {Label: "Unknown / missing", Color: "var(--unknown)"},
+ },
+ Rows: make([]summaryChartRow, 0, len(summaries)),
+ }
+
+ var (
+ endToEndValues []int64
+ totalEndToEnd int64
+ maxEndToEnd int64
+ )
+
+ for _, summary := range summaries {
+ page.Rows = append(page.Rows, buildSummaryChartRow(summary))
+
+ if summary.EndToEndLatencyNS == nil {
+ continue
+ }
+ endToEnd := *summary.EndToEndLatencyNS
+ endToEndValues = append(endToEndValues, endToEnd)
+ totalEndToEnd += endToEnd
+ if endToEnd > maxEndToEnd {
+ maxEndToEnd = endToEnd
+ }
+ }
+
+ page.MessagesWithEndToEnd = len(endToEndValues)
+ page.AverageEndToEnd = "n/a"
+ page.MaxEndToEnd = "n/a"
+ if len(endToEndValues) > 0 {
+ page.AverageEndToEnd = formatLatencyNS(totalEndToEnd / int64(len(endToEndValues)))
+ page.MaxEndToEnd = formatLatencyNS(maxEndToEnd)
+ }
+
+ return page
+}
+
+func buildSummaryChartRow(summary Summary) summaryChartRow {
+ row := summaryChartRow{
+ Title: buildSummaryChartTitle(summary),
+ Subtitle: buildSummaryChartSubtitle(summary),
+ EndToEnd: "End-to-end: n/a",
+ MissingTimestamps: strings.Join(summary.MissingTimestamps, ", "),
+ }
+
+ if summary.EndToEndLatencyNS == nil || *summary.EndToEndLatencyNS <= 0 {
+ return row
+ }
+
+ total := *summary.EndToEndLatencyNS
+ row.EndToEnd = fmt.Sprintf("End-to-end: %s", formatLatencyNS(total))
+
+ metrics := []summaryChartSegmentMetric{
+ {label: "A processing", value: summary.AProcessingLatencyNS, color: "var(--a-proc)"},
+ {label: "A queue", value: summary.AQueueLatencyNS, color: "var(--a-queue)"},
+ {label: "Transport + B queue", value: summary.ABTransportPropagationBQueueLatencyNS, color: "var(--transport)"},
+ {label: "B processing", value: summary.BProcessingLatencyNS, color: "var(--b-proc)"},
+ }
+
+ var knownTotal int64
+ for _, metric := range metrics {
+ if metric.value == nil || *metric.value <= 0 {
+ continue
+ }
+ knownTotal += *metric.value
+ }
+
+ scaleTotal := total
+ if knownTotal > scaleTotal {
+ scaleTotal = knownTotal
+ }
+ if scaleTotal <= 0 {
+ return row
+ }
+
+ for _, metric := range metrics {
+ if metric.value == nil || *metric.value <= 0 {
+ continue
+ }
+ row.Segments = append(row.Segments, summaryChartSegment{
+ Label: metric.label,
+ Value: formatLatencyNS(*metric.value),
+ Color: metric.color,
+ WidthPercent: float64(*metric.value) * 100 / float64(scaleTotal),
+ })
+ }
+
+ if remaining := total - knownTotal; remaining > 0 {
+ row.Segments = append(row.Segments, summaryChartSegment{
+ Label: "Unknown / missing",
+ Value: formatLatencyNS(remaining),
+ Color: "var(--unknown)",
+ WidthPercent: float64(remaining) * 100 / float64(scaleTotal),
+ })
+ }
+
+ return row
+}
+
+func buildSummaryChartTitle(summary Summary) string {
+ if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" {
+ return fmt.Sprintf("%s #%d (%s)", summary.MessageType, summary.MessageID, summary.FileName)
+ }
+
+ return fmt.Sprintf("%s #%d", summary.MessageType, summary.MessageID)
+}
+
+func buildSummaryChartSubtitle(summary Summary) string {
+ parts := []string{
+ fmt.Sprintf("%s -> %s", summary.From, summary.To),
+ fmt.Sprintf("%d bytes", summary.BodySize),
+ }
+
+ if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" {
+ parts = append(parts, fmt.Sprintf("file: %s", summary.FileName))
+ }
+
+ return strings.Join(parts, " | ")
+}
+
+func formatLatencyNS(ns int64) string {
+ return fmt.Sprintf("%.3f ms", float64(ns)/1_000_000)
+}
diff --git a/cmd/internal/latencylog/summary_chart_test.go b/cmd/internal/latencylog/summary_chart_test.go
new file mode 100644
index 0000000..a7a7d0d
--- /dev/null
+++ b/cmd/internal/latencylog/summary_chart_test.go
@@ -0,0 +1,67 @@
+package latencylog
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+func TestWriteSummariesHTMLChart(t *testing.T) {
+ aProcessing := int64(20_000_000)
+ aQueue := int64(10_000_000)
+ transport := int64(40_000_000)
+ bProcessing := int64(30_000_000)
+ endToEnd := int64(100_000_000)
+
+ summaries := []Summary{
+ {
+ MessageType: protocol.MessageTypeText,
+ MessageID: 7,
+ From: "peer-a",
+ To: "peer-b",
+ BodySize: 5,
+ AProcessingLatencyNS: &aProcessing,
+ AQueueLatencyNS: &aQueue,
+ ABTransportPropagationBQueueLatencyNS: &transport,
+ BProcessingLatencyNS: &bProcessing,
+ EndToEndLatencyNS: &endToEnd,
+ },
+ {
+ MessageType: protocol.MessageTypeFile,
+ MessageID: 8,
+ From: "peer-b",
+ To: "peer-a",
+ FileName: "payload.bin",
+ BodySize: 128,
+ MissingTimestamps: []string{EventBRXSoftware},
+ },
+ }
+
+ path := filepath.Join(t.TempDir(), "charts", "latency-summary.html")
+ if err := WriteSummariesHTMLChart(path, summaries); err != nil {
+ t.Fatalf("WriteSummariesHTMLChart() error = %v", err)
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("os.ReadFile() error = %v", err)
+ }
+
+ content := string(data)
+ for _, want := range []string{
+ "Latency Summary",
+ "text #7",
+ "peer-a -> peer-b | 5 bytes",
+ "End-to-end: 100.000 ms",
+ "A processing 20.000 ms",
+ "file #8 (payload.bin)",
+ "Missing timestamps: B_RX_SOFTWARE",
+ } {
+ if !strings.Contains(content, want) {
+ t.Fatalf("chart content missing %q\n%s", want, content)
+ }
+ }
+}
diff --git a/cmd/internal/latencylog/summary_test.go b/cmd/internal/latencylog/summary_test.go
new file mode 100644
index 0000000..a7a0eb6
--- /dev/null
+++ b/cmd/internal/latencylog/summary_test.go
@@ -0,0 +1,144 @@
+package latencylog
+
+import (
+ "bufio"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "reflect"
+ "testing"
+
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+func TestSummarizeEventsComputesLatencyMetrics(t *testing.T) {
+ events := []Event{
+ {TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 230, Event: EventBPersistBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
+ }
+
+ summaries := SummarizeEvents(events)
+ if len(summaries) != 1 {
+ t.Fatalf("summary count = %d, want 1", len(summaries))
+ }
+
+ summary := summaries[0]
+ if got := ptrValue(summary.AProcessingLatencyNS); got != 20 {
+ t.Fatalf("AProcessingLatencyNS = %d, want 20", got)
+ }
+ if got := ptrValue(summary.AQueueLatencyNS); got != 20 {
+ t.Fatalf("AQueueLatencyNS = %d, want 20", got)
+ }
+ if got := ptrValue(summary.ABTransportPropagationBQueueLatencyNS); got != 80 {
+ t.Fatalf("ABTransportPropagationBQueueLatencyNS = %d, want 80", got)
+ }
+ if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 {
+ t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got)
+ }
+ if got := ptrValue(summary.BProcessingLatencyNS); got != 40 {
+ t.Fatalf("BProcessingLatencyNS = %d, want 40", got)
+ }
+ if got := ptrValue(summary.EndToEndLatencyNS); got != 160 {
+ t.Fatalf("EndToEndLatencyNS = %d, want 160", got)
+ }
+ if got := summary.Timestamps[EventBRXSoftware]; got != 180 {
+ t.Fatalf("timestamps[%q] = %d, want 180", EventBRXSoftware, got)
+ }
+ if len(summary.MissingTimestamps) != 0 {
+ t.Fatalf("MissingTimestamps = %v, want empty", summary.MissingTimestamps)
+ }
+}
+
+func TestSummarizeEventsReportsMissingTimestamps(t *testing.T) {
+ events := []Event{
+ {TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 240, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"},
+ }
+
+ summaries := SummarizeEvents(events)
+ if len(summaries) != 1 {
+ t.Fatalf("summary count = %d, want 1", len(summaries))
+ }
+
+ wantMissing := []string{EventATXSched, EventATXSoftware, EventBRXSoftware, EventBAppRecv}
+ if !reflect.DeepEqual(summaries[0].MissingTimestamps, wantMissing) {
+ t.Fatalf("MissingTimestamps = %v, want %v", summaries[0].MissingTimestamps, wantMissing)
+ }
+ if summaries[0].AProcessingLatencyNS != nil {
+ t.Fatalf("AProcessingLatencyNS = %v, want nil", ptrValue(summaries[0].AProcessingLatencyNS))
+ }
+ if summaries[0].EndToEndLatencyNS == nil {
+ t.Fatal("EndToEndLatencyNS = nil, want non-nil because endpoints are present")
+ }
+}
+
+func TestLoadAndWriteSummaryFiles(t *testing.T) {
+ rawPath := filepath.Join(t.TempDir(), "raw.jsonl")
+ rawLogger, err := NewJSONLLogger(rawPath)
+ if err != nil {
+ t.Fatalf("NewJSONLLogger() error = %v", err)
+ }
+ t.Cleanup(func() {
+ _ = rawLogger.Close()
+ })
+
+ for _, event := range []Event{
+ {TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
+ {TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
+ } {
+ if err := rawLogger.LogEvent(event); err != nil {
+ t.Fatalf("LogEvent() error = %v", err)
+ }
+ }
+
+ events, err := LoadEventsFromFile(rawPath)
+ if err != nil {
+ t.Fatalf("LoadEventsFromFile() error = %v", err)
+ }
+
+ summaryPath := filepath.Join(t.TempDir(), "summary.jsonl")
+ if err := WriteSummariesJSONL(summaryPath, SummarizeEvents(events)); err != nil {
+ t.Fatalf("WriteSummariesJSONL() error = %v", err)
+ }
+
+ file, err := os.Open(summaryPath)
+ if err != nil {
+ t.Fatalf("os.Open() error = %v", err)
+ }
+ defer file.Close()
+
+ scanner := bufio.NewScanner(file)
+ if !scanner.Scan() {
+ t.Fatal("expected one summary line, got none")
+ }
+
+ var summary Summary
+ if err := json.Unmarshal(scanner.Bytes(), &summary); err != nil {
+ t.Fatalf("json.Unmarshal() error = %v", err)
+ }
+ if summary.MessageID != 3 {
+ t.Fatalf("MessageID = %d, want 3", summary.MessageID)
+ }
+ if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 {
+ t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got)
+ }
+ if got := ptrValue(summary.EndToEndLatencyNS); got != 160 {
+ t.Fatalf("EndToEndLatencyNS = %d, want 160", got)
+ }
+}
+
+func ptrValue(value *int64) int64 {
+ if value == nil {
+ return 0
+ }
+ return *value
+}
diff --git a/cmd/internal/peer/client.go b/cmd/internal/peer/client.go
new file mode 100644
index 0000000..1c586fa
--- /dev/null
+++ b/cmd/internal/peer/client.go
@@ -0,0 +1,179 @@
+package peer
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "path/filepath"
+ "sync/atomic"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+ "omnisocketgo/cmd/internal/transport"
+)
+
+var dialServer = net.Dial
+
+type clientOptions struct {
+ logger latencylog.Logger
+}
+
+// Option 用于配置 Client 的可选行为,例如时延日志。
+type Option func(*clientOptions)
+
+// WithLogger 为 client 注入时延日志记录器。
+func WithLogger(logger latencylog.Logger) Option {
+ return func(options *clientOptions) {
+ options.logger = logger
+ }
+}
+
+// Client 表示一个已经连接到 server 的 peer。
+type Client struct {
+ id string
+ conn *transport.TCPConn
+ logger latencylog.Logger
+
+ nextID uint64
+}
+
+// Dial 连接到 server,并立即发送 register 消息完成身份注册。
+func Dial(serverAddr, peerID string, opts ...Option) (*Client, error) {
+ options := clientOptions{
+ logger: latencylog.NoopLogger{},
+ }
+ for _, opt := range opts {
+ opt(&options)
+ }
+ if options.logger == nil {
+ options.logger = latencylog.NoopLogger{}
+ }
+
+ rawConn, err := dialServer("tcp", serverAddr) //使用 net.Dial 连接到 serverAddr 指定的 TCP 地址,返回一个 net.Conn。
+ if err != nil {
+ return nil, fmt.Errorf("peer: dial server: %w", err)
+ }
+
+ conn, err := transport.NewTCPConn(
+ rawConn,
+ transport.WithLogger(options.logger, latencylog.NodeRolePeer, peerID),
+ )
+ if err != nil {
+ _ = rawConn.Close()
+ return nil, fmt.Errorf("peer: create transport conn: %w", err)
+ }
+ client := &Client{
+ id: peerID,
+ conn: conn,
+ logger: options.logger,
+ }
+
+ if err := conn.Send(protocol.Message{ //向 server 发送一条 register 消息,完成身份注册。
+ Type: protocol.MessageTypeRegister,
+ From: peerID,
+ To: protocol.ServerPeerID,
+ }); err != nil {
+ _ = conn.Close()
+ return nil, fmt.Errorf("peer: register with server: %w", err)
+ }
+
+ return client, nil
+}
+
+// ID 返回当前 client 的 peer 标识。
+func (c *Client) ID() string {
+ return c.id
+}
+
+// SendText 向目标 peer 发送一条文本消息。
+func (c *Client) SendText(to, body string) error {
+ msg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: c.nextMessageID(),
+ From: c.id,
+ To: to,
+ }
+ // 记录 A 端应用开始准备消息的时间点。
+ latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
+
+ msg.Body = []byte(body)
+
+ return c.conn.Send(msg)
+}
+
+// SendFile 向目标 peer 发送一条文件消息。
+func (c *Client) SendFile(to, fileName string, body []byte) error {
+ msg := protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: c.nextMessageID(),
+ From: c.id,
+ To: to,
+ FileName: fileName,
+ }
+ // 记录 A 端应用开始准备消息的时间点。
+ latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
+
+ bodyCopy := make([]byte, len(body))
+ copy(bodyCopy, body)
+
+ msg.Body = bodyCopy
+
+ return c.conn.Send(msg)
+}
+
+// SendFilePath 从本地文件读取内容并发送给目标 peer。
+func (c *Client) SendFilePath(to, path string) error {
+ msg := protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: c.nextMessageID(),
+ From: c.id,
+ To: to,
+ FileName: filepath.Base(path),
+ }
+ // 记录 A 端应用开始准备消息的时间点。
+ latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
+
+ body, err := os.ReadFile(path)
+ if err != nil {
+ return fmt.Errorf("peer: read file %s: %w", path, err)
+ }
+
+ msg.Body = body
+
+ return c.conn.Send(msg)
+}
+
+// Receive 读取一条来自 server 的消息。
+func (c *Client) Receive() (protocol.Message, error) {
+ msg, err := c.conn.Receive() //从底层 TCP 连接读取一条消息,返回一个 protocol.Message 结构体。
+ if err != nil {
+ return protocol.Message{}, fmt.Errorf("peer: receive from server: %w", err)
+ }
+
+ latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
+
+ return msg, nil
+}
+
+// ReceiveLoop 持续接收 server 消息并交给 handler 处理。
+func (c *Client) ReceiveLoop(handler func(protocol.Message) error) error {
+ return c.conn.ReceiveLoop(func(msg protocol.Message) error {
+ switch msg.Type {
+ case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError:
+ // 记录 B 端应用真正读到完整消息的时间点。
+ latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
+ return handler(msg)
+ default:
+ return fmt.Errorf("peer: unexpected message type from server: %s", msg.Type)
+ }
+ })
+}
+
+// Close 关闭与 server 的连接。
+func (c *Client) Close() error {
+ return c.conn.Close()
+}
+
+func (c *Client) nextMessageID() uint64 {
+ return atomic.AddUint64(&c.nextID, 1)
+}
diff --git a/cmd/internal/peer/client_linux_test.go b/cmd/internal/peer/client_linux_test.go
new file mode 100644
index 0000000..f5432af
--- /dev/null
+++ b/cmd/internal/peer/client_linux_test.go
@@ -0,0 +1,188 @@
+//go:build linux
+
+package peer
+
+import (
+ "net"
+ "strings"
+ "sync"
+ "testing"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+ "omnisocketgo/cmd/internal/server"
+)
+
+func TestClientsExchangeMessagesWithLinuxTimestamps(t *testing.T) {
+ hub := server.NewHub()
+ serverAddr, cleanup := startRealHubServer(t, hub)
+ defer cleanup()
+
+ peerALogger := &recordingLogger{}
+ peerA, err := Dial(serverAddr, "peer-a", WithLogger(peerALogger))
+ if err != nil {
+ t.Fatalf("Dial(peer-a) error = %v", err)
+ }
+ defer func() { _ = peerA.Close() }()
+
+ peerBLogger := &recordingLogger{}
+ peerB, err := Dial(serverAddr, "peer-b", WithLogger(peerBLogger))
+ if err != nil {
+ t.Fatalf("Dial(peer-b) error = %v", err)
+ }
+ defer func() { _ = peerB.Close() }()
+
+ inboxDir := t.TempDir()
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
+
+ if err := peerA.SendText("peer-b", "hello"); err != nil {
+ t.Fatalf("SendText() error = %v", err)
+ }
+ textMsg, err := peerB.Receive()
+ if err != nil {
+ t.Fatalf("peerB.Receive(text) error = %v", err)
+ }
+ if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
+ t.Fatalf("peerB.PersistMessage(text) error = %v", err)
+ }
+
+ if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil {
+ t.Fatalf("SendFile() error = %v", err)
+ }
+ fileMsg, err := peerB.Receive()
+ if err != nil {
+ t.Fatalf("peerB.Receive(file) error = %v", err)
+ }
+ if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
+ t.Fatalf("peerB.PersistMessage(file) error = %v", err)
+ }
+
+ waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 1, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a text kernel timestamps")
+ waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 2, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a file kernel timestamps")
+ waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 1, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b text receive timestamps")
+ waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 2, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b file receive timestamps")
+}
+
+func startRealHubServer(t *testing.T, hub *server.Hub) (string, func()) {
+ t.Helper()
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("net.Listen() error = %v", err)
+ }
+
+ var (
+ wg sync.WaitGroup
+ stop = make(chan struct{})
+ errOnce sync.Once
+ )
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for {
+ conn, acceptErr := listener.Accept()
+ if acceptErr != nil {
+ select {
+ case <-stop:
+ return
+ default:
+ }
+ if strings.Contains(acceptErr.Error(), "closed") {
+ return
+ }
+ t.Errorf("listener.Accept() error = %v", acceptErr)
+ return
+ }
+
+ wg.Add(1)
+ go func(rawConn net.Conn) {
+ defer wg.Done()
+ if serveErr := hub.ServeConn(rawConn); serveErr != nil && !isExpectedHubServeExit(serveErr) {
+ errOnce.Do(func() {
+ t.Logf("hub.ServeConn() ended with %v", serveErr)
+ })
+ }
+ }(conn)
+ }
+ }()
+
+ cleanup := func() {
+ close(stop)
+ _ = listener.Close()
+ wg.Wait()
+ }
+
+ return listener.Addr().String(), cleanup
+}
+
+func hasMessageEvents(events []latencylog.Event, messageID uint64, wantEvents ...string) bool {
+ seen := make(map[string]bool, len(wantEvents))
+ for _, event := range events {
+ if event.MessageID != messageID {
+ continue
+ }
+ if event.TsUnixNano <= 0 {
+ return false
+ }
+ seen[event.Event] = true
+ }
+
+ for _, wantEvent := range wantEvents {
+ if !seen[wantEvent] {
+ return false
+ }
+ }
+
+ return true
+}
+
+func isExpectedHubServeExit(err error) bool {
+ if err == nil {
+ return true
+ }
+
+ message := err.Error()
+ return strings.Contains(message, "closed") || strings.Contains(message, "protocol: read frame: EOF")
+}
+
+func TestLinuxTimestampedReceivePreservesBusinessMessageShape(t *testing.T) {
+ hub := server.NewHub()
+ serverAddr, cleanup := startRealHubServer(t, hub)
+ defer cleanup()
+
+ peerA, err := Dial(serverAddr, "peer-a")
+ if err != nil {
+ t.Fatalf("Dial(peer-a) error = %v", err)
+ }
+ defer func() { _ = peerA.Close() }()
+
+ peerB, err := Dial(serverAddr, "peer-b")
+ if err != nil {
+ t.Fatalf("Dial(peer-b) error = %v", err)
+ }
+ defer func() { _ = peerB.Close() }()
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
+
+ want := protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: []byte{0xde, 0xad, 0xbe, 0xef},
+ }
+ if err := peerA.SendFile(want.To, want.FileName, want.Body); err != nil {
+ t.Fatalf("SendFile() error = %v", err)
+ }
+
+ got, err := peerB.Receive()
+ if err != nil {
+ t.Fatalf("peerB.Receive() error = %v", err)
+ }
+ if got.Type != want.Type || got.ID != want.ID || got.From != want.From || got.To != want.To || got.FileName != want.FileName || string(got.Body) != string(want.Body) {
+ t.Fatalf("received message mismatch: got %+v want %+v", got, want)
+ }
+}
diff --git a/cmd/internal/peer/client_test.go b/cmd/internal/peer/client_test.go
new file mode 100644
index 0000000..cdf7cef
--- /dev/null
+++ b/cmd/internal/peer/client_test.go
@@ -0,0 +1,770 @@
+package peer
+
+import (
+ "bytes"
+ "encoding/json"
+ "net"
+ "os"
+ "path/filepath"
+ "reflect"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+ "omnisocketgo/cmd/internal/server"
+ "omnisocketgo/cmd/internal/transport"
+)
+
+type recordingLogger struct {
+ mu sync.Mutex
+ events []latencylog.Event
+}
+
+func (l *recordingLogger) LogEvent(event latencylog.Event) error {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ l.events = append(l.events, event)
+ return nil
+}
+
+func (l *recordingLogger) Events() []latencylog.Event {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ return append([]latencylog.Event(nil), l.events...)
+}
+
+type failingLogger struct{}
+
+func (failingLogger) LogEvent(latencylog.Event) error {
+ return net.ErrClosed
+}
+
+func TestDialRegistersPeer(t *testing.T) {
+ hub := server.NewHub()
+ cleanup := stubDialToHub(t, hub)
+ defer cleanup()
+
+ client, err := Dial("ignored", "peer-a")
+ if err != nil {
+ t.Fatalf("Dial() error = %v", err)
+ }
+ defer func() { _ = client.Close() }()
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+}
+
+func TestClientsExchangeTextAndFileMessages(t *testing.T) {
+ hub := server.NewHub()
+ cleanup := stubDialToHub(t, hub)
+ defer cleanup()
+
+ peerA, err := Dial("ignored", "peer-a")
+ if err != nil {
+ t.Fatalf("Dial(peer-a) error = %v", err)
+ }
+ defer func() { _ = peerA.Close() }()
+
+ peerB, err := Dial("ignored", "peer-b")
+ if err != nil {
+ t.Fatalf("Dial(peer-b) error = %v", err)
+ }
+ defer func() { _ = peerB.Close() }()
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
+
+ received := make(chan protocol.Message, 2)
+ receiveErr := make(chan error, 1)
+ go func() {
+ for i := 0; i < 2; i++ {
+ msg, err := peerB.Receive()
+ if err != nil {
+ receiveErr <- err
+ return
+ }
+ received <- msg
+ }
+ receiveErr <- nil
+ }()
+
+ if err := peerA.SendText("peer-b", "hello"); err != nil {
+ t.Fatalf("SendText() error = %v", err)
+ }
+ fileBody := []byte{0x01, 0x02, 0x03}
+ if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil {
+ t.Fatalf("SendFile() error = %v", err)
+ }
+
+ if err := <-receiveErr; err != nil {
+ t.Fatalf("peerB.Receive() error = %v", err)
+ }
+
+ gotFirst := <-received
+ wantFirst := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+ if !reflect.DeepEqual(gotFirst, wantFirst) {
+ t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst)
+ }
+
+ gotSecond := <-received
+ wantSecond := protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 2,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: fileBody,
+ }
+ if !reflect.DeepEqual(gotSecond, wantSecond) {
+ t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond)
+ }
+}
+
+func TestClientReceivesServerErrorForUnknownTarget(t *testing.T) {
+ hub := server.NewHub()
+ cleanup := stubDialToHub(t, hub)
+ defer cleanup()
+
+ client, err := Dial("ignored", "peer-a")
+ if err != nil {
+ t.Fatalf("Dial() error = %v", err)
+ }
+ defer func() { _ = client.Close() }()
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ if err := client.SendText("missing-peer", "hello"); err != nil {
+ t.Fatalf("SendText() error = %v", err)
+ }
+
+ got, err := client.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if got.Type != protocol.MessageTypeError {
+ t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
+ }
+ if string(got.Body) != "unknown target: missing-peer" {
+ t.Fatalf("error body = %q, want unknown target message", got.Body)
+ }
+}
+
+func TestClientReceiveLoopHandlesForwardedMessages(t *testing.T) {
+ hub := server.NewHub()
+ cleanup := stubDialToHub(t, hub)
+ defer cleanup()
+
+ peerA, err := Dial("ignored", "peer-a")
+ if err != nil {
+ t.Fatalf("Dial(peer-a) error = %v", err)
+ }
+ defer func() { _ = peerA.Close() }()
+
+ peerB, err := Dial("ignored", "peer-b")
+ if err != nil {
+ t.Fatalf("Dial(peer-b) error = %v", err)
+ }
+ defer func() { _ = peerB.Close() }()
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
+
+ var (
+ mu sync.Mutex
+ got []protocol.Message
+ )
+ loopErr := make(chan error, 1)
+ go func() {
+ loopErr <- peerB.ReceiveLoop(func(msg protocol.Message) error {
+ mu.Lock()
+ defer mu.Unlock()
+ got = append(got, msg)
+ if len(got) == 1 {
+ return peerB.Close()
+ }
+ return nil
+ })
+ }()
+
+ if err := peerA.SendText("peer-b", "hello"); err != nil {
+ t.Fatalf("SendText() error = %v", err)
+ }
+
+ err = <-loopErr
+ if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection")) {
+ t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ want := []protocol.Message{
+ {
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ },
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
+ }
+}
+
+func TestClientSendLogsLatencyEvents(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(*testing.T) string
+ send func(*Client, string) error
+ wantMsg protocol.Message
+ wantEvents []string
+ }{
+ {
+ name: "text",
+ send: func(client *Client, _ string) error {
+ return client.SendText("peer-b", "hello")
+ },
+ wantMsg: protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ },
+ wantEvents: []string{
+ latencylog.EventAAppPrepBegin,
+ latencylog.EventSendHandoffBegin,
+ latencylog.EventATXSched,
+ latencylog.EventATXSoftware,
+ latencylog.EventSendHandoffEnd,
+ },
+ },
+ {
+ name: "file-bytes",
+ send: func(client *Client, _ string) error {
+ return client.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03})
+ },
+ wantMsg: protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: []byte{0x01, 0x02, 0x03},
+ },
+ wantEvents: []string{
+ latencylog.EventAAppPrepBegin,
+ latencylog.EventSendHandoffBegin,
+ latencylog.EventATXSched,
+ latencylog.EventATXSoftware,
+ latencylog.EventSendHandoffEnd,
+ },
+ },
+ {
+ name: "file-path",
+ setup: func(t *testing.T) string {
+ t.Helper()
+
+ path := filepath.Join(t.TempDir(), "payload.bin")
+ if err := os.WriteFile(path, []byte{0x01, 0x02, 0x03}, 0o644); err != nil {
+ t.Fatalf("os.WriteFile() error = %v", err)
+ }
+
+ return path
+ },
+ send: func(client *Client, path string) error {
+ return client.SendFilePath("peer-b", path)
+ },
+ wantMsg: protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: []byte{0x01, 0x02, 0x03},
+ },
+ wantEvents: []string{
+ latencylog.EventAAppPrepBegin,
+ latencylog.EventSendHandoffBegin,
+ latencylog.EventATXSched,
+ latencylog.EventATXSoftware,
+ latencylog.EventSendHandoffEnd,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ inputPath := ""
+ if tt.setup != nil {
+ inputPath = tt.setup(t)
+ }
+
+ logger := &recordingLogger{}
+ clientConn, receiver := newClientTransportPair(
+ t,
+ []transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
+ nil,
+ )
+ client := &Client{
+ id: "peer-a",
+ conn: clientConn,
+ logger: logger,
+ }
+
+ sendErr := make(chan error, 1)
+ go func() {
+ sendErr <- tt.send(client, inputPath)
+ }()
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("receiver.Receive() error = %v", err)
+ }
+ if err := <-sendErr; err != nil {
+ t.Fatalf("send() error = %v", err)
+ }
+ if !reflect.DeepEqual(got, tt.wantMsg) {
+ t.Fatalf("message mismatch: got %+v want %+v", got, tt.wantMsg)
+ }
+
+ events := logger.Events()
+ if len(events) != len(tt.wantEvents) {
+ t.Fatalf("event count = %d, want %d", len(events), len(tt.wantEvents))
+ }
+ for i, wantEvent := range tt.wantEvents {
+ if events[i].Event != wantEvent {
+ t.Fatalf("event[%d] = %q, want %q", i, events[i].Event, wantEvent)
+ }
+ if events[i].MessageID != tt.wantMsg.ID || events[i].From != tt.wantMsg.From || events[i].To != tt.wantMsg.To {
+ t.Fatalf("event[%d] metadata mismatch: %+v", i, events[i])
+ }
+ }
+ })
+ }
+}
+
+func TestClientReceiveLogsOnlyBusinessMessages(t *testing.T) {
+ logger := &recordingLogger{}
+ clientConn, sender := newClientTransportPair(
+ t,
+ []transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-b")},
+ nil,
+ )
+ client := &Client{
+ id: "peer-b",
+ conn: clientConn,
+ logger: logger,
+ }
+
+ textMsg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 21,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+ sendErr := make(chan error, 1)
+ go func() {
+ sendErr <- sender.Send(textMsg)
+ }()
+ if _, err := client.Receive(); err != nil {
+ t.Fatalf("client.Receive(text) error = %v", err)
+ }
+ if err := <-sendErr; err != nil {
+ t.Fatalf("sender.Send(text) error = %v", err)
+ }
+
+ errorMsg := protocol.Message{
+ Type: protocol.MessageTypeError,
+ ID: 22,
+ From: protocol.ServerPeerID,
+ To: "peer-b",
+ Body: []byte("failure"),
+ }
+ sendErr = make(chan error, 1)
+ go func() {
+ sendErr <- sender.Send(errorMsg)
+ }()
+ if _, err := client.Receive(); err != nil {
+ t.Fatalf("client.Receive(error) error = %v", err)
+ }
+ if err := <-sendErr; err != nil {
+ t.Fatalf("sender.Send(error) error = %v", err)
+ }
+
+ events := logger.Events()
+ if len(events) != 2 {
+ t.Fatalf("event count = %d, want 2", len(events))
+ }
+ if events[0].Event != latencylog.EventBRXSoftware {
+ t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBRXSoftware)
+ }
+ if events[1].Event != latencylog.EventBAppRecv {
+ t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBAppRecv)
+ }
+ if events[0].MessageID != textMsg.ID || events[1].MessageID != textMsg.ID {
+ t.Fatalf("message IDs = %d,%d, want %d", events[0].MessageID, events[1].MessageID, textMsg.ID)
+ }
+}
+
+func TestClientPersistTextMessageWritesInboxFileAndLogs(t *testing.T) {
+ inboxDir := t.TempDir()
+ logger := &recordingLogger{}
+ client := &Client{
+ id: "peer-b",
+ logger: logger,
+ }
+
+ msg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 31,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+
+ path, err := client.PersistMessage(msg, inboxDir)
+ if err != nil {
+ t.Fatalf("PersistMessage() error = %v", err)
+ }
+
+ if path != filepath.Join(inboxDir, textInboxFileName) {
+ t.Fatalf("path = %q, want %q", path, filepath.Join(inboxDir, textInboxFileName))
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("os.ReadFile() error = %v", err)
+ }
+
+ var record textInboxRecord
+ if err := json.Unmarshal(bytes.TrimSpace(data), &record); err != nil {
+ t.Fatalf("json.Unmarshal() error = %v", err)
+ }
+ if record.MessageID != msg.ID || record.From != msg.From || record.To != msg.To || record.Body != "hello" {
+ t.Fatalf("record mismatch: got %+v want message %+v", record, msg)
+ }
+
+ events := logger.Events()
+ if len(events) != 2 {
+ t.Fatalf("event count = %d, want 2", len(events))
+ }
+ if events[0].Event != latencylog.EventBPersistBegin {
+ t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
+ }
+ if events[1].Event != latencylog.EventBPersistEnd {
+ t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
+ }
+}
+
+func TestClientPersistFileMessageWritesInboxFileAndLogs(t *testing.T) {
+ inboxDir := t.TempDir()
+ logger := &recordingLogger{}
+ client := &Client{
+ id: "peer-b",
+ logger: logger,
+ }
+
+ msg := protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 32,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: []byte{0x01, 0x02, 0x03},
+ }
+
+ path, err := client.PersistMessage(msg, inboxDir)
+ if err != nil {
+ t.Fatalf("PersistMessage() error = %v", err)
+ }
+
+ wantPath := filepath.Join(inboxDir, "peer-a-32-payload.bin")
+ if path != wantPath {
+ t.Fatalf("path = %q, want %q", path, wantPath)
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("os.ReadFile() error = %v", err)
+ }
+ if !reflect.DeepEqual(data, msg.Body) {
+ t.Fatalf("file body mismatch: got %v want %v", data, msg.Body)
+ }
+
+ events := logger.Events()
+ if len(events) != 2 {
+ t.Fatalf("event count = %d, want 2", len(events))
+ }
+ if events[0].Event != latencylog.EventBPersistBegin {
+ t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
+ }
+ if events[1].Event != latencylog.EventBPersistEnd {
+ t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
+ }
+}
+
+func TestClientPersistMessageReturnsErrorOnWriteFailure(t *testing.T) {
+ blocker := filepath.Join(t.TempDir(), "blocker")
+ if err := os.WriteFile(blocker, []byte("not a directory"), 0o644); err != nil {
+ t.Fatalf("os.WriteFile() error = %v", err)
+ }
+
+ logger := &recordingLogger{}
+ client := &Client{
+ id: "peer-b",
+ logger: logger,
+ }
+
+ msg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 33,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+
+ if _, err := client.PersistMessage(msg, blocker); err == nil {
+ t.Fatal("PersistMessage() error = nil, want non-nil")
+ }
+
+ events := logger.Events()
+ if len(events) != 1 {
+ t.Fatalf("event count = %d, want 1", len(events))
+ }
+ if events[0].Event != latencylog.EventBPersistBegin {
+ t.Fatalf("event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
+ }
+}
+
+func TestClientIgnoresLoggerFailure(t *testing.T) {
+ clientConn, receiver := newClientTransportPair(
+ t,
+ []transport.Option{transport.WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
+ nil,
+ )
+ client := &Client{
+ id: "peer-a",
+ conn: clientConn,
+ logger: failingLogger{},
+ }
+
+ sendErr := make(chan error, 1)
+ go func() {
+ sendErr <- client.SendText("peer-b", "hello")
+ }()
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("receiver.Receive() error = %v", err)
+ }
+ if err := <-sendErr; err != nil {
+ t.Fatalf("SendText() error = %v, want nil even if logger fails", err)
+ }
+ if string(got.Body) != "hello" {
+ t.Fatalf("body = %q, want hello", got.Body)
+ }
+}
+
+func TestClientPersistIgnoresLoggerFailure(t *testing.T) {
+ client := &Client{
+ id: "peer-b",
+ logger: failingLogger{},
+ }
+
+ msg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 34,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+
+ path, err := client.PersistMessage(msg, t.TempDir())
+ if err != nil {
+ t.Fatalf("PersistMessage() error = %v, want nil even if logger fails", err)
+ }
+ if path == "" {
+ t.Fatal("PersistMessage() path = empty, want non-empty")
+ }
+}
+
+func TestClientsExchangeMessagesWithLatencyLogs(t *testing.T) {
+ hub := server.NewHub()
+ cleanup := stubDialToHub(t, hub)
+ defer cleanup()
+
+ peerALogger := &recordingLogger{}
+ peerA, err := Dial("ignored", "peer-a", WithLogger(peerALogger))
+ if err != nil {
+ t.Fatalf("Dial(peer-a) error = %v", err)
+ }
+ defer func() { _ = peerA.Close() }()
+
+ peerBLogger := &recordingLogger{}
+ peerB, err := Dial("ignored", "peer-b", WithLogger(peerBLogger))
+ if err != nil {
+ t.Fatalf("Dial(peer-b) error = %v", err)
+ }
+ defer func() { _ = peerB.Close() }()
+
+ inboxDir := t.TempDir()
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
+
+ if err := peerA.SendText("peer-b", "hello"); err != nil {
+ t.Fatalf("SendText() error = %v", err)
+ }
+ textMsg, err := peerB.Receive()
+ if err != nil {
+ t.Fatalf("peerB.Receive(text) error = %v", err)
+ }
+ if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
+ t.Fatalf("peerB.PersistMessage(text) error = %v", err)
+ }
+
+ if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil {
+ t.Fatalf("SendFile() error = %v", err)
+ }
+ fileMsg, err := peerB.Receive()
+ if err != nil {
+ t.Fatalf("peerB.Receive(file) error = %v", err)
+ }
+ if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
+ t.Fatalf("peerB.PersistMessage(file) error = %v", err)
+ }
+
+ waitFor(t, func() bool { return len(peerALogger.Events()) == 10 }, "peer-a latency events")
+ waitFor(t, func() bool { return len(peerBLogger.Events()) == 8 }, "peer-b latency events")
+
+ assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{
+ 1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
+ 2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
+ })
+ assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{
+ 1: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
+ 2: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
+ })
+}
+
+func stubDialToHub(t *testing.T, hub *server.Hub) func() {
+ t.Helper()
+
+ originalDial := dialServer
+ serverAddr, cleanup := startRealHubServer(t, hub)
+
+ dialServer = func(network, addr string) (net.Conn, error) {
+ return net.Dial(network, serverAddr)
+ }
+
+ return func() {
+ dialServer = originalDial
+ cleanup()
+ }
+}
+
+func waitFor(t *testing.T, condition func() bool, description string) {
+ t.Helper()
+
+ deadline := time.Now().Add(500 * time.Millisecond)
+ for time.Now().Before(deadline) {
+ if condition() {
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ t.Fatalf("timed out waiting for %s", description)
+}
+
+func assertEventSequencesByMessage(t *testing.T, events []latencylog.Event, want map[uint64][]string) {
+ t.Helper()
+
+ grouped := make(map[uint64][]latencylog.Event)
+ for _, event := range events {
+ grouped[event.MessageID] = append(grouped[event.MessageID], event)
+ if event.TsUnixNano <= 0 {
+ t.Fatalf("event timestamp must be positive: %+v", event)
+ }
+ }
+
+ if len(grouped) != len(want) {
+ t.Fatalf("message group count = %d, want %d", len(grouped), len(want))
+ }
+
+ for messageID, wantEvents := range want {
+ gotEvents := grouped[messageID]
+ if len(gotEvents) != len(wantEvents) {
+ t.Fatalf("message %d event count = %d, want %d", messageID, len(gotEvents), len(wantEvents))
+ }
+ for i, wantEvent := range wantEvents {
+ if gotEvents[i].Event != wantEvent {
+ t.Fatalf("message %d event[%d] = %q, want %q", messageID, i, gotEvents[i].Event, wantEvent)
+ }
+ }
+ }
+}
+
+func newClientTransportPair(t *testing.T, clientOpts []transport.Option, peerOpts []transport.Option) (*transport.TCPConn, *transport.TCPConn) {
+ t.Helper()
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("net.Listen() error = %v", err)
+ }
+
+ type acceptResult struct {
+ conn net.Conn
+ err error
+ }
+
+ accepted := make(chan acceptResult, 1)
+ go func() {
+ conn, acceptErr := listener.Accept()
+ accepted <- acceptResult{conn: conn, err: acceptErr}
+ }()
+
+ clientSide, err := net.Dial("tcp", listener.Addr().String())
+ if err != nil {
+ _ = listener.Close()
+ t.Fatalf("net.Dial() error = %v", err)
+ }
+
+ result := <-accepted
+ if err := listener.Close(); err != nil {
+ t.Fatalf("listener.Close() error = %v", err)
+ }
+ if result.err != nil {
+ _ = clientSide.Close()
+ t.Fatalf("listener.Accept() error = %v", result.err)
+ }
+
+ clientConn, err := transport.NewTCPConn(clientSide, clientOpts...)
+ if err != nil {
+ _ = clientSide.Close()
+ _ = result.conn.Close()
+ t.Fatalf("transport.NewTCPConn(client) error = %v", err)
+ }
+ peerConn, err := transport.NewTCPConn(result.conn, peerOpts...)
+ if err != nil {
+ _ = clientConn.Close()
+ _ = result.conn.Close()
+ t.Fatalf("transport.NewTCPConn(peer) error = %v", err)
+ }
+
+ t.Cleanup(func() {
+ _ = clientConn.Close()
+ _ = peerConn.Close()
+ })
+
+ return clientConn, peerConn
+}
diff --git a/cmd/internal/peer/persist.go b/cmd/internal/peer/persist.go
new file mode 100644
index 0000000..cdd982c
--- /dev/null
+++ b/cmd/internal/peer/persist.go
@@ -0,0 +1,99 @@
+package peer
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+const textInboxFileName = "messages.log"
+
+type textInboxRecord struct {
+ MessageType protocol.MessageType `json:"message_type"`
+ MessageID uint64 `json:"message_id"`
+ From string `json:"from"`
+ To string `json:"to"`
+ Body string `json:"body"`
+}
+
+// PersistMessage 将收到的业务消息写入本地磁盘,并记录处理完成节点。
+func (c *Client) PersistMessage(msg protocol.Message, inboxDir string) (string, error) {
+ if !latencylog.IsBusinessMessage(msg) {
+ return "", fmt.Errorf("peer: cannot persist message type %s", msg.Type)
+ }
+ if inboxDir == "" {
+ return "", fmt.Errorf("peer: inbox directory is required")
+ }
+
+ latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistBegin, msg)
+
+ if err := os.MkdirAll(inboxDir, 0o755); err != nil {
+ return "", fmt.Errorf("peer: create inbox dir %s: %w", inboxDir, err)
+ }
+
+ path, err := persistMessageToDisk(msg, inboxDir)
+ if err != nil {
+ return "", err
+ }
+
+ latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistEnd, msg)
+
+ return path, nil
+}
+
+// persistMessageToDisk 根据消息类型将消息内容写入磁盘,文本消息追加到文本日志文件,文件消息写成独立文件。
+func persistMessageToDisk(msg protocol.Message, inboxDir string) (string, error) {
+ switch msg.Type {
+ case protocol.MessageTypeText:
+ return persistTextMessage(msg, inboxDir)
+ case protocol.MessageTypeFile:
+ return persistFileMessage(msg, inboxDir)
+ default:
+ return "", fmt.Errorf("peer: cannot persist unsupported message type %s", msg.Type)
+ }
+}
+
+// registerPeer 验证 peer ID 的合法性和唯一性,并将其与连接关联起来。
+func persistTextMessage(msg protocol.Message, inboxDir string) (string, error) {
+ record := textInboxRecord{
+ MessageType: msg.Type,
+ MessageID: msg.ID,
+ From: msg.From,
+ To: msg.To,
+ Body: string(msg.Body),
+ }
+
+ line, err := json.Marshal(record)
+ if err != nil {
+ return "", fmt.Errorf("peer: encode text inbox record: %w", err)
+ }
+
+ path := filepath.Join(inboxDir, textInboxFileName)
+ file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ return "", fmt.Errorf("peer: open text inbox %s: %w", path, err)
+ }
+ defer file.Close()
+
+ if _, err := file.Write(append(line, '\n')); err != nil {
+ return "", fmt.Errorf("peer: append text inbox %s: %w", path, err)
+ }
+
+ return path, nil
+}
+
+// persistFileMessage 将文件消息的内容写成独立文件,文件名包含发送方、消息 ID 和原始文件名,保证唯一性和可读性。
+func persistFileMessage(msg protocol.Message, inboxDir string) (string, error) {
+ fileName := filepath.Base(msg.FileName)
+ path := filepath.Join(inboxDir, fmt.Sprintf("%s-%d-%s", msg.From, msg.ID, fileName))
+
+ if err := os.WriteFile(path, msg.Body, 0o644); err != nil {
+ return "", fmt.Errorf("peer: write received file %s: %w", path, err)
+ }
+
+ return path, nil
+}
diff --git a/cmd/internal/protocol/codec.go b/cmd/internal/protocol/codec.go
new file mode 100644
index 0000000..fef658b
--- /dev/null
+++ b/cmd/internal/protocol/codec.go
@@ -0,0 +1,279 @@
+package protocol
+
+import (
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "unicode/utf8"
+)
+
+// MaxFrameSize 用于限制单个帧的最大长度,
+// 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。
+const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB
+
+var (
+ ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。
+ ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。
+ ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。
+ ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。
+ ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。
+ ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。
+ ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。
+ ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。
+ ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。
+ ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。
+ ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。
+ ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。
+ ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。
+ ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。
+)
+
+// 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON(下面自定义的Message头)][body bytes]
+// 写了 tag:JSON 字段名是你指定的 type;不写 tag:JSON 字段名默认是 Go 字段名 Type
+type messageHeader struct {
+ Type MessageType `json:"type"`
+ ID uint64 `json:"id"`
+ From string `json:"from"`
+ To string `json:"to"`
+ FileName string `json:"file_name,omitempty"`
+ ContentLength int `json:"content_length"`
+}
+
+// EncodeMessage 将逻辑消息编码为帧内字节格式:
+// 1. 4 字节大端序 header 长度
+// 2. header JSON
+// 3. 原始 body 字节
+func EncodeMessage(msg Message) ([]byte, error) {
+ if err := validateMessage(msg); err != nil {
+ return nil, err
+ }
+
+ header := messageHeader{
+ Type: msg.Type,
+ ID: msg.ID,
+ From: msg.From,
+ To: msg.To,
+ FileName: msg.FileName,
+ ContentLength: len(msg.Body),
+ }
+
+ headerPayload, err := json.Marshal(header)
+ if err != nil {
+ return nil, fmt.Errorf("protocol: encode header: %w", err)
+ }
+ // 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。
+ payload := make([]byte, 4+len(headerPayload)+len(msg.Body))
+ // 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON(第五个字节开始) 和 body。
+ binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload)))
+ copy(payload[4:], headerPayload)
+ copy(payload[4+len(headerPayload):], msg.Body)
+
+ //检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。
+ if len(payload) > MaxFrameSize {
+ return nil, ErrFrameTooLarge
+ }
+
+ return payload, nil
+}
+
+// DecodeMessage 将帧内字节格式还原为 Message。
+func DecodeMessage(data []byte) (Message, error) {
+ if len(data) > MaxFrameSize {
+ return Message{}, ErrFrameTooLarge
+ }
+ if len(data) < 4 {
+ return Message{}, ErrInvalidHeaderLength
+ }
+
+ headerLen := int(binary.BigEndian.Uint32(data[:4]))
+ if headerLen == 0 || headerLen > len(data)-4 {
+ return Message{}, ErrInvalidHeaderLength
+ }
+
+ headerPayload := data[4 : 4+headerLen]
+ body := data[4+headerLen:]
+
+ var header messageHeader
+ if err := json.Unmarshal(headerPayload, &header); err != nil {
+ return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err))
+ }
+
+ if header.ContentLength < 0 || header.ContentLength != len(body) {
+ return Message{}, ErrInvalidContentLength
+ }
+
+ bodyCopy := make([]byte, len(body))
+ copy(bodyCopy, body)
+
+ msg := Message{
+ Type: header.Type,
+ ID: header.ID,
+ From: header.From,
+ To: header.To,
+ FileName: header.FileName,
+ Body: bodyCopy,
+ }
+
+ if err := validateMessage(msg); err != nil {
+ return Message{}, err
+ }
+
+ return msg, nil
+}
+
+// WriteFrame 向流中写入一个带长度前缀的帧。
+// TCP帧格式如下:
+// 1. 4 字节大端序长度
+// 2. 后续 payload 内容
+//
+// TCP 是字节流协议,没有天然的消息边界。
+// 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节,
+// 从而解决粘包和拆包问题。
+func WriteFrame(w io.Writer, payload []byte) error {
+ size := len(payload)
+ //空帧
+ if size == 0 {
+ return ErrInvalidFrameLength
+ }
+ //帧过大
+ if size > MaxFrameSize {
+ return ErrFrameTooLarge
+ }
+
+ var header [4]byte
+ binary.BigEndian.PutUint32(header[:], uint32(size))
+
+ // 先写长度头,接收方才能根据长度一次性读取完整消息体。
+ if err := writeFull(w, header[:]); err != nil {
+ return err
+ }
+
+ return writeFull(w, payload)
+}
+
+// ReadFrame 从流中读取一个完整的长度前缀帧。
+// 它会先读取固定 4 字节长度头,校验长度是否合法,
+// 再使用 io.ReadFull 按长度读取完整消息体,
+// 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。
+func ReadFrame(r io.Reader) ([]byte, error) {
+ var header [4]byte
+ if _, err := io.ReadFull(r, header[:]); err != nil {
+ return nil, err
+ }
+
+ size := binary.BigEndian.Uint32(header[:])
+ // 长度为 0 的帧被认为是非法输入,而不是合法的空消息。
+ if size == 0 {
+ return nil, ErrInvalidFrameLength
+ }
+ // 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。
+ if size > MaxFrameSize {
+ return nil, ErrFrameTooLarge
+ }
+
+ payload := make([]byte, int(size))
+ if _, err := io.ReadFull(r, payload); err != nil {
+ return nil, err
+ }
+
+ return payload, nil
+}
+
+// WriteMessage 是给上层直接使用的完整发送路径:
+// 把一条结构化消息完整编码并发送出去”的总入口。
+// Message -> header+body -> 长度前缀帧 -> io.Writer。
+func WriteMessage(w io.Writer, msg Message) error {
+ payload, err := EncodeMessage(msg)
+ if err != nil {
+ return fmt.Errorf("protocol: encode message: %w", err)
+ }
+
+ if err := WriteFrame(w, payload); err != nil {
+ return fmt.Errorf("protocol: write frame: %w", err)
+ }
+
+ return nil
+}
+
+// ReadMessage 是给上层直接使用的完整接收路径:
+// io.Reader -> 长度前缀帧 -> header+body -> Message。
+func ReadMessage(r io.Reader) (Message, error) {
+ payload, err := ReadFrame(r)
+ if err != nil {
+ return Message{}, fmt.Errorf("protocol: read frame: %w", err)
+ }
+
+ msg, err := DecodeMessage(payload)
+ if err != nil {
+ return Message{}, fmt.Errorf("protocol: decode message: %w", err)
+ }
+
+ return msg, nil
+}
+
+// validateMessage 检查 Message 传输的类型(只接受 text 和 file )。
+func validateMessage(msg Message) error {
+ if msg.From == "" {
+ return ErrMissingFrom
+ }
+ if msg.To == "" {
+ return ErrMissingTo
+ }
+
+ switch msg.Type {
+ case MessageTypeText:
+ if msg.FileName != "" {
+ return ErrUnexpectedFileName
+ }
+ if !utf8.Valid(msg.Body) {
+ return ErrInvalidTextBody
+ }
+ case MessageTypeFile:
+ if msg.FileName == "" {
+ return ErrMissingFileName
+ }
+ case MessageTypeRegister:
+ if msg.To != ServerPeerID {
+ return ErrInvalidRegisterTarget
+ }
+ if msg.FileName != "" {
+ return ErrUnexpectedFileName
+ }
+ if len(msg.Body) != 0 {
+ return ErrUnexpectedBody
+ }
+ case MessageTypeError:
+ if msg.From != ServerPeerID {
+ return ErrInvalidErrorSource
+ }
+ if msg.FileName != "" {
+ return ErrUnexpectedFileName
+ }
+ if !utf8.Valid(msg.Body) {
+ return ErrInvalidTextBody
+ }
+ default:
+ return ErrInvalidMessageType
+ }
+
+ return nil
+}
+
+// writeFull 会持续写入,直到所有字节都写完或者底层返回错误。
+// 这样可以避免某些 Writer 发生部分写入时破坏帧格式。
+func writeFull(w io.Writer, data []byte) error {
+ for len(data) > 0 {
+ n, err := w.Write(data)
+ if err != nil {
+ return err
+ }
+ if n == 0 {
+ return io.ErrShortWrite
+ }
+ data = data[n:]
+ }
+
+ return nil
+}
diff --git a/cmd/internal/protocol/codec_test.go b/cmd/internal/protocol/codec_test.go
new file mode 100644
index 0000000..b229b36
--- /dev/null
+++ b/cmd/internal/protocol/codec_test.go
@@ -0,0 +1,507 @@
+package protocol
+
+import (
+ "bytes"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
+func TestEncodeDecodeMessageTextASCII(t *testing.T) {
+ original := Message{
+ Type: MessageTypeText,
+ ID: 42,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+
+ data, err := EncodeMessage(original)
+ if err != nil {
+ t.Fatalf("EncodeMessage() error = %v", err)
+ }
+
+ decoded, err := DecodeMessage(data)
+ if err != nil {
+ t.Fatalf("DecodeMessage() error = %v", err)
+ }
+
+ if !reflect.DeepEqual(decoded, original) {
+ t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
+ }
+}
+
+// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8,
+// 从而天然兼容 ASCII 之外的普通文本。
+func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
+ original := Message{
+ Type: MessageTypeText,
+ ID: 43,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("你好, world"),
+ }
+
+ data, err := EncodeMessage(original)
+ if err != nil {
+ t.Fatalf("EncodeMessage() error = %v", err)
+ }
+
+ decoded, err := DecodeMessage(data)
+ if err != nil {
+ t.Fatalf("DecodeMessage() error = %v", err)
+ }
+
+ if !reflect.DeepEqual(decoded, original) {
+ t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
+ }
+}
+
+// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
+func TestEncodeDecodeMessageFile(t *testing.T) {
+ original := Message{
+ Type: MessageTypeFile,
+ ID: 44,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "data.bin",
+ Body: []byte{0x00, 0xff, 0x10, 0x7f},
+ }
+
+ data, err := EncodeMessage(original)
+ if err != nil {
+ t.Fatalf("EncodeMessage() error = %v", err)
+ }
+
+ decoded, err := DecodeMessage(data)
+ if err != nil {
+ t.Fatalf("DecodeMessage() error = %v", err)
+ }
+
+ if !reflect.DeepEqual(decoded, original) {
+ t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
+ }
+}
+
+// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
+func TestEncodeDecodeMessageRegister(t *testing.T) {
+ original := Message{
+ Type: MessageTypeRegister,
+ ID: 45,
+ From: "peer-a",
+ To: ServerPeerID,
+ Body: []byte{},
+ }
+
+ data, err := EncodeMessage(original)
+ if err != nil {
+ t.Fatalf("EncodeMessage() error = %v", err)
+ }
+
+ decoded, err := DecodeMessage(data)
+ if err != nil {
+ t.Fatalf("DecodeMessage() error = %v", err)
+ }
+
+ if !reflect.DeepEqual(decoded, original) {
+ t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
+ }
+}
+
+// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
+func TestEncodeDecodeMessageError(t *testing.T) {
+ original := Message{
+ Type: MessageTypeError,
+ ID: 46,
+ From: ServerPeerID,
+ To: "peer-a",
+ Body: []byte("unknown target"),
+ }
+
+ data, err := EncodeMessage(original)
+ if err != nil {
+ t.Fatalf("EncodeMessage() error = %v", err)
+ }
+
+ decoded, err := DecodeMessage(data)
+ if err != nil {
+ t.Fatalf("DecodeMessage() error = %v", err)
+ }
+
+ if !reflect.DeepEqual(decoded, original) {
+ t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
+ }
+}
+
+// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
+// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
+func TestWriteReadFrame(t *testing.T) {
+ var buf bytes.Buffer
+ payload := []byte("header+body")
+
+ if err := WriteFrame(&buf, payload); err != nil {
+ t.Fatalf("WriteFrame() error = %v", err)
+ }
+
+ got, err := ReadFrame(&buf)
+ if err != nil {
+ t.Fatalf("ReadFrame() error = %v", err)
+ }
+
+ if !bytes.Equal(got, payload) {
+ t.Fatalf("payload mismatch: got %q want %q", got, payload)
+ }
+}
+
+// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
+// 因为外层帧非空的前提下,空正文是合法业务内容。
+func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
+ tests := []struct {
+ name string
+ message Message
+ }{
+ {
+ name: "empty text",
+ message: Message{
+ Type: MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte(""),
+ },
+ },
+ {
+ name: "empty file",
+ message: Message{
+ Type: MessageTypeFile,
+ ID: 2,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "empty.txt",
+ Body: []byte{},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var buf bytes.Buffer
+
+ if err := WriteMessage(&buf, tt.message); err != nil {
+ t.Fatalf("WriteMessage() error = %v", err)
+ }
+
+ got, err := ReadMessage(&buf)
+ if err != nil {
+ t.Fatalf("ReadMessage() error = %v", err)
+ }
+
+ if !reflect.DeepEqual(got, tt.message) {
+ t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
+ }
+ })
+ }
+}
+
+// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
+func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ message Message
+ wantErr error
+ }{
+ {
+ name: "invalid type",
+ message: Message{
+ Type: MessageType("unknown"),
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ },
+ wantErr: ErrInvalidMessageType,
+ },
+ {
+ name: "missing from",
+ message: Message{
+ Type: MessageTypeText,
+ ID: 2,
+ To: "peer-b",
+ },
+ wantErr: ErrMissingFrom,
+ },
+ {
+ name: "missing to",
+ message: Message{
+ Type: MessageTypeText,
+ ID: 3,
+ From: "peer-a",
+ },
+ wantErr: ErrMissingTo,
+ },
+ {
+ name: "text with file name",
+ message: Message{
+ Type: MessageTypeText,
+ ID: 4,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "bad.txt",
+ Body: []byte("hello"),
+ },
+ wantErr: ErrUnexpectedFileName,
+ },
+ {
+ name: "text with invalid utf8",
+ message: Message{
+ Type: MessageTypeText,
+ ID: 5,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte{0xff, 0xfe},
+ },
+ wantErr: ErrInvalidTextBody,
+ },
+ {
+ name: "file without file name",
+ message: Message{
+ Type: MessageTypeFile,
+ ID: 6,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte{0x01},
+ },
+ wantErr: ErrMissingFileName,
+ },
+ {
+ name: "register with wrong target",
+ message: Message{
+ Type: MessageTypeRegister,
+ ID: 7,
+ From: "peer-a",
+ To: "peer-b",
+ },
+ wantErr: ErrInvalidRegisterTarget,
+ },
+ {
+ name: "register with body",
+ message: Message{
+ Type: MessageTypeRegister,
+ ID: 8,
+ From: "peer-a",
+ To: ServerPeerID,
+ Body: []byte("unexpected"),
+ },
+ wantErr: ErrUnexpectedBody,
+ },
+ {
+ name: "error with wrong source",
+ message: Message{
+ Type: MessageTypeError,
+ ID: 9,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("bad"),
+ },
+ wantErr: ErrInvalidErrorSource,
+ },
+ {
+ name: "error with file name",
+ message: Message{
+ Type: MessageTypeError,
+ ID: 10,
+ From: ServerPeerID,
+ To: "peer-a",
+ FileName: "bad.txt",
+ Body: []byte("bad"),
+ },
+ wantErr: ErrUnexpectedFileName,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := EncodeMessage(tt.message)
+ if !errors.Is(err, tt.wantErr) {
+ t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
+// 而不是被当成一条合法的空消息。
+func TestReadFrameRejectsInvalidLength(t *testing.T) {
+ var buf bytes.Buffer
+
+ if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
+ t.Fatalf("binary.Write() error = %v", err)
+ }
+
+ _, err := ReadFrame(&buf)
+ if !errors.Is(err, ErrInvalidFrameLength) {
+ t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
+ }
+}
+
+// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
+// 从而保证最大长度限制真正生效。
+func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
+ var buf bytes.Buffer
+
+ if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
+ t.Fatalf("binary.Write() error = %v", err)
+ }
+
+ _, err := ReadFrame(&buf)
+ if !errors.Is(err, ErrFrameTooLarge) {
+ t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
+ }
+}
+
+// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
+// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
+func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
+ var buf bytes.Buffer
+
+ err := WriteFrame(&buf, nil)
+ if !errors.Is(err, ErrInvalidFrameLength) {
+ t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
+ }
+}
+
+// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
+func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ }{
+ {
+ name: "too short for header len",
+ data: []byte{0x00, 0x00, 0x00},
+ },
+ {
+ name: "zero header len",
+ data: []byte{0x00, 0x00, 0x00, 0x00},
+ },
+ {
+ name: "header len exceeds payload",
+ data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := DecodeMessage(tt.data)
+ if !errors.Is(err, ErrInvalidHeaderLength) {
+ t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
+ }
+ })
+ }
+}
+
+// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
+func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
+ data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
+
+ _, err := DecodeMessage(data)
+ if !errors.Is(err, ErrInvalidHeaderJSON) {
+ t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
+ }
+}
+
+// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
+func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
+ headerPayload, err := json.Marshal(messageHeader{
+ Type: MessageTypeText,
+ ID: 7,
+ From: "peer-a",
+ To: "peer-b",
+ ContentLength: 10,
+ })
+ if err != nil {
+ t.Fatalf("json.Marshal() error = %v", err)
+ }
+
+ var data bytes.Buffer
+ if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
+ t.Fatalf("binary.Write() error = %v", err)
+ }
+ if _, err := data.Write(headerPayload); err != nil {
+ t.Fatalf("data.Write(headerPayload) error = %v", err)
+ }
+ if _, err := data.Write([]byte("hello")); err != nil {
+ t.Fatalf("data.Write(body) error = %v", err)
+ }
+
+ _, err = DecodeMessage(data.Bytes())
+ if !errors.Is(err, ErrInvalidContentLength) {
+ t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
+ }
+}
+
+// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file,
+// 验证读取端每次都能严格停在当前帧边界,不会串包。
+func TestReadMultipleMessages(t *testing.T) {
+ var buf bytes.Buffer
+
+ first := Message{
+ Type: MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+
+ second := Message{
+ Type: MessageTypeFile,
+ ID: 2,
+ From: "peer-b",
+ To: "peer-a",
+ FileName: "payload.bin",
+ Body: []byte{0x01, 0x02, 0x03},
+ }
+
+ if err := WriteMessage(&buf, first); err != nil {
+ t.Fatalf("WriteMessage(first) error = %v", err)
+ }
+ if err := WriteMessage(&buf, second); err != nil {
+ t.Fatalf("WriteMessage(second) error = %v", err)
+ }
+
+ gotFirst, err := ReadMessage(&buf)
+ if err != nil {
+ t.Fatalf("ReadMessage(first) error = %v", err)
+ }
+ gotSecond, err := ReadMessage(&buf)
+ if err != nil {
+ t.Fatalf("ReadMessage(second) error = %v", err)
+ }
+
+ if !reflect.DeepEqual(gotFirst, first) {
+ t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
+ }
+ if !reflect.DeepEqual(gotSecond, second) {
+ t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
+ }
+}
+
+// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
+func TestReadMessageWrapsDecodeError(t *testing.T) {
+ var buf bytes.Buffer
+
+ if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
+ t.Fatalf("WriteFrame() error = %v", err)
+ }
+
+ _, err := ReadMessage(&buf)
+ if err == nil {
+ t.Fatal("ReadMessage() error = nil, want non-nil")
+ }
+ if !strings.Contains(err.Error(), "decode message") {
+ t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
+ }
+}
diff --git a/cmd/internal/protocol/message.go b/cmd/internal/protocol/message.go
new file mode 100644
index 0000000..5f5d28b
--- /dev/null
+++ b/cmd/internal/protocol/message.go
@@ -0,0 +1,33 @@
+package protocol
+
+// MessageType 表示一条消息的传输类型。
+// v1 只区分普通文本和文件两类负载。
+type MessageType string
+
+const (
+ // MessageTypeText 表示正文按 UTF-8 文本解释,天然兼容 ASCII。
+ MessageTypeText MessageType = "text"
+ // MessageTypeFile 表示正文是原始文件字节。
+ MessageTypeFile MessageType = "file"
+ // MessageTypeRegister 表示 peer 向 server 显式注册自己的身份。
+ MessageTypeRegister MessageType = "register"
+ // MessageTypeError 表示 server 向 peer 返回错误信息。
+ MessageTypeError MessageType = "error"
+)
+
+// ServerPeerID 是协议中约定的 server 端固定标识。
+const ServerPeerID = "server"
+
+// Message 是 peer 和 server 共用的传输消息结构。
+// 头部元信息会被编码为 JSON,Body 则作为原始字节拼接在头部之后。
+type Message struct {
+ Type MessageType `json:"type"` // 消息类型,只允许 text 或 file。
+ ID uint64 `json:"id"` // 由发送方生成,用于追踪消息。
+ From string `json:"from"` // 发送方标识。
+ To string `json:"to"` // 接收方标识。
+
+ // FileName 仅在 Type 为 file 时使用。
+ FileName string `json:"file_name,omitempty"`
+ // Body 是真正传输的正文内容,不进入头部 JSON。
+ Body []byte `json:"-"`
+}
diff --git a/cmd/internal/server/hub.go b/cmd/internal/server/hub.go
new file mode 100644
index 0000000..ed11bc3
--- /dev/null
+++ b/cmd/internal/server/hub.go
@@ -0,0 +1,192 @@
+package server
+
+import (
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+ "omnisocketgo/cmd/internal/transport"
+)
+
+const gracefulRejectCloseTimeout = 100 * time.Millisecond
+
+// Hub 管理已注册 peer 的连接,并负责在它们之间转发消息。
+type Hub struct {
+ mu sync.RWMutex
+ peers map[string]*transport.TCPConn
+ logger latencylog.Logger
+}
+
+// Option 用于配置 Hub 的可选行为,例如时延日志。
+type Option func(*Hub)
+
+// WithLogger 为 hub 注入时延日志记录器。
+func WithLogger(logger latencylog.Logger) Option {
+ return func(hub *Hub) {
+ hub.logger = logger
+ }
+}
+
+// NewHub 创建一个空的连接中心。
+func NewHub(opts ...Option) *Hub {
+ hub := &Hub{
+ peers: make(map[string]*transport.TCPConn),
+ logger: latencylog.NoopLogger{},
+ }
+
+ for _, opt := range opts {
+ opt(hub)
+ }
+
+ if hub.logger == nil {
+ hub.logger = latencylog.NoopLogger{}
+ }
+
+ return hub
+}
+
+// HasPeer 返回给定 ID 是否已经注册到 hub。
+func (h *Hub) HasPeer(peerID string) bool {
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+
+ _, ok := h.peers[peerID]
+ return ok
+}
+
+// ServeConn 处理一条新接入的底层 TCP 连接。
+// 连接上的第一条消息必须是 register,之后才允许转发 text/file。
+func (h *Hub) ServeConn(rawConn net.Conn) error {
+ conn, err := transport.NewTCPConn(rawConn)
+ if err != nil {
+ _ = rawConn.Close()
+ return fmt.Errorf("server: create transport conn: %w", err)
+ }
+
+ peerID, gracefulClose, err := h.registerConn(conn)
+ if err != nil {
+ h.closeConn(conn, gracefulClose)
+ return err
+ }
+ defer h.unregister(peerID, conn)
+
+ if err := h.receivePeerLoop(peerID, conn); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// registerConn 从新连接上读取第一条消息,验证它是 register 消息,并把连接注册到 hub。
+func (h *Hub) registerConn(conn *transport.TCPConn) (string, bool, error) {
+ msg, err := conn.Receive()
+ if err != nil {
+ return "", false, fmt.Errorf("server: receive register: %w", err)
+ }
+
+ if msg.Type != protocol.MessageTypeRegister {
+ if sendErr := sendServerError(conn, msg.From, "first message must be register"); sendErr != nil {
+ return "", false, fmt.Errorf("server: reject unregistered peer: %w", sendErr)
+ }
+ return "", true, fmt.Errorf("server: first message must be register, got %s", msg.Type)
+ }
+
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ if _, exists := h.peers[msg.From]; exists {
+ if sendErr := sendServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
+ return "", false, fmt.Errorf("server: duplicate peer id %s: %w", msg.From, sendErr)
+ }
+ return "", true, fmt.Errorf("server: duplicate peer id: %s", msg.From)
+ }
+
+ h.peers[msg.From] = conn
+ return msg.From, false, nil
+}
+
+// handlePeerMessage 验证消息类型并执行相应的转发或错误响应。
+func (h *Hub) handlePeerMessage(peerID string, conn *transport.TCPConn, msg protocol.Message) (bool, error) {
+ switch msg.Type {
+ case protocol.MessageTypeText, protocol.MessageTypeFile: //只允许已注册的 peer 发送文本或文件消息,其他类型都视为协议错误。
+ msg.From = peerID
+ targetConn, ok := h.lookup(msg.To)
+ if !ok {
+ return false, sendServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
+ }
+ if err := targetConn.Send(msg); err != nil { //转发消息,如果发送失败,说明目标连接可能已经不可用,此时从 hub 中注销该连接并关闭它,并向发送方返回错误响应。
+ h.unregister(msg.To, targetConn)
+ _ = targetConn.Close()
+ return false, sendServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
+ }
+ return false, nil
+ case protocol.MessageTypeRegister, protocol.MessageTypeError: //已注册的 peer 不允许再发送 register 或 error 消息,这些都视为协议错误。
+ if err := sendServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
+ return false, fmt.Errorf("server: send protocol error: %w", err)
+ }
+ return true, fmt.Errorf("server: unexpected message type from peer %s: %s", peerID, msg.Type)
+ default: // 其他任何消息类型都视为协议错误。
+ if err := sendServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
+ return false, fmt.Errorf("server: send unsupported type error: %w", err)
+ }
+ return true, fmt.Errorf("server: unsupported message type: %s", msg.Type)
+ }
+}
+
+func (h *Hub) receivePeerLoop(peerID string, conn *transport.TCPConn) error {
+ for {
+ msg, err := conn.Receive()
+ if err != nil {
+ _ = conn.Close()
+ return fmt.Errorf("transport: receive loop read: %w", err)
+ }
+
+ gracefulClose, err := h.handlePeerMessage(peerID, conn, msg)
+ if err != nil {
+ h.closeConn(conn, gracefulClose)
+ return fmt.Errorf("transport: receive loop handler: %w", err)
+ }
+ }
+}
+
+// lookup 在 hub 中查找目标 peer 的连接。
+func (h *Hub) lookup(peerID string) (*transport.TCPConn, bool) {
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+
+ conn, ok := h.peers[peerID]
+ return conn, ok
+}
+
+// unregister 从 hub 中移除指定 peer 的连接,通常在连接关闭或发生错误时调用。
+func (h *Hub) unregister(peerID string, conn *transport.TCPConn) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ current, ok := h.peers[peerID]
+ if ok && current == conn {
+ delete(h.peers, peerID)
+ }
+}
+
+func (h *Hub) closeConn(conn *transport.TCPConn, graceful bool) {
+ if graceful {
+ _ = conn.CloseGracefully(gracefulRejectCloseTimeout)
+ return
+ }
+
+ _ = conn.Close()
+}
+
+// sendServerError 是一个辅助函数,用于向指定 peer 发送错误消息。
+func sendServerError(conn *transport.TCPConn, to, message string) error {
+ return conn.Send(protocol.Message{
+ Type: protocol.MessageTypeError,
+ From: protocol.ServerPeerID,
+ To: to,
+ Body: []byte(message),
+ })
+}
diff --git a/cmd/internal/server/hub_test.go b/cmd/internal/server/hub_test.go
new file mode 100644
index 0000000..2cee305
--- /dev/null
+++ b/cmd/internal/server/hub_test.go
@@ -0,0 +1,398 @@
+package server
+
+import (
+ "net"
+ "reflect"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+ "omnisocketgo/cmd/internal/transport"
+)
+
+type recordingLogger struct {
+ mu sync.Mutex
+ events []latencylog.Event
+}
+
+func (l *recordingLogger) LogEvent(event latencylog.Event) error {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ l.events = append(l.events, event)
+ return nil
+}
+
+func (l *recordingLogger) Events() []latencylog.Event {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ return append([]latencylog.Event(nil), l.events...)
+}
+
+func TestServeConnRegistersPeer(t *testing.T) {
+ hub := NewHub()
+ client, done := startHubConn(t, hub)
+
+ if err := client.Send(protocol.Message{
+ Type: protocol.MessageTypeRegister,
+ From: "peer-a",
+ To: protocol.ServerPeerID,
+ }); err != nil {
+ t.Fatalf("Send(register) error = %v", err)
+ }
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ if err := client.Close(); err != nil {
+ t.Fatalf("client.Close() error = %v", err)
+ }
+
+ err := <-done
+ if err == nil || !strings.Contains(err.Error(), "receive loop read") {
+ t.Fatalf("ServeConn() error = %v, want read-loop shutdown error", err)
+ }
+}
+
+func TestServeConnRejectsDuplicatePeerID(t *testing.T) {
+ hub := NewHub()
+
+ first, firstDone := startHubConn(t, hub)
+ registerPeer(t, first, "peer-a")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ second, secondDone := startHubConn(t, hub)
+ registerPeer(t, second, "peer-a")
+
+ got, err := second.Receive()
+ if err != nil {
+ t.Fatalf("second.Receive() error = %v", err)
+ }
+ if got.Type != protocol.MessageTypeError {
+ t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
+ }
+ if string(got.Body) != "duplicate peer id: peer-a" {
+ t.Fatalf("error body = %q, want duplicate peer message", got.Body)
+ }
+
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "original peer-a to remain registered")
+
+ if err := first.Close(); err != nil {
+ t.Fatalf("first.Close() error = %v", err)
+ }
+
+ if err := <-secondDone; err == nil || !strings.Contains(err.Error(), "duplicate peer id") {
+ t.Fatalf("second ServeConn() error = %v, want duplicate peer id error", err)
+ }
+ if err := <-firstDone; err == nil || !strings.Contains(err.Error(), "receive loop read") {
+ t.Fatalf("first ServeConn() error = %v, want read-loop shutdown error", err)
+ }
+}
+
+func TestServeConnForwardsMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ msg protocol.Message
+ }{
+ {
+ name: "text",
+ msg: protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "spoofed",
+ To: "peer-b",
+ Body: []byte("hello"),
+ },
+ },
+ {
+ name: "file",
+ msg: protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 2,
+ From: "spoofed",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: []byte{0x01, 0x02, 0x03},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ hub := NewHub()
+
+ sender, senderDone := startHubConn(t, hub)
+ receiver, receiverDone := startHubConn(t, hub)
+ registerPeer(t, sender, "peer-a")
+ registerPeer(t, receiver, "peer-b")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
+
+ if err := sender.Send(tt.msg); err != nil {
+ t.Fatalf("sender.Send() error = %v", err)
+ }
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("receiver.Receive() error = %v", err)
+ }
+
+ want := tt.msg
+ want.From = "peer-a"
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("forwarded message mismatch: got %+v want %+v", got, want)
+ }
+
+ _ = sender.Close()
+ _ = receiver.Close()
+ <-senderDone
+ <-receiverDone
+ })
+ }
+}
+
+func TestServeConnReturnsErrorForUnknownTarget(t *testing.T) {
+ hub := NewHub()
+
+ client, done := startHubConn(t, hub)
+ registerPeer(t, client, "peer-a")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ if err := client.Send(protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "missing-peer",
+ Body: []byte("hello"),
+ }); err != nil {
+ t.Fatalf("Send(text) error = %v", err)
+ }
+
+ got, err := client.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if got.Type != protocol.MessageTypeError {
+ t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
+ }
+ if string(got.Body) != "unknown target: missing-peer" {
+ t.Fatalf("error body = %q, want unknown target message", got.Body)
+ }
+ if !hub.HasPeer("peer-a") {
+ t.Fatal("peer-a should remain registered after unknown target error")
+ }
+
+ _ = client.Close()
+ <-done
+}
+
+func TestServeConnRejectsRegisterAfterRegistration(t *testing.T) {
+ hub := NewHub()
+
+ client, done := startHubConn(t, hub)
+ registerPeer(t, client, "peer-a")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ if err := client.Send(protocol.Message{
+ Type: protocol.MessageTypeRegister,
+ From: "peer-a",
+ To: protocol.ServerPeerID,
+ }); err != nil {
+ t.Fatalf("Send(register again) error = %v", err)
+ }
+
+ got, err := client.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if got.Type != protocol.MessageTypeError {
+ t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
+ }
+ if string(got.Body) != "registered peers can only send text or file messages" {
+ t.Fatalf("error body = %q, want registered-peer protocol error", got.Body)
+ }
+
+ if err := <-done; err == nil || !strings.Contains(err.Error(), "unexpected message type from peer peer-a: register") {
+ t.Fatalf("ServeConn() error = %v, want unexpected register message error", err)
+ }
+}
+
+func TestServeConnUnregistersPeerOnClose(t *testing.T) {
+ hub := NewHub()
+
+ client, done := startHubConn(t, hub)
+ registerPeer(t, client, "peer-a")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ if err := client.Close(); err != nil {
+ t.Fatalf("client.Close() error = %v", err)
+ }
+ <-done
+
+ waitFor(t, func() bool { return !hub.HasPeer("peer-a") }, "peer-a to be unregistered")
+}
+
+func TestServeConnDoesNotEmitEndpointLatencyEventsOnForward(t *testing.T) {
+ logger := &recordingLogger{}
+ hub := NewHub(WithLogger(logger))
+
+ sender, senderDone := startHubConn(t, hub)
+ receiver, receiverDone := startHubConn(t, hub)
+ registerPeer(t, sender, "peer-a")
+ registerPeer(t, receiver, "peer-b")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
+
+ msg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 11,
+ From: "spoofed",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+ if err := sender.Send(msg); err != nil {
+ t.Fatalf("sender.Send() error = %v", err)
+ }
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("receiver.Receive() error = %v", err)
+ }
+ msg.From = "peer-a"
+ if !reflect.DeepEqual(got, msg) {
+ t.Fatalf("forwarded message mismatch: got %+v want %+v", got, msg)
+ }
+
+ events := logger.Events()
+ if len(events) != 0 {
+ t.Fatalf("event count = %d, want 0 because server is a black-box relay", len(events))
+ }
+
+ _ = sender.Close()
+ _ = receiver.Close()
+ <-senderDone
+ <-receiverDone
+}
+
+func TestServeConnDoesNotLogLatencyEventsForUnknownTarget(t *testing.T) {
+ logger := &recordingLogger{}
+ hub := NewHub(WithLogger(logger))
+
+ client, done := startHubConn(t, hub)
+ registerPeer(t, client, "peer-a")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ if err := client.Send(protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 15,
+ From: "peer-a",
+ To: "missing-peer",
+ Body: []byte("hello"),
+ }); err != nil {
+ t.Fatalf("Send(text) error = %v", err)
+ }
+
+ got, err := client.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if got.Type != protocol.MessageTypeError {
+ t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
+ }
+ if events := logger.Events(); len(events) != 0 {
+ t.Fatalf("event count = %d, want 0 for unknown target path", len(events))
+ }
+
+ _ = client.Close()
+ <-done
+}
+
+func TestServeConnDoesNotLogLatencyEventsForDuplicateRegister(t *testing.T) {
+ logger := &recordingLogger{}
+ hub := NewHub(WithLogger(logger))
+
+ first, firstDone := startHubConn(t, hub)
+ registerPeer(t, first, "peer-a")
+ waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
+
+ second, secondDone := startHubConn(t, hub)
+ registerPeer(t, second, "peer-a")
+
+ got, err := second.Receive()
+ if err != nil {
+ t.Fatalf("second.Receive() error = %v", err)
+ }
+ if got.Type != protocol.MessageTypeError {
+ t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
+ }
+ if events := logger.Events(); len(events) != 0 {
+ t.Fatalf("event count = %d, want 0 for duplicate register path", len(events))
+ }
+
+ _ = first.Close()
+ <-secondDone
+ <-firstDone
+}
+
+func startHubConn(t *testing.T, hub *Hub) (*transport.TCPConn, <-chan error) {
+ t.Helper()
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("net.Listen() error = %v", err)
+ }
+ done := make(chan error, 1)
+
+ go func() {
+ serverSide, acceptErr := listener.Accept()
+ if acceptErr != nil {
+ done <- acceptErr
+ return
+ }
+ done <- hub.ServeConn(serverSide)
+ }()
+
+ clientSide, err := net.Dial("tcp", listener.Addr().String())
+ if err != nil {
+ _ = listener.Close()
+ t.Fatalf("net.Dial() error = %v", err)
+ }
+ if err := listener.Close(); err != nil {
+ t.Fatalf("listener.Close() error = %v", err)
+ }
+
+ conn, err := transport.NewTCPConn(clientSide)
+ if err != nil {
+ _ = clientSide.Close()
+ t.Fatalf("transport.NewTCPConn() error = %v", err)
+ }
+
+ return conn, done
+}
+
+func registerPeer(t *testing.T, conn *transport.TCPConn, peerID string) {
+ t.Helper()
+
+ if err := conn.Send(protocol.Message{
+ Type: protocol.MessageTypeRegister,
+ From: peerID,
+ To: protocol.ServerPeerID,
+ }); err != nil {
+ t.Fatalf("Send(register %s) error = %v", peerID, err)
+ }
+}
+
+func waitFor(t *testing.T, condition func() bool, description string) {
+ t.Helper()
+
+ deadline := time.Now().Add(500 * time.Millisecond)
+ for time.Now().Before(deadline) {
+ if condition() {
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ t.Fatalf("timed out waiting for %s", description)
+}
diff --git a/cmd/internal/transport/tcp.go b/cmd/internal/transport/tcp.go
new file mode 100644
index 0000000..d0a0f8b
--- /dev/null
+++ b/cmd/internal/transport/tcp.go
@@ -0,0 +1,151 @@
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "syscall"
+ "time"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+// TCPConn 是对单条活跃 TCP 连接的轻量封装。
+// 它负责把协议层的单条消息读写,提升为可复用的收发接口。
+type TCPConn struct {
+ conn net.Conn
+ raw syscall.RawConn // 连接对应的底层 syscall 句柄,用于 Linux socket timestamping 收发。
+
+ logger latencylog.Logger
+ nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer"
+ nodeID string // 日志中记录的节点 ID,例如 peer 的 ID 或 server 的 "hub"
+ writeMu sync.Mutex // 保护 Send 方法的互斥锁,确保同一时刻只有一条完整协议消息被写入连接,防止多条消息字节交叉
+ closeOnce sync.Once // 保护 Close 方法的 sync.Once,确保连接只被关闭一次
+ closeErr error // 连接关闭时的错误,如果连接成功关闭则为 nil,重复调用 Close 时会返回同样的错误
+}
+
+// Option 用于为 TCPConn 注入可选行为,例如时延日志。
+type Option func(*TCPConn)
+
+// WithLogger 为连接发送路径注入业务消息日志上下文。
+func WithLogger(logger latencylog.Logger, nodeRole, nodeID string) Option {
+ return func(conn *TCPConn) {
+ conn.logger = logger
+ conn.nodeRole = nodeRole
+ conn.nodeID = nodeID
+ }
+}
+
+// NewTCPConn 用已有的 net.Conn 创建 transport 连接封装。
+func NewTCPConn(conn net.Conn, opts ...Option) (*TCPConn, error) {
+ tcpConn := &TCPConn{
+ conn: conn,
+ logger: latencylog.NoopLogger{},
+ }
+
+ for _, opt := range opts {
+ opt(tcpConn)
+ }
+
+ if tcpConn.logger == nil {
+ tcpConn.logger = latencylog.NoopLogger{}
+ }
+
+ if err := tcpConn.initLinuxTimestamping(); err != nil {
+ return nil, err
+ }
+
+ return tcpConn, nil
+}
+
+// Send 将一条协议消息完整写入底层连接。
+// 多个 goroutine 可以并发调用,内部会串行化写入。
+func (c *TCPConn) Send(msg protocol.Message) error {
+ c.writeMu.Lock() //“同一时刻只能有一条完整协议消息往连接里写,防止多条消息字节交叉
+ defer c.writeMu.Unlock()
+ latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
+
+ if err := c.sendMessageLinux(msg); err != nil {
+ return fmt.Errorf("transport: send message: %w", err)
+ }
+ //记录发送完成的时延日志事件,事件类型为 EventSendHandoffEnd,包含消息的基本信息(类型、ID、来源、目标)。
+ latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
+
+ return nil
+}
+
+// Receive 从底层连接读取一条完整协议消息。
+// 同一条连接应只由单个 reader 持续调用该方法。
+func (c *TCPConn) Receive() (protocol.Message, error) {
+ msg, err := c.receiveMessageLinux()
+ if err != nil {
+ return protocol.Message{}, fmt.Errorf("transport: receive message: %w", err)
+ }
+
+ return msg, nil
+}
+
+// ReceiveLoop 持续读取消息并交给 handler 处理。
+// 读取错误、handler 错误或连接关闭都会结束循环,并关闭连接。
+func (c *TCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
+ for {
+ msg, err := c.Receive()
+ if err != nil {
+ _ = c.Close()
+ return fmt.Errorf("transport: receive loop read: %w", err)
+ }
+
+ if err := handler(msg); err != nil {
+ _ = c.Close()
+ return fmt.Errorf("transport: receive loop handler: %w", err)
+ }
+ }
+}
+
+// CloseGracefully 在支持 half-close 的连接上先关闭写方向,给对端留出读取最终响应的机会,
+// 然后在短暂等待后再彻底关闭连接。
+func (c *TCPConn) CloseGracefully(drainTimeout time.Duration) error {
+ if closeWriter, ok := c.conn.(interface{ CloseWrite() error }); ok {
+ if err := closeWriter.CloseWrite(); err != nil && !errors.Is(err, net.ErrClosed) {
+ return c.Close()
+ }
+
+ if drainTimeout > 0 {
+ _ = c.conn.SetReadDeadline(time.Now().Add(drainTimeout))
+ defer func() {
+ _ = c.conn.SetReadDeadline(time.Time{})
+ }()
+
+ var buf [256]byte
+ for {
+ _, err := c.conn.Read(buf[:])
+ switch {
+ case err == nil:
+ continue
+ case errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
+ return c.Close()
+ default:
+ var netErr net.Error
+ if errors.As(err, &netErr) && netErr.Timeout() {
+ return c.Close()
+ }
+ return c.Close()
+ }
+ }
+ }
+ }
+
+ return c.Close()
+}
+
+// Close 关闭底层连接,并保证重复调用是安全的。
+func (c *TCPConn) Close() error {
+ c.closeOnce.Do(func() {
+ c.closeErr = c.conn.Close()
+ })
+
+ return c.closeErr
+}
diff --git a/cmd/internal/transport/tcp_linux.go b/cmd/internal/transport/tcp_linux.go
new file mode 100644
index 0000000..4264310
--- /dev/null
+++ b/cmd/internal/transport/tcp_linux.go
@@ -0,0 +1,462 @@
+//go:build linux
+
+package transport
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "syscall"
+ "time"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+const (
+ linuxTimestampControlBufferSize = 256 // 控制消息缓冲区。
+ linuxTXTimestampWaitTimeout = 250 * time.Millisecond // 等待 TX 时间戳的上限。
+ linuxTXTimestampPollInterval = time.Millisecond // 轮询 errqueue 的间隔。
+
+ linuxSOTimestampingNew = 0x41
+ linuxSCMTimestampingNew = linuxSOTimestampingNew
+ linuxSOEEOriginTimestamping = 4 // timestamping errqueue 事件。
+ linuxSCMTstampSnd = 0 // 对应 A_TX_SOFTWARE。
+ linuxSCMTstampSched = 1 // 对应 A_TX_SCHED。
+
+ linuxSOFTimestampingTXSoftware = 1 << 1 // 打开 TX software timestamp。
+ linuxSOFTimestampingRXSoftware = 1 << 3 // 打开 RX software timestamp。
+ linuxSOFTimestampingSoftware = 1 << 4 // software timestamp 总开关。
+ linuxSOFTimestampingOptID = 1 << 7 // 给时间戳关联 ID。
+ linuxSOFTimestampingTXSched = 1 << 8 // 打开 TX sched timestamp。
+ linuxSOFTimestampingOptTSONLY = 1 << 11 // 只回时间戳。
+ linuxSOFTimestampingOptIDTCP = 1 << 16 // 让 TCP 也带 timestamp ID。
+)
+
+// 拿到底层 fd,并打开 Linux timestamping。
+func (c *TCPConn) initLinuxTimestamping() error {
+ sysConn, ok := c.conn.(interface {
+ SyscallConn() (syscall.RawConn, error)
+ })
+ if !ok {
+ return fmt.Errorf("transport: connection does not support SyscallConn")
+ }
+
+ rawConn, err := sysConn.SyscallConn()
+ if err != nil || rawConn == nil {
+ if err != nil {
+ return fmt.Errorf("transport: get syscall conn: %w", err)
+ }
+ return fmt.Errorf("transport: missing syscall conn")
+ }
+
+ //socket是否可以成功打开 timestamping 取决于内核版本和配置,尝试多个 flag 组合直到成功或遇到非 EINVAL 错误。
+ if err := enableLinuxTimestamping(rawConn); err != nil {
+ return fmt.Errorf("transport: enable linux timestamping: %w", err)
+ }
+ //成功打开 timestamping 后,rawConn 就可以用来收 TX/RX 时间戳了。
+ c.raw = rawConn
+ return nil
+}
+
+// 给 socket开权限打开TX software timestamping。
+func enableLinuxTimestamping(rawConn syscall.RawConn) error {
+ flagCandidates := []int{ //不同linux版本可能支持不同的 flag 组合,尝试多个组合直到成功。
+ linuxSOFTimestampingTXSched |
+ linuxSOFTimestampingTXSoftware |
+ linuxSOFTimestampingRXSoftware |
+ linuxSOFTimestampingSoftware |
+ linuxSOFTimestampingOptID | //TCP 协议栈给每个时间戳生成一个序列号
+ linuxSOFTimestampingOptIDTCP |
+ linuxSOFTimestampingOptTSONLY,
+ linuxSOFTimestampingTXSched |
+ linuxSOFTimestampingTXSoftware |
+ linuxSOFTimestampingRXSoftware |
+ linuxSOFTimestampingSoftware |
+ linuxSOFTimestampingOptID |
+ linuxSOFTimestampingOptTSONLY,
+ linuxSOFTimestampingTXSched |
+ linuxSOFTimestampingTXSoftware |
+ linuxSOFTimestampingRXSoftware |
+ linuxSOFTimestampingSoftware |
+ linuxSOFTimestampingOptTSONLY,
+ }
+
+ var lastErr error
+ for _, flags := range flagCandidates { //尝试不同的 flag 组合,直到成功或遇到非 EINVAL 错误。
+ // 内核根据 fd 找到对应的内存结构体(Socket 缓冲区)
+ err := rawConn.Control(func(fd uintptr) { //Control 方法保证在回调里 fd 是有效的,可以安全地调用 syscall.SetsockoptInt。
+ lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
+ })
+ if err != nil {
+ return err
+ }
+ if lastErr == nil {
+ return nil
+ }
+ if !errors.Is(lastErr, syscall.EINVAL) {
+ return lastErr
+ }
+ }
+
+ return lastErr
+}
+
+// sendMessageLinux 编码消息、写完整帧,再记录 TX 时间戳。
+func (c *TCPConn) sendMessageLinux(msg protocol.Message) error {
+ payload, err := protocol.EncodeMessage(msg)
+ if err != nil {
+ return fmt.Errorf("protocol: encode message: %w", err)
+ }
+
+ //编码后的消息 payload 前面加 4 字节长度,构成完整帧。
+ frame := make([]byte, 4+len(payload))
+ binary.BigEndian.PutUint32(frame[:4], uint32(len(payload)))
+ copy(frame[4:], payload)
+
+ if err := c.writeFrameLinux(frame); err != nil {
+ return fmt.Errorf("protocol: write frame: %w", err)
+ }
+ //记录发送延时日志
+ c.logTXTimestampEvents(msg)
+ return nil
+}
+
+// writeFrameLinux 用 sendmsg 写完整帧。
+func (c *TCPConn) writeFrameLinux(frame []byte) error {
+ written := 0
+ var opErr error
+
+ err := c.raw.Write(func(fd uintptr) bool {
+ if written >= len(frame) {
+ return true
+ }
+
+ n, sendErr := syscall.SendmsgN(int(fd), frame[written:], nil, nil, 0)
+ switch {
+ case sendErr == nil:
+ if n <= 0 {
+ opErr = io.ErrShortWrite
+ return true
+ }
+ written += n
+ return written >= len(frame)
+ case errors.Is(sendErr, syscall.EAGAIN), errors.Is(sendErr, syscall.EWOULDBLOCK):
+ return false
+ default:
+ opErr = sendErr
+ return true
+ }
+ })
+ if err != nil {
+ return err
+ }
+ if opErr != nil {
+ return opErr
+ }
+ if written != len(frame) {
+ return io.ErrShortWrite
+ }
+
+ return nil
+}
+
+// 把 A_TX_SCHED / A_TX_SOFTWARE 写入日志。(发送过程中)
+func (c *TCPConn) logTXTimestampEvents(msg protocol.Message) {
+ timestamps := c.collectTXTimestampEvents()
+
+ if ts, ok := timestamps[latencylog.EventATXSched]; ok {
+ latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg)
+ }
+ if ts, ok := timestamps[latencylog.EventATXSoftware]; ok {
+ latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg)
+ }
+}
+
+// 在 errqueue 里等两类 TX 时间戳。
+func (c *TCPConn) collectTXTimestampEvents() map[string]int64 {
+ timestamps := make(map[string]int64, 2)
+ //设置合理等待上限
+ deadline := time.Now().Add(linuxTXTimestampWaitTimeout)
+
+ //轮询 errqueue 直到拿到两类时间戳,或超时,或遇到非 EAGAIN 错误。
+ for len(timestamps) < 2 && time.Now().Before(deadline) {
+ eventName, ts, err := c.recvTXTimestampOnce()
+ if err != nil {
+ if isWouldBlock(err) {
+ time.Sleep(linuxTXTimestampPollInterval)
+ continue
+ }
+ break
+ }
+ if eventName == "" || ts <= 0 {
+ continue
+ }
+ if _, exists := timestamps[eventName]; !exists {
+ timestamps[eventName] = ts
+ }
+ }
+
+ return timestamps
+}
+
+// recvTXTimestampOnce 从 errqueue 读一次时间戳事件。
+func (c *TCPConn) recvTXTimestampOnce() (string, int64, error) {
+ var (
+ eventName string // 事件名,例如 A_TX_SCHED 或 A_TX_SOFTWARE。
+ tsUnixNS int64 // 时间戳的 UnixNano 表示。
+ opErr error
+ )
+
+ err := c.raw.Control(func(fd uintptr) {
+ //设置足够大的 oob buffer 来接收控制消息,调用 recvmsg 从 errqueue 读一条消息。
+ oob := make([]byte, linuxTimestampControlBufferSize)
+ //recvmsg 的 flags 里必须带 MSG_ERRQUEUE,才能从 errqueue 里读消息,非阻塞模式下如果没有消息可读会返回 EAGAIN。
+ _, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
+ if recvErr != nil {
+ opErr = recvErr
+ return
+ }
+ //解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
+ eventName, tsUnixNS = parseTXTimestampControlMessages(oob[:oobn])
+ })
+ if err != nil {
+ return "", 0, err
+ }
+ if opErr != nil {
+ return "", 0, opErr
+ }
+
+ return eventName, tsUnixNS, nil //如果成功拿到时间戳事件,eventName 会是 A_TX_SCHED 或 A_TX_SOFTWARE 之一,tsUnixNS 是对应的时间戳;如果没有拿到事件或时间戳无效,eventName 会是空字符串,tsUnixNS 会是 0。
+}
+
+// 把底层时间戳映射成日志事件名。
+func parseTXTimestampControlMessages(oob []byte) (string, int64) {
+ if len(oob) == 0 {
+ return "", 0
+ }
+ //解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
+ controlMessages, err := syscall.ParseSocketControlMessage(oob)
+ if err != nil {
+ return "", 0
+ }
+
+ var (
+ tsUnixNS int64 //时间戳的 UnixNano 表示。
+ tsKind uint32 //extended err里,告诉我们这个时间戳是 sched 还是 software。
+ hasTS bool // 是否拿到时间戳了。
+ hasKind bool // 是否拿到时间戳类型了。
+ )
+ //一个 recvmsg 可能会收到多个控制消息,循环找我们关心的时间戳事件,拿到时间戳和事件类型。
+ for _, controlMessage := range controlMessages {
+ switch {
+ case controlMessage.Header.Level == syscall.SOL_SOCKET && controlMessage.Header.Type == linuxSCMTimestampingNew:
+ if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
+ tsUnixNS = ts
+ hasTS = true
+ }
+ case isSocketExtendedErr(controlMessage): //判断时间戳是否进入了errqueue,
+ if info, ok := parseSocketExtendedErrInfo(controlMessage.Data); ok {
+ tsKind = info //时间戳类型被内核放在 extended err 的附加信息里,解析出来。
+ hasKind = true
+ }
+ }
+ }
+
+ if !hasTS || !hasKind {
+ return "", 0
+ }
+
+ switch tsKind { //把内核的时间戳类型映射成日志事件名。(记录时只关心 sched 和 software 两类时间戳)
+ case linuxSCMTstampSched:
+ return latencylog.EventATXSched, tsUnixNS
+ case linuxSCMTstampSnd:
+ return latencylog.EventATXSoftware, tsUnixNS
+ default:
+ return "", 0
+ }
+}
+
+// 判断控制消息是否来自 socket extended err。
+// 内核产生的时间戳并不会混合在普通的数据流里,而是被包装成一种特殊的“错误消息”丢进 Error Queue。
+func isSocketExtendedErr(controlMessage syscall.SocketControlMessage) bool {
+ switch {
+ case controlMessage.Header.Level == syscall.SOL_IP && controlMessage.Header.Type == syscall.IP_RECVERR:
+ return true
+ case controlMessage.Header.Level == syscall.SOL_IPV6 && controlMessage.Header.Type == syscall.IPV6_RECVERR:
+ return true
+ default:
+ return false
+ }
+}
+
+// 从 socket extended err 的数据里取 origin timestamping 信息。
+func parseSocketExtendedErrInfo(data []byte) (uint32, bool) {
+ if len(data) < 16 {
+ return 0, false
+ }
+ if data[4] != linuxSOEEOriginTimestamping {
+ return 0, false
+ }
+
+ return binary.NativeEndian.Uint32(data[8:12]), true
+}
+
+// 读一条完整消息,并记录 B_RX_SOFTWARE。
+func (c *TCPConn) receiveMessageLinux() (protocol.Message, error) {
+ payload, rxTimestamp, err := c.readFrameLinux()
+ if err != nil {
+ return protocol.Message{}, fmt.Errorf("protocol: read frame: %w", err)
+ }
+
+ msg, err := protocol.DecodeMessage(payload)
+ if err != nil {
+ return protocol.Message{}, fmt.Errorf("protocol: decode message: %w", err)
+ }
+
+ if rxTimestamp > 0 {
+ latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg)
+ }
+
+ return msg, nil
+}
+
+// readFrameLinux 先读 4 字节长度,再读整条 payload。
+func (c *TCPConn) readFrameLinux() ([]byte, int64, error) {
+ var frameHeader [4]byte
+ rxTimestamp, err := c.readFullLinux(frameHeader[:])
+ if err != nil {
+ return nil, rxTimestamp, err
+ }
+
+ size := binary.BigEndian.Uint32(frameHeader[:])
+ switch {
+ case size == 0:
+ return nil, rxTimestamp, protocol.ErrInvalidFrameLength
+ case size > protocol.MaxFrameSize:
+ return nil, rxTimestamp, protocol.ErrFrameTooLarge
+ }
+
+ payload := make([]byte, int(size))
+ bodyTimestamp, err := c.readFullLinux(payload)
+ if rxTimestamp == 0 {
+ rxTimestamp = bodyTimestamp
+ }
+ if err != nil {
+ return nil, rxTimestamp, err
+ }
+
+ return payload, rxTimestamp, nil
+}
+
+// 读满 buf,并保留首个 RX_SOFTWARE(返回进入tcp协议栈的时间戳)。
+func (c *TCPConn) readFullLinux(buf []byte) (int64, error) {
+ if len(buf) == 0 {
+ return 0, nil
+ }
+
+ var (
+ offset int
+ firstRXTime int64
+ )
+
+ for offset < len(buf) {
+ n, rxTimestamp, err := c.recvmsgLinux(buf[offset:])
+ if firstRXTime == 0 && rxTimestamp > 0 {
+ firstRXTime = rxTimestamp
+ }
+ if err != nil {
+ if errors.Is(err, io.EOF) && offset > 0 {
+ return firstRXTime, io.ErrUnexpectedEOF
+ }
+ return firstRXTime, err
+ }
+
+ offset += n
+ }
+
+ return firstRXTime, nil
+}
+
+// recvmsgLinux 用 recvmsg 同时读取数据和控制消息。
+func (c *TCPConn) recvmsgLinux(buf []byte) (int, int64, error) {
+ var (
+ n int
+ rxTimeNS int64
+ opErr error
+ )
+
+ err := c.raw.Read(func(fd uintptr) bool {
+ oob := make([]byte, linuxTimestampControlBufferSize)
+ readN, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
+ switch {
+ case recvErr == nil:
+ if readN == 0 {
+ opErr = io.EOF
+ return true
+ }
+ n = readN
+ rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
+ return true
+ case errors.Is(recvErr, syscall.EAGAIN), errors.Is(recvErr, syscall.EWOULDBLOCK):
+ return false
+ default:
+ opErr = recvErr
+ return true
+ }
+ })
+ if err != nil {
+ return 0, 0, err
+ }
+ if opErr != nil {
+ return 0, 0, opErr
+ }
+
+ return n, rxTimeNS, nil
+}
+
+// 从控制消息里取 RX_SOFTWARE。
+func parseRXTimestampControlMessages(oob []byte) int64 {
+ if len(oob) == 0 {
+ return 0
+ }
+
+ controlMessages, err := syscall.ParseSocketControlMessage(oob)
+ if err != nil {
+ return 0
+ }
+
+ for _, controlMessage := range controlMessages {
+ if controlMessage.Header.Level != syscall.SOL_SOCKET || controlMessage.Header.Type != linuxSCMTimestampingNew {
+ continue
+ }
+
+ if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
+ return ts
+ }
+ }
+
+ return 0
+}
+
+// 取第一个非零 timespec。
+func parseSCMTimestampingData(data []byte) int64 {
+ const timespec64Size = 16
+
+ for offset := 0; offset+timespec64Size <= len(data); offset += timespec64Size {
+ sec := int64(binary.NativeEndian.Uint64(data[offset : offset+8]))
+ nsec := int64(binary.NativeEndian.Uint64(data[offset+8 : offset+16]))
+ if sec == 0 && nsec == 0 {
+ continue
+ }
+ return sec*int64(time.Second) + nsec
+ }
+
+ return 0
+}
+
+// 判断错误是否是 EAGAIN 或 EWOULDBLOCK。
+func isWouldBlock(err error) bool {
+ return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK)
+}
diff --git a/cmd/internal/transport/tcp_linux_test.go b/cmd/internal/transport/tcp_linux_test.go
new file mode 100644
index 0000000..5d99c20
--- /dev/null
+++ b/cmd/internal/transport/tcp_linux_test.go
@@ -0,0 +1,140 @@
+//go:build linux
+
+package transport
+
+import (
+ "net"
+ "reflect"
+ "testing"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+func TestLinuxTimestampingRecordsKernelEvents(t *testing.T) {
+ tests := []struct {
+ name string
+ msg protocol.Message
+ }{
+ {
+ name: "text",
+ msg: protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 41,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello over tcp"),
+ },
+ },
+ {
+ name: "file",
+ msg: protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 42,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: []byte{0x00, 0x01, 0x02, 0xff},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ clientConn, serverConn := newTCPPair(t)
+
+ senderLogger := &recordingLogger{}
+ receiverLogger := &recordingLogger{}
+ sender, err := NewTCPConn(
+ clientConn,
+ WithLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"),
+ )
+ if err != nil {
+ t.Fatalf("NewTCPConn(sender) error = %v", err)
+ }
+ receiver, err := NewTCPConn(
+ serverConn,
+ WithLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"),
+ )
+ if err != nil {
+ t.Fatalf("NewTCPConn(receiver) error = %v", err)
+ }
+ t.Cleanup(func() {
+ _ = sender.Close()
+ _ = receiver.Close()
+ })
+
+ sendErr := make(chan error, 1)
+ go func() {
+ sendErr <- sender.Send(tt.msg)
+ }()
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if err := <-sendErr; err != nil {
+ t.Fatalf("Send() error = %v", err)
+ }
+ if !reflect.DeepEqual(got, tt.msg) {
+ t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
+ }
+
+ assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID)
+ assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID)
+ assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID)
+ })
+ }
+}
+
+func newTCPPair(t *testing.T) (net.Conn, net.Conn) {
+ t.Helper()
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("net.Listen() error = %v", err)
+ }
+
+ type acceptResult struct {
+ conn net.Conn
+ err error
+ }
+
+ accepted := make(chan acceptResult, 1)
+ go func() {
+ conn, acceptErr := listener.Accept()
+ accepted <- acceptResult{conn: conn, err: acceptErr}
+ }()
+
+ clientConn, err := net.Dial("tcp", listener.Addr().String())
+ if err != nil {
+ _ = listener.Close()
+ t.Fatalf("net.Dial() error = %v", err)
+ }
+
+ result := <-accepted
+ if err := listener.Close(); err != nil {
+ t.Fatalf("listener.Close() error = %v", err)
+ }
+ if result.err != nil {
+ _ = clientConn.Close()
+ t.Fatalf("listener.Accept() error = %v", result.err)
+ }
+
+ return clientConn, result.conn
+}
+
+func assertHasEvent(t *testing.T, events []latencylog.Event, wantEvent string, wantMessageID uint64) {
+ t.Helper()
+
+ for _, event := range events {
+ if event.Event == wantEvent && event.MessageID == wantMessageID {
+ if event.TsUnixNano <= 0 {
+ t.Fatalf("event %s timestamp must be positive: %+v", wantEvent, event)
+ }
+ return
+ }
+ }
+
+ t.Fatalf("missing event %s for message %d in %+v", wantEvent, wantMessageID, events)
+}
diff --git a/cmd/internal/transport/tcp_test.go b/cmd/internal/transport/tcp_test.go
new file mode 100644
index 0000000..8d7d5ca
--- /dev/null
+++ b/cmd/internal/transport/tcp_test.go
@@ -0,0 +1,416 @@
+package transport
+
+import (
+ "errors"
+ "io"
+ "reflect"
+ "strings"
+ "sync"
+ "testing"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+type recordingLogger struct {
+ mu sync.Mutex
+ events []latencylog.Event
+}
+
+func (l *recordingLogger) LogEvent(event latencylog.Event) error {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ l.events = append(l.events, event)
+ return nil
+}
+
+func (l *recordingLogger) Events() []latencylog.Event {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ return append([]latencylog.Event(nil), l.events...)
+}
+
+type failingLogger struct{}
+
+func (failingLogger) LogEvent(latencylog.Event) error {
+ return errors.New("log failed")
+}
+
+// TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。
+func TestSendReceiveMessage(t *testing.T) {
+ tests := []struct {
+ name string
+ msg protocol.Message
+ }{
+ {
+ name: "text",
+ msg: protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ },
+ },
+ {
+ name: "file",
+ msg: protocol.Message{
+ Type: protocol.MessageTypeFile,
+ ID: 2,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "data.bin",
+ Body: []byte{0x00, 0x10, 0xff},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ sender, receiver := newTransportConnPair(t, nil, nil)
+ //创建一个容量为1的缓冲通道sendErr,用于接收发送操作的错误结果。
+ sendErr := make(chan error, 1)
+ go func() {
+ sendErr <- sender.Send(tt.msg) //发送消息,并将结果(错误或nil)发送到sendErr通道。
+ }()
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。
+ t.Fatalf("Send() error = %v", err)
+ }
+
+ if !reflect.DeepEqual(got, tt.msg) {
+ t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
+ }
+ })
+ }
+}
+
+func TestSendLogsHandoffEvents(t *testing.T) {
+ logger := &recordingLogger{}
+ sender, receiver := newTransportConnPair(
+ t,
+ []Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
+ nil,
+ )
+
+ msg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 7,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+
+ sendErr := make(chan error, 1)
+ go func() {
+ sendErr <- sender.Send(msg)
+ }()
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if err := <-sendErr; err != nil {
+ t.Fatalf("Send() error = %v", err)
+ }
+ if !reflect.DeepEqual(got, msg) {
+ t.Fatalf("message mismatch: got %+v want %+v", got, msg)
+ }
+
+ events := logger.Events()
+ if len(events) != 4 {
+ t.Fatalf("event count = %d, want 4", len(events))
+ }
+ if events[0].Event != latencylog.EventSendHandoffBegin {
+ t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
+ }
+ if events[1].Event != latencylog.EventATXSched {
+ t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched)
+ }
+ if events[2].Event != latencylog.EventATXSoftware {
+ t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware)
+ }
+ if events[3].Event != latencylog.EventSendHandoffEnd {
+ t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd)
+ }
+ for i, event := range events {
+ if event.MessageID != msg.ID {
+ t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID)
+ }
+ }
+ if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" {
+ t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a")
+ }
+ if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 {
+ t.Fatalf("timestamps must be positive: %+v", events)
+ }
+}
+
+func TestSendIgnoresLoggerFailure(t *testing.T) {
+ sender, receiver := newTransportConnPair(
+ t,
+ []Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
+ nil,
+ )
+
+ msg := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 9,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+
+ sendErr := make(chan error, 1)
+ go func() {
+ sendErr <- sender.Send(msg)
+ }()
+
+ got, err := receiver.Receive()
+ if err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+ if err := <-sendErr; err != nil {
+ t.Fatalf("Send() error = %v, want nil even if logger fails", err)
+ }
+ if !reflect.DeepEqual(got, msg) {
+ t.Fatalf("message mismatch: got %+v want %+v", got, msg)
+ }
+}
+
+// TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
+func TestReceiveLoopDeliversMessages(t *testing.T) {
+ sender, receiver := newTransportConnPair(t, nil, nil)
+
+ want := []protocol.Message{
+ {
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ },
+ {
+ Type: protocol.MessageTypeFile,
+ ID: 2,
+ From: "peer-a",
+ To: "peer-b",
+ FileName: "payload.bin",
+ Body: []byte{0x01, 0x02, 0x03},
+ },
+ }
+
+ var (
+ mu sync.Mutex
+ got []protocol.Message
+ )
+ loopErr := make(chan error, 1)
+ go func() {
+ loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
+ mu.Lock()
+ defer mu.Unlock()
+ got = append(got, msg)
+ return nil
+ })
+ }()
+
+ for _, msg := range want {
+ if err := sender.Send(msg); err != nil {
+ t.Fatalf("Send() error = %v", err)
+ }
+ }
+ if err := sender.Close(); err != nil {
+ t.Fatalf("sender.Close() error = %v", err)
+ }
+
+ err := <-loopErr
+ if err == nil {
+ t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
+ }
+ if !strings.Contains(err.Error(), "receive loop read") {
+ t.Fatalf("ReceiveLoop() error = %v, want read context", err)
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
+ }
+}
+
+// TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。
+func TestConcurrentSendKeepsMessagesIntact(t *testing.T) {
+ sender, receiver := newTransportConnPair(t, nil, nil)
+ // 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。
+ want := []protocol.Message{
+ {Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")},
+ {Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")},
+ {Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")},
+ {Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")},
+ }
+
+ received := make(chan protocol.Message, len(want))
+ readErr := make(chan error, 1)
+ go func() { //异步地运行一个 goroutine
+ for range want {
+ msg, err := receiver.Receive()
+ if err != nil {
+ readErr <- err
+ return
+ }
+ received <- msg
+ }
+ readErr <- nil
+ }()
+
+ var wg sync.WaitGroup
+ for _, msg := range want {
+ msg := msg
+ wg.Add(1)
+ go func() { //异步处理
+ defer wg.Done()
+ if err := sender.Send(msg); err != nil {
+ t.Errorf("Send() error = %v", err)
+ }
+ }()
+ }
+ wg.Wait()
+
+ if err := <-readErr; err != nil {
+ t.Fatalf("Receive() error = %v", err)
+ }
+
+ gotByID := make(map[uint64]protocol.Message, len(want))
+ for range want {
+ msg := <-received
+ gotByID[msg.ID] = msg
+ }
+
+ for _, msg := range want {
+ got, ok := gotByID[msg.ID]
+ if !ok {
+ t.Fatalf("missing message with ID %d", msg.ID)
+ }
+ if !reflect.DeepEqual(got, msg) {
+ t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg)
+ }
+ }
+}
+
+// TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。
+func TestReceiveLoopStopsOnHandlerError(t *testing.T) {
+ sender, receiver := newTransportConnPair(t, nil, nil)
+
+ wantErr := errors.New("stop loop")
+ loopErr := make(chan error, 1)
+ go func() {
+ loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
+ return wantErr
+ })
+ }()
+
+ first := protocol.Message{
+ Type: protocol.MessageTypeText,
+ ID: 1,
+ From: "peer-a",
+ To: "peer-b",
+ Body: []byte("hello"),
+ }
+ if err := sender.Send(first); err != nil {
+ t.Fatalf("Send(first) error = %v", err)
+ }
+
+ err := <-loopErr
+ if !errors.Is(err, wantErr) {
+ t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr)
+ }
+ if !strings.Contains(err.Error(), "receive loop handler") {
+ t.Fatalf("ReceiveLoop() error = %v, want handler context", err)
+ }
+}
+
+// TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。
+func TestReceiveLoopStopsOnReadError(t *testing.T) {
+ sender, receiver := newTransportConnPair(t, nil, nil)
+
+ loopErr := make(chan error, 1)
+ go func() {
+ loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
+ return nil
+ })
+ }()
+
+ if err := sender.Close(); err != nil {
+ t.Fatalf("sender.Close() error = %v", err)
+ }
+
+ err := <-loopErr
+ if err == nil {
+ t.Fatal("ReceiveLoop() error = nil, want non-nil")
+ }
+ if !strings.Contains(err.Error(), "receive loop read") {
+ t.Fatalf("ReceiveLoop() error = %v, want read context", err)
+ }
+}
+
+// TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。
+func TestCloseIsIdempotent(t *testing.T) {
+ conn, peer := newTransportConnPair(t, nil, nil)
+
+ if err := conn.Close(); err != nil {
+ t.Fatalf("Close(first) error = %v", err)
+ }
+ if err := conn.Close(); err != nil {
+ t.Fatalf("Close(second) error = %v, want nil", err)
+ }
+ if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") {
+ t.Fatalf("peer.Close() error = %v", err)
+ }
+}
+
+// TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。
+func TestReceiveReturnsWrappedReadError(t *testing.T) {
+ conn, peer := newTransportConnPair(t, nil, nil)
+ go func() {
+ _ = peer.Close()
+ }()
+
+ _, err := conn.Receive()
+ if err == nil {
+ t.Fatal("Receive() error = nil, want non-nil")
+ }
+ if !strings.Contains(err.Error(), "transport: receive message") {
+ t.Fatalf("Receive() error = %v, want wrapped receive error", err)
+ }
+ if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") {
+ t.Fatalf("Receive() error = %v, want underlying read failure", err)
+ }
+}
+
+func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) {
+ t.Helper()
+
+ left, right := newTCPPair(t)
+
+ sender, err := NewTCPConn(left, senderOpts...)
+ if err != nil {
+ t.Fatalf("NewTCPConn(sender) error = %v", err)
+ }
+ receiver, err := NewTCPConn(right, receiverOpts...)
+ if err != nil {
+ t.Fatalf("NewTCPConn(receiver) error = %v", err)
+ }
+
+ t.Cleanup(func() {
+ _ = sender.Close()
+ _ = receiver.Close()
+ })
+
+ return sender, receiver
+}
diff --git a/cmd/latencysummary/main.go b/cmd/latencysummary/main.go
new file mode 100644
index 0000000..76c4df1
--- /dev/null
+++ b/cmd/latencysummary/main.go
@@ -0,0 +1,59 @@
+package main
+
+import (
+ "flag"
+ "log"
+ "path/filepath"
+ "strings"
+
+ "omnisocketgo/cmd/internal/latencylog"
+)
+
+type stringListFlag []string
+
+func (f *stringListFlag) String() string {
+ return ""
+}
+
+func (f *stringListFlag) Set(value string) error {
+ *f = append(*f, value)
+ return nil
+}
+
+func main() {
+ var inputPaths stringListFlag
+ outputPath := flag.String("output", "latency-summary.jsonl", "output JSONL file for summarized latency metrics")
+ flag.Var(&inputPaths, "input", "raw latency JSONL file path; can be provided multiple times")
+ flag.Parse()
+
+ if len(inputPaths) == 0 {
+ log.Fatal("at least one -input raw latency log file is required")
+ }
+
+ events, err := latencylog.LoadEventsFromFiles(inputPaths)
+ if err != nil {
+ log.Fatalf("load raw latency logs: %v", err)
+ }
+
+ summaries := latencylog.SummarizeEvents(events)
+ if err := latencylog.WriteSummariesJSONL(*outputPath, summaries); err != nil {
+ log.Fatalf("write latency summary: %v", err)
+ }
+
+ chartPath := replaceFileExt(*outputPath, ".html")
+ if err := latencylog.WriteSummariesHTMLChart(chartPath, summaries); err != nil {
+ log.Fatalf("write latency chart: %v", err)
+ }
+
+ log.Printf("wrote %d summarized message records to %s", len(summaries), *outputPath)
+ log.Printf("wrote simple latency chart to %s", chartPath)
+}
+
+func replaceFileExt(path, ext string) string {
+ currentExt := filepath.Ext(path)
+ if currentExt == "" {
+ return path + ext
+ }
+
+ return strings.TrimSuffix(path, currentExt) + ext
+}
diff --git a/cmd/peer/interactive.go b/cmd/peer/interactive.go
new file mode 100644
index 0000000..8137bc0
--- /dev/null
+++ b/cmd/peer/interactive.go
@@ -0,0 +1,89 @@
+package main
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "strings"
+)
+
+const (
+ interactiveCommandHelp = "help"
+ interactiveCommandQuit = "quit"
+ interactiveCommandText = "text"
+ interactiveCommandFile = "file"
+)
+
+// 交互式命令行界面,允许用户在连接建立后反复发送文本或文件消息。
+var errEmptyInteractiveCommand = errors.New("interactive command is empty")
+
+type interactiveCommand struct {
+ name string
+ to string
+ value string
+}
+
+// 解析用户输入的交互式命令,支持发送文本或文件消息,以及查看帮助和退出。
+func parseInteractiveCommand(line string) (interactiveCommand, error) {
+ commandName, rest, ok := cutInteractiveField(strings.TrimSpace(line))
+ if !ok {
+ return interactiveCommand{}, errEmptyInteractiveCommand
+ }
+
+ switch strings.ToLower(commandName) {
+ case "help", "h", "?":
+ return interactiveCommand{name: interactiveCommandHelp}, nil
+ case "quit", "exit":
+ return interactiveCommand{name: interactiveCommandQuit}, nil
+ case interactiveCommandText:
+ to, body, err := parseInteractiveTargetValue(rest, interactiveCommandText)
+ if err != nil {
+ return interactiveCommand{}, err
+ }
+ return interactiveCommand{name: interactiveCommandText, to: to, value: body}, nil
+ case interactiveCommandFile:
+ to, path, err := parseInteractiveTargetValue(rest, interactiveCommandFile)
+ if err != nil {
+ return interactiveCommand{}, err
+ }
+ return interactiveCommand{name: interactiveCommandFile, to: to, value: path}, nil
+ default:
+ return interactiveCommand{}, fmt.Errorf("unknown command %q; type help for usage", commandName)
+ }
+}
+
+func parseInteractiveTargetValue(rest, commandName string) (string, string, error) {
+ to, value, ok := cutInteractiveField(strings.TrimSpace(rest))
+ if !ok {
+ return "", "", fmt.Errorf("%s command requires a target peer and payload", commandName)
+ }
+ if strings.TrimSpace(value) == "" {
+ return "", "", fmt.Errorf("%s command requires a non-empty payload", commandName)
+ }
+
+ return to, strings.TrimSpace(value), nil
+}
+
+func cutInteractiveField(input string) (string, string, bool) {
+ trimmed := strings.TrimSpace(input)
+ if trimmed == "" {
+ return "", "", false
+ }
+
+ for i, r := range trimmed {
+ if r == ' ' || r == '\t' {
+ return trimmed[:i], strings.TrimSpace(trimmed[i+1:]), true
+ }
+ }
+
+ return trimmed, "", true
+}
+
+// 打印交互式命令帮助信息,列出可用的命令和用法说明。
+func printInteractiveHelp(w io.Writer) {
+ _, _ = fmt.Fprintln(w, "interactive mode commands:")
+ _, _ = fmt.Fprintln(w, " help show this help")
+ _, _ = fmt.Fprintln(w, " text send one text message over the existing connection")
+ _, _ = fmt.Fprintln(w, " file send one file over the existing connection")
+ _, _ = fmt.Fprintln(w, " quit exit this peer process")
+}
diff --git a/cmd/peer/interactive_test.go b/cmd/peer/interactive_test.go
new file mode 100644
index 0000000..74abf87
--- /dev/null
+++ b/cmd/peer/interactive_test.go
@@ -0,0 +1,87 @@
+package main
+
+import "testing"
+
+func TestParseInteractiveCommand(t *testing.T) {
+ tests := []struct {
+ name string
+ line string
+ want interactiveCommand
+ wantErr string
+ }{
+ {
+ name: "text command preserves spaces in body",
+ line: "text peer-b hello over the same connection",
+ want: interactiveCommand{
+ name: interactiveCommandText,
+ to: "peer-b",
+ value: "hello over the same connection",
+ },
+ },
+ {
+ name: "file command preserves spaces in path",
+ line: "file peer-b /tmp/demo payload.bin",
+ want: interactiveCommand{
+ name: interactiveCommandFile,
+ to: "peer-b",
+ value: "/tmp/demo payload.bin",
+ },
+ },
+ {
+ name: "help alias",
+ line: "?",
+ want: interactiveCommand{
+ name: interactiveCommandHelp,
+ },
+ },
+ {
+ name: "quit alias",
+ line: "exit",
+ want: interactiveCommand{
+ name: interactiveCommandQuit,
+ },
+ },
+ {
+ name: "empty command",
+ line: " ",
+ wantErr: errEmptyInteractiveCommand.Error(),
+ },
+ {
+ name: "text requires payload",
+ line: "text peer-b",
+ wantErr: "text command requires a non-empty payload",
+ },
+ {
+ name: "file requires target and payload",
+ line: "file",
+ wantErr: "file command requires a target peer and payload",
+ },
+ {
+ name: "unknown command",
+ line: "ping peer-b",
+ wantErr: `unknown command "ping"; type help for usage`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := parseInteractiveCommand(tt.line)
+ if tt.wantErr != "" {
+ if err == nil {
+ t.Fatalf("parseInteractiveCommand(%q) error = nil, want %q", tt.line, tt.wantErr)
+ }
+ if err.Error() != tt.wantErr {
+ t.Fatalf("parseInteractiveCommand(%q) error = %q, want %q", tt.line, err.Error(), tt.wantErr)
+ }
+ return
+ }
+
+ if err != nil {
+ t.Fatalf("parseInteractiveCommand(%q) error = %v", tt.line, err)
+ }
+ if got != tt.want {
+ t.Fatalf("parseInteractiveCommand(%q) = %+v, want %+v", tt.line, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/cmd/peer/main.go b/cmd/peer/main.go
new file mode 100644
index 0000000..ab3244b
--- /dev/null
+++ b/cmd/peer/main.go
@@ -0,0 +1,173 @@
+package main
+
+import (
+ "bufio"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ peerpkg "omnisocketgo/cmd/internal/peer"
+ "omnisocketgo/cmd/internal/protocol"
+)
+
+func main() {
+ peerID := flag.String("id", "peer-a", "peer identity") // peer 的标识
+ serverAddr := flag.String("server", "127.0.0.1:9000", "server address") // server 的地址
+ targetPeer := flag.String("to", "", "optional target peer for one outgoing message") // 可选的目标 peer 标识
+ text := flag.String("text", "", "optional text to send after connecting") // 可选的文本消息内容
+ filePath := flag.String("file", "", "optional file path to send after connecting")
+ inboxDir := flag.String("inbox-dir", "inbox", "directory used to persist received text and file messages")
+ logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs")
+ interactive := flag.Bool("interactive", true, "enable interactive REPL for repeated text/file sends on the same connection")
+ flag.Parse()
+
+ clientOptions := make([]peerpkg.Option, 0, 1)
+ if *logPath != "" {
+ logger, err := latencylog.NewJSONLLogger(*logPath)
+ if err != nil {
+ log.Fatalf("create latency logger %s: %v", *logPath, err)
+ }
+ defer logger.Close()
+ clientOptions = append(clientOptions, peerpkg.WithLogger(logger))
+ }
+
+ client, err := peerpkg.Dial(*serverAddr, *peerID, clientOptions...)
+ if err != nil {
+ log.Fatalf("dial server %s: %v", *serverAddr, err)
+ }
+ defer client.Close()
+
+ log.Printf("connected to %s as %s", *serverAddr, client.ID())
+
+ receiveErr := make(chan error, 1)
+ go func() {
+ receiveErr <- client.ReceiveLoop(func(msg protocol.Message) error {
+ switch msg.Type {
+ case protocol.MessageTypeText:
+ path, err := client.PersistMessage(msg, *inboxDir)
+ if err != nil {
+ return err
+ }
+ log.Printf("received text from %s to %s and persisted to %s", msg.From, msg.To, path)
+ case protocol.MessageTypeFile:
+ path, err := client.PersistMessage(msg, *inboxDir)
+ if err != nil {
+ return err
+ }
+ log.Printf("received file from %s to %s: %s (%d bytes) -> %s", msg.From, msg.To, msg.FileName, len(msg.Body), path)
+ case protocol.MessageTypeError:
+ log.Printf("received %s from %s to %s: %s", msg.Type, msg.From, msg.To, string(msg.Body))
+ default:
+ log.Printf("received unexpected message type %s from %s", msg.Type, msg.From)
+ }
+ return nil
+ })
+ }()
+
+ if *text != "" && *filePath != "" {
+ log.Fatal("only one of -text or -file may be specified")
+ }
+
+ if (*text != "" || *filePath != "") && *targetPeer == "" {
+ log.Fatal("flag -to is required when sending text or file")
+ }
+
+ //如果指定了目标 peer 和文本消息内容,则向目标 peer 发送一条文本消息,如果发送失败,打印错误日志并退出。
+ if *targetPeer != "" && *text != "" {
+ if err := client.SendText(*targetPeer, *text); err != nil {
+ log.Fatalf("send text to %s: %v", *targetPeer, err)
+ }
+ log.Printf("sent text to %s", *targetPeer)
+ }
+
+ if *targetPeer != "" && *filePath != "" {
+ if err := client.SendFilePath(*targetPeer, *filePath); err != nil {
+ log.Fatalf("send file %s to %s: %v", *filePath, *targetPeer, err)
+ }
+ log.Printf("sent file %s to %s", *filePath, *targetPeer)
+ }
+
+ if *interactive {
+ if err := runInteractiveShell(client, os.Stdin, os.Stdout, receiveErr); err != nil {
+ log.Printf("interactive shell ended: %v", err)
+ }
+ return
+ }
+
+ if err := <-receiveErr; err != nil {
+ log.Printf("receive loop ended: %v", err)
+ }
+}
+
+func runInteractiveShell(client *peerpkg.Client, in io.Reader, out io.Writer, receiveErr <-chan error) error {
+ printInteractiveHelp(out)
+ lines, inputErr := readInteractiveLines(in, out, fmt.Sprintf("%s> ", client.ID()))
+
+ for {
+ select {
+ case err := <-receiveErr:
+ return err
+ case line, ok := <-lines:
+ if !ok {
+ return <-inputErr
+ }
+
+ command, err := parseInteractiveCommand(line)
+ if err != nil {
+ if err == errEmptyInteractiveCommand {
+ continue
+ }
+ log.Printf("interactive command error: %v", err)
+ continue
+ }
+
+ switch command.name {
+ case interactiveCommandHelp:
+ printInteractiveHelp(out)
+ case interactiveCommandQuit:
+ return nil
+ case interactiveCommandText:
+ if err := client.SendText(command.to, command.value); err != nil {
+ log.Printf("send text to %s: %v", command.to, err)
+ continue
+ }
+ log.Printf("sent text to %s", command.to)
+ case interactiveCommandFile:
+ if err := client.SendFilePath(command.to, command.value); err != nil {
+ log.Printf("send file %s to %s: %v", command.value, command.to, err)
+ continue
+ }
+ log.Printf("sent file %s to %s", command.value, command.to)
+ }
+ }
+ }
+}
+
+func readInteractiveLines(in io.Reader, out io.Writer, prompt string) (<-chan string, <-chan error) {
+ lines := make(chan string)
+ errs := make(chan error, 1)
+
+ go func() {
+ defer close(lines)
+
+ scanner := bufio.NewScanner(in)
+ scanner.Buffer(make([]byte, 0, 1024), 1024*1024)
+
+ for {
+ if _, err := fmt.Fprint(out, prompt); err != nil {
+ errs <- err
+ return
+ }
+ if !scanner.Scan() {
+ errs <- scanner.Err()
+ return
+ }
+ lines <- scanner.Text()
+ }
+ }()
+
+ return lines, errs
+}
diff --git a/cmd/server/main.go b/cmd/server/main.go
new file mode 100644
index 0000000..390e145
--- /dev/null
+++ b/cmd/server/main.go
@@ -0,0 +1,49 @@
+package main
+
+import (
+ "flag"
+ "log"
+ "net"
+
+ "omnisocketgo/cmd/internal/latencylog"
+ "omnisocketgo/cmd/internal/server"
+)
+
+func main() {
+ listenAddr := flag.String("listen", ":9000", "server listen address") //监听地址
+ logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs")
+ flag.Parse() //查看命令行参数
+
+ hubOptions := make([]server.Option, 0, 1)
+ if *logPath != "" {
+ logger, err := latencylog.NewJSONLLogger(*logPath)
+ if err != nil {
+ log.Fatalf("create latency logger %s: %v", *logPath, err)
+ }
+ defer logger.Close()
+ hubOptions = append(hubOptions, server.WithLogger(logger))
+ }
+
+ listener, err := net.Listen("tcp", *listenAddr) //开启tcp监听器,监听来自客户端的连接请求
+ if err != nil {
+ log.Fatalf("listen on %s: %v", *listenAddr, err)
+ }
+ defer listener.Close() //确保在 main 函数退出时关闭监听器
+
+ hub := server.NewHub(hubOptions...) //创建一个新的 Hub 实例,负责管理客户端连接和消息转发
+ log.Printf("server listening on %s", listener.Addr())
+
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ log.Printf("accept connection: %v", err)
+ continue
+ }
+
+ go func(rawConn net.Conn) {
+ if err := hub.ServeConn(rawConn); err != nil {
+ log.Printf("connection closed: %v", err)
+ }
+ }(conn)
+ }
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..c8fbf8c
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,3 @@
+module omnisocketgo
+
+go 1.22
diff --git a/latencysummary b/latencysummary
new file mode 100755
index 0000000..e8ecfe9
Binary files /dev/null and b/latencysummary differ