This commit is contained in:
nnbcccscdscdsc
2026-03-23 20:18:53 +08:00
commit 4824675244
28 changed files with 5569 additions and 0 deletions

View File

@@ -0,0 +1,166 @@
package latencylog
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"time"
"omnisocketgo/cmd/internal/protocol"
)
const (
NodeRolePeer = "peer" //客户端节点
NodeRoleServer = "server" //云端转发节点
)
// 记录的消息事件的类型常量。
const (
EventAAppPrepBegin = "A_APP_PREP_BEGIN" // A 端应用开始准备这条消息
EventATXSched = "A_TX_SCHED" // A 端进入 Linux qdisc 之前
EventATXSoftware = "A_TX_SOFTWARE" // A 端即将交给网卡驱动
EventATXHardware = "A_TX_HARDWARE" // A 端网卡真正发出到物理介质
EventBRXHardware = "B_RX_HARDWARE" // B 端网卡真正从物理介质收到
EventBRXSoftware = "B_RX_SOFTWARE" // B 端驱动把数据交给 Linux 接收栈
EventBAppRecv = "B_APP_RECV" // B 端应用真正读到完整消息
EventBPersistBegin = "B_PERSIST_BEGIN" // B 端开始写盘
EventBPersistEnd = "B_PERSIST_END" // B 端写盘完成
EventSendHandoffBegin = "send_handoff_begin" // 调试事件:应用把消息交给传输层开始
EventSendHandoffEnd = "send_handoff_end" // 调试事件:应用把消息交给传输层结束
)
// Event 是一条时延时间戳日志记录。
type Event struct {
TsUnixNano int64 `json:"ts_unix_nano"`
NodeRole string `json:"node_role"`
NodeID string `json:"node_id"`
Event string `json:"event"`
MessageType protocol.MessageType `json:"message_type"`
MessageID uint64 `json:"message_id"`
From string `json:"from"`
To string `json:"to"`
FileName string `json:"file_name,omitempty"`
BodySize int `json:"body_size"`
}
// Logger 负责接收事件并将其写入外部介质。
type Logger interface {
LogEvent(Event) error
}
// NoopLogger 是默认的空实现。
type NoopLogger struct{}
// LogEvent 对空日志实现始终返回 nil。
func (NoopLogger) LogEvent(Event) error {
return nil
}
// JSONLLogger 以 JSONL 形式追加写日志文件。
type JSONLLogger struct {
mu sync.Mutex
closeOnce sync.Once
closeErr error
file *os.File
}
// NewJSONLLogger 创建一个线程安全的 JSONL 文件日志器。
func NewJSONLLogger(path string) (*JSONLLogger, error) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return nil, err
}
return &JSONLLogger{file: file}, nil
}
// LogEvent 以单行 JSON 的形式追加一条事件。
func (l *JSONLLogger) LogEvent(event Event) error {
line, err := json.Marshal(event)
if err != nil {
return err
}
l.mu.Lock()
defer l.mu.Unlock()
if _, err := l.file.Write(append(line, '\n')); err != nil {
return err
}
return nil
}
// Close 关闭底层文件;重复调用是安全的。
func (l *JSONLLogger) Close() error {
l.closeOnce.Do(func() {
l.closeErr = l.file.Close()
})
return l.closeErr
}
// IsBusinessMessage 判断消息是否属于要参与 A-C-B 时延分析的业务消息。
func IsBusinessMessage(msg protocol.Message) bool {
switch msg.Type {
case protocol.MessageTypeText, protocol.MessageTypeFile:
return true
default:
return false
}
}
// NewMessageEvent 用当前 UTC 时间为一条业务消息构造事件。
func NewMessageEvent(nodeRole, nodeID, eventName string, msg protocol.Message) Event {
return NewMessageEventAt(time.Now().UTC().UnixNano(), nodeRole, nodeID, eventName, msg)
}
// NewMessageEventAt 用指定的 UnixNano 时间为一条业务消息构造事件。
func NewMessageEventAt(tsUnixNano int64, nodeRole, nodeID, eventName string, msg protocol.Message) Event {
return Event{
TsUnixNano: tsUnixNano,
NodeRole: nodeRole,
NodeID: nodeID,
Event: eventName,
MessageType: msg.Type,
MessageID: msg.ID,
From: msg.From,
To: msg.To,
FileName: msg.FileName,
BodySize: len(msg.Body),
}
}
// LogBestEffort 写一条事件,失败时静默忽略,避免打断主收发流程。
func LogBestEffort(logger Logger, event Event) {
if logger == nil {
return
}
_ = logger.LogEvent(event)
}
// LogMessageEvent 为业务消息构造并写入一条事件。
func LogMessageEvent(logger Logger, nodeRole, nodeID, eventName string, msg protocol.Message) {
if !IsBusinessMessage(msg) {
return
}
LogBestEffort(logger, NewMessageEvent(nodeRole, nodeID, eventName, msg))
}
// LogMessageEventAt 为业务消息写入一条指定时间戳的事件。
func LogMessageEventAt(logger Logger, nodeRole, nodeID, eventName string, tsUnixNano int64, msg protocol.Message) {
if !IsBusinessMessage(msg) {
return
}
LogBestEffort(logger, NewMessageEventAt(tsUnixNano, nodeRole, nodeID, eventName, msg))
}

View File

@@ -0,0 +1,131 @@
package latencylog
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"sync"
"testing"
"omnisocketgo/cmd/internal/protocol"
)
func TestJSONLLoggerWritesOneEventPerLine(t *testing.T) {
path := filepath.Join(t.TempDir(), "latency.jsonl")
logger, err := NewJSONLLogger(path)
if err != nil {
t.Fatalf("NewJSONLLogger() error = %v", err)
}
t.Cleanup(func() {
_ = logger.Close()
})
event := Event{
TsUnixNano: 123,
NodeRole: NodeRolePeer,
NodeID: "peer-a",
Event: EventAAppPrepBegin,
MessageType: protocol.MessageTypeText,
MessageID: 1,
From: "peer-a",
To: "peer-b",
BodySize: 5,
}
if err := logger.LogEvent(event); err != nil {
t.Fatalf("LogEvent() error = %v", err)
}
file, err := os.Open(path)
if err != nil {
t.Fatalf("os.Open() error = %v", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
if !scanner.Scan() {
t.Fatal("expected one JSONL line, got none")
}
var got Event
if err := json.Unmarshal(scanner.Bytes(), &got); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if got != event {
t.Fatalf("event mismatch: got %+v want %+v", got, event)
}
if scanner.Scan() {
t.Fatal("expected exactly one JSONL line")
}
if err := scanner.Err(); err != nil {
t.Fatalf("scanner.Err() = %v", err)
}
}
func TestJSONLLoggerHandlesConcurrentWrites(t *testing.T) {
path := filepath.Join(t.TempDir(), "latency.jsonl")
logger, err := NewJSONLLogger(path)
if err != nil {
t.Fatalf("NewJSONLLogger() error = %v", err)
}
t.Cleanup(func() {
_ = logger.Close()
})
const total = 32
var wg sync.WaitGroup
for i := 0; i < total; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
err := logger.LogEvent(Event{
TsUnixNano: int64(i + 1),
NodeRole: NodeRoleServer,
NodeID: protocol.ServerPeerID,
Event: EventBAppRecv,
MessageType: protocol.MessageTypeFile,
MessageID: uint64(i + 1),
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
BodySize: 3,
})
if err != nil {
t.Errorf("LogEvent() error = %v", err)
}
}()
}
wg.Wait()
file, err := os.Open(path)
if err != nil {
t.Fatalf("os.Open() error = %v", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
var count int
seen := make(map[uint64]bool, total)
for scanner.Scan() {
var event Event
if err := json.Unmarshal(scanner.Bytes(), &event); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
count++
seen[event.MessageID] = true
}
if err := scanner.Err(); err != nil {
t.Fatalf("scanner.Err() = %v", err)
}
if count != total {
t.Fatalf("line count = %d, want %d", count, total)
}
if len(seen) != total {
t.Fatalf("unique message count = %d, want %d", len(seen), total)
}
}

View File

@@ -0,0 +1,254 @@
package latencylog
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"omnisocketgo/cmd/internal/protocol"
)
// Summary 是针对单条消息的时延的规则列表。
var requiredTimestampNames = []string{
EventAAppPrepBegin, // A 端应用开始准备这条消息
EventATXSched, // A 端进入 Linux qdisc 之前
EventATXSoftware, // A 端即将交给网卡驱动
EventBRXSoftware, // B 端网卡驱动把数据交给 Linux 接收栈
EventBAppRecv, // B 端应用真正读到完整消息
EventBPersistEnd, // B 端写盘完成
}
// Summary 是针对单条消息的时延整理结果。
type Summary struct {
MessageType protocol.MessageType `json:"message_type"` //消息类型
MessageID uint64 `json:"message_id"` //消息ID
From string `json:"from"` //发送方
To string `json:"to"` //接收方
FileName string `json:"file_name,omitempty"` //文件名(仅文件消息)
BodySize int `json:"body_size"` //消息体大小(字节数)
Timestamps map[string]int64 `json:"timestamps"` //事件时间戳key 是事件名称value 是 UnixNano 时间戳
AProcessingLatencyNS *int64 `json:"a_processing_latency_ns,omitempty"` // A 处理时延A_TX_SCHED - A_APP_PREP_BEGIN
AQueueLatencyNS *int64 `json:"a_queue_latency_ns,omitempty"` // A 排队时延A_TX_SOFTWARE - A_TX_SCHED
ABTransportPropagationBQueueLatencyNS *int64 `json:"a_b_transport_propagation_b_queue_latency_ns,omitempty"` // A-B 传输时延 + B 排队时延B_APP_RECV - A_TX_SOFTWARE
BKernelReceivePathLatencyNS *int64 `json:"b_kernel_receive_path_latency_ns,omitempty"` // B 内核接收路径近似B_APP_RECV - B_RX_SOFTWARE
BProcessingLatencyNS *int64 `json:"b_processing_latency_ns,omitempty"` // B 处理时延B_PERSIST_END - B_APP_RECV
EndToEndLatencyNS *int64 `json:"end_to_end_latency_ns,omitempty"` // 端到端时延B_PERSIST_END - A_APP_PREP_BEGIN
MissingTimestamps []string `json:"missing_timestamps,omitempty"` // 缺失的时间戳列表,包含 requiredTimestampNames 中但在原始事件中没有的事件名称
}
// LoadEventsFromFiles 从JSONL 原始日志文件中加载事件。
type messageKey struct {
MessageType protocol.MessageType //消息类型
MessageID uint64 //消息ID
From string //发送方
To string //接收方
}
// LoadEventsFromFiles 从多个 JSONL 原始日志文件中加载事件。
func LoadEventsFromFiles(paths []string) ([]Event, error) {
var events []Event
for _, path := range paths {
fileEvents, err := LoadEventsFromFile(path)
if err != nil {
return nil, err
}
events = append(events, fileEvents...)
}
return events, nil
}
// LoadEventsFromFile 从单个 JSONL 原始日志文件中加载事件。
func LoadEventsFromFile(path string) ([]Event, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("latencylog: open raw log %s: %w", path, err)
}
defer file.Close()
var events []Event
scanner := bufio.NewScanner(file)
for scanner.Scan() {
if len(scanner.Bytes()) == 0 {
continue
}
var event Event
if err := json.Unmarshal(scanner.Bytes(), &event); err != nil { //解析 JSONL 行失败,返回错误
return nil, fmt.Errorf("latencylog: decode event from %s: %w", path, err)
}
events = append(events, event)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("latencylog: scan raw log %s: %w", path, err)
}
return events, nil
}
// SummarizeEvents 将原始事件整理成按消息分组的时延结果。
func SummarizeEvents(events []Event) []Summary {
grouped := make(map[messageKey]*Summary)
for _, event := range events {
if !IsBusinessEvent(event) {
continue
}
key := messageKey{
MessageType: event.MessageType,
MessageID: event.MessageID,
From: event.From,
To: event.To,
}
summary, ok := grouped[key]
if !ok {
summary = &Summary{
MessageType: event.MessageType,
MessageID: event.MessageID,
From: event.From,
To: event.To,
FileName: event.FileName,
BodySize: event.BodySize,
Timestamps: make(map[string]int64),
}
grouped[key] = summary
}
if summary.FileName == "" {
summary.FileName = event.FileName
}
if event.BodySize > 0 {
summary.BodySize = event.BodySize
}
if existing, exists := summary.Timestamps[event.Event]; !exists || event.TsUnixNano < existing {
summary.Timestamps[event.Event] = event.TsUnixNano
}
}
summaries := make([]Summary, 0, len(grouped))
for _, summary := range grouped {
completeSummary(summary) //补全时延指标和缺失时间戳信息
summaries = append(summaries, *summary)
}
//对整理结果进行排序,先按发送方、再按接收方、再按消息 ID、最后按消息类型排序保证输出的稳定性和可读性。
sort.Slice(summaries, func(i, j int) bool {
if summaries[i].From != summaries[j].From {
return summaries[i].From < summaries[j].From
}
if summaries[i].To != summaries[j].To {
return summaries[i].To < summaries[j].To
}
if summaries[i].MessageID != summaries[j].MessageID {
return summaries[i].MessageID < summaries[j].MessageID
}
return summaries[i].MessageType < summaries[j].MessageType
})
return summaries
}
// WriteSummariesJSONL 将整理结果写成 JSONL 汇总文件。
func WriteSummariesJSONL(path string, summaries []Summary) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("latencylog: create summary dir for %s: %w", path, err)
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return fmt.Errorf("latencylog: open summary file %s: %w", path, err)
}
defer file.Close()
writer := bufio.NewWriter(file)
for _, summary := range summaries { //将每条整理结果编码成 JSONL 行并写入文件
line, err := json.Marshal(summary)
if err != nil {
return fmt.Errorf("latencylog: encode summary for message %d: %w", summary.MessageID, err)
}
if _, err := writer.Write(append(line, '\n')); err != nil {
return fmt.Errorf("latencylog: write summary file %s: %w", path, err)
}
}
if err := writer.Flush(); err != nil { //将缓冲区内容写入文件
return fmt.Errorf("latencylog: flush summary file %s: %w", path, err)
}
return nil
}
// completeSummary 根据事件时间戳计算时延指标,并找出缺失的时间戳。
func completeSummary(summary *Summary) {
summary.MissingTimestamps = missingTimestampNames(summary.Timestamps)
if value := subtractIfPresent(summary.Timestamps, EventATXSched, EventAAppPrepBegin); value != nil {
summary.AProcessingLatencyNS = value
}
if value := subtractIfPresent(summary.Timestamps, EventATXSoftware, EventATXSched); value != nil {
summary.AQueueLatencyNS = value
}
if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventATXSoftware); value != nil {
summary.ABTransportPropagationBQueueLatencyNS = value
}
if value := subtractIfPresent(summary.Timestamps, EventBAppRecv, EventBRXSoftware); value != nil {
summary.BKernelReceivePathLatencyNS = value
}
if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventBAppRecv); value != nil {
summary.BProcessingLatencyNS = value
}
if value := subtractIfPresent(summary.Timestamps, EventBPersistEnd, EventAAppPrepBegin); value != nil {
summary.EndToEndLatencyNS = value
}
}
// 返回 requiredTimestampNames 中哪些在给定的 timestamps 中缺失。
func missingTimestampNames(timestamps map[string]int64) []string {
var missing []string
for _, name := range requiredTimestampNames {
if _, ok := timestamps[name]; !ok {
missing = append(missing, name)
}
}
return missing
}
// 如果 timestamps 中同时存在 endName 和 beginName则返回它们的差值否则返回 nil。
func subtractIfPresent(timestamps map[string]int64, endName, beginName string) *int64 {
end, ok := timestamps[endName]
if !ok {
return nil
}
begin, ok := timestamps[beginName]
if !ok {
return nil
}
value := end - begin
return &value
}
// 判断事件是否是业务相关的时延事件(其中一项)
func IsBusinessEvent(event Event) bool {
switch event.Event {
case EventAAppPrepBegin,
EventATXSched,
EventATXSoftware,
EventATXHardware,
EventBRXHardware,
EventBRXSoftware,
EventBAppRecv,
EventBPersistBegin,
EventBPersistEnd:
return true
default:
return false
}
}

View File

@@ -0,0 +1,440 @@
package latencylog
import (
"bufio"
"fmt"
"html/template"
"os"
"path/filepath"
"strings"
"omnisocketgo/cmd/internal/protocol"
)
const summaryChartHTMLTemplate = `<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Latency Summary Chart</title>
<style>
:root {
color-scheme: light;
--bg: #f6f7fb;
--panel: #ffffff;
--text: #172033;
--muted: #60708a;
--border: #d9dfeb;
--track: #e8edf5;
--a-proc: #3b82f6;
--a-queue: #14b8a6;
--transport: #f59e0b;
--b-proc: #22c55e;
--unknown: #94a3b8;
}
* { box-sizing: border-box; }
body {
margin: 0;
font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif;
background: linear-gradient(180deg, #eef3ff 0%, var(--bg) 220px);
color: var(--text);
}
main {
max-width: 1120px;
margin: 0 auto;
padding: 32px 20px 48px;
}
h1 {
margin: 0 0 8px;
font-size: 32px;
line-height: 1.1;
}
.intro {
color: var(--muted);
margin: 0 0 24px;
font-size: 15px;
}
.stats {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 12px;
margin-bottom: 24px;
}
.stat {
background: var(--panel);
border: 1px solid var(--border);
border-radius: 14px;
padding: 14px 16px;
box-shadow: 0 10px 30px rgba(23, 32, 51, 0.06);
}
.stat-label {
color: var(--muted);
font-size: 12px;
text-transform: uppercase;
letter-spacing: 0.05em;
margin-bottom: 6px;
}
.stat-value {
font-size: 22px;
font-weight: 700;
}
.legend {
display: flex;
flex-wrap: wrap;
gap: 12px;
margin-bottom: 20px;
color: var(--muted);
font-size: 13px;
}
.legend-item {
display: inline-flex;
align-items: center;
gap: 8px;
}
.swatch {
width: 12px;
height: 12px;
border-radius: 999px;
flex: 0 0 auto;
}
.rows {
display: grid;
gap: 14px;
}
.row {
background: var(--panel);
border: 1px solid var(--border);
border-radius: 16px;
padding: 16px;
box-shadow: 0 12px 30px rgba(23, 32, 51, 0.05);
}
.row-head {
display: flex;
justify-content: space-between;
align-items: baseline;
gap: 12px;
margin-bottom: 8px;
}
.row-title {
font-size: 18px;
font-weight: 700;
margin: 0;
}
.row-e2e {
font-size: 16px;
font-weight: 700;
white-space: nowrap;
}
.row-meta {
color: var(--muted);
font-size: 13px;
margin-bottom: 12px;
}
.bar {
height: 24px;
display: flex;
overflow: hidden;
background: var(--track);
border-radius: 999px;
}
.segment {
height: 100%;
}
.segment:first-child {
border-top-left-radius: 999px;
border-bottom-left-radius: 999px;
}
.segment:last-child {
border-top-right-radius: 999px;
border-bottom-right-radius: 999px;
}
.row-legend {
display: flex;
flex-wrap: wrap;
gap: 10px 14px;
margin-top: 10px;
color: var(--muted);
font-size: 12px;
}
.missing {
margin-top: 10px;
color: #a16207;
font-size: 12px;
}
.empty {
color: var(--muted);
font-style: italic;
padding: 24px 16px;
background: var(--panel);
border: 1px dashed var(--border);
border-radius: 16px;
text-align: center;
}
</style>
</head>
<body>
<main>
<h1>Latency Summary</h1>
<p class="intro">A simple per-message end-to-end latency chart generated from summarized JSONL records.</p>
<section class="stats">
<div class="stat">
<div class="stat-label">Messages</div>
<div class="stat-value">{{.TotalMessages}}</div>
</div>
<div class="stat">
<div class="stat-label">With End-To-End</div>
<div class="stat-value">{{.MessagesWithEndToEnd}}</div>
</div>
<div class="stat">
<div class="stat-label">Average End-To-End</div>
<div class="stat-value">{{.AverageEndToEnd}}</div>
</div>
<div class="stat">
<div class="stat-label">Max End-To-End</div>
<div class="stat-value">{{.MaxEndToEnd}}</div>
</div>
</section>
<section class="legend">
{{range .Legend}}
<span class="legend-item">
<span class="swatch" style="background: {{.Color}}"></span>
<span>{{.Label}}</span>
</span>
{{end}}
</section>
{{if .Rows}}
<section class="rows">
{{range .Rows}}
<article class="row">
<div class="row-head">
<h2 class="row-title">{{.Title}}</h2>
<div class="row-e2e">{{.EndToEnd}}</div>
</div>
<div class="row-meta">{{.Subtitle}}</div>
<div class="bar">
{{range .Segments}}
<div class="segment" style="width: {{printf "%.4f" .WidthPercent}}%; background: {{.Color}}" title="{{.Label}}: {{.Value}}"></div>
{{end}}
</div>
{{if .Segments}}
<div class="row-legend">
{{range .Segments}}
<span class="legend-item">
<span class="swatch" style="background: {{.Color}}"></span>
<span>{{.Label}} {{.Value}}</span>
</span>
{{end}}
</div>
{{end}}
{{if .MissingTimestamps}}
<div class="missing">Missing timestamps: {{.MissingTimestamps}}</div>
{{end}}
</article>
{{end}}
</section>
{{else}}
<section class="empty">No summarized messages were available for chart rendering.</section>
{{end}}
</main>
</body>
</html>
`
type summaryChartPage struct {
TotalMessages int
MessagesWithEndToEnd int
AverageEndToEnd string
MaxEndToEnd string
Legend []summaryChartLegendItem
Rows []summaryChartRow
}
type summaryChartLegendItem struct {
Label string
Color string
}
type summaryChartRow struct {
Title string
Subtitle string
EndToEnd string
MissingTimestamps string
Segments []summaryChartSegment
}
type summaryChartSegment struct {
Label string
Value string
Color string
WidthPercent float64
}
type summaryChartSegmentMetric struct {
label string
value *int64
color string
}
// WriteSummariesHTMLChart 将整理结果写成一个可直接在浏览器中打开的简单 HTML 图表。
func WriteSummariesHTMLChart(path string, summaries []Summary) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("latencylog: create chart dir for %s: %w", path, err)
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return fmt.Errorf("latencylog: open chart file %s: %w", path, err)
}
defer file.Close()
page := buildSummaryChartPage(summaries)
tmpl, err := template.New("summary-chart").Parse(summaryChartHTMLTemplate)
if err != nil {
return fmt.Errorf("latencylog: parse chart template: %w", err)
}
writer := bufio.NewWriter(file)
if err := tmpl.Execute(writer, page); err != nil {
return fmt.Errorf("latencylog: render chart %s: %w", path, err)
}
if err := writer.Flush(); err != nil {
return fmt.Errorf("latencylog: flush chart %s: %w", path, err)
}
return nil
}
func buildSummaryChartPage(summaries []Summary) summaryChartPage {
page := summaryChartPage{
TotalMessages: len(summaries),
Legend: []summaryChartLegendItem{
{Label: "A processing", Color: "var(--a-proc)"},
{Label: "A queue", Color: "var(--a-queue)"},
{Label: "Transport + B queue", Color: "var(--transport)"},
{Label: "B processing", Color: "var(--b-proc)"},
{Label: "Unknown / missing", Color: "var(--unknown)"},
},
Rows: make([]summaryChartRow, 0, len(summaries)),
}
var (
endToEndValues []int64
totalEndToEnd int64
maxEndToEnd int64
)
for _, summary := range summaries {
page.Rows = append(page.Rows, buildSummaryChartRow(summary))
if summary.EndToEndLatencyNS == nil {
continue
}
endToEnd := *summary.EndToEndLatencyNS
endToEndValues = append(endToEndValues, endToEnd)
totalEndToEnd += endToEnd
if endToEnd > maxEndToEnd {
maxEndToEnd = endToEnd
}
}
page.MessagesWithEndToEnd = len(endToEndValues)
page.AverageEndToEnd = "n/a"
page.MaxEndToEnd = "n/a"
if len(endToEndValues) > 0 {
page.AverageEndToEnd = formatLatencyNS(totalEndToEnd / int64(len(endToEndValues)))
page.MaxEndToEnd = formatLatencyNS(maxEndToEnd)
}
return page
}
func buildSummaryChartRow(summary Summary) summaryChartRow {
row := summaryChartRow{
Title: buildSummaryChartTitle(summary),
Subtitle: buildSummaryChartSubtitle(summary),
EndToEnd: "End-to-end: n/a",
MissingTimestamps: strings.Join(summary.MissingTimestamps, ", "),
}
if summary.EndToEndLatencyNS == nil || *summary.EndToEndLatencyNS <= 0 {
return row
}
total := *summary.EndToEndLatencyNS
row.EndToEnd = fmt.Sprintf("End-to-end: %s", formatLatencyNS(total))
metrics := []summaryChartSegmentMetric{
{label: "A processing", value: summary.AProcessingLatencyNS, color: "var(--a-proc)"},
{label: "A queue", value: summary.AQueueLatencyNS, color: "var(--a-queue)"},
{label: "Transport + B queue", value: summary.ABTransportPropagationBQueueLatencyNS, color: "var(--transport)"},
{label: "B processing", value: summary.BProcessingLatencyNS, color: "var(--b-proc)"},
}
var knownTotal int64
for _, metric := range metrics {
if metric.value == nil || *metric.value <= 0 {
continue
}
knownTotal += *metric.value
}
scaleTotal := total
if knownTotal > scaleTotal {
scaleTotal = knownTotal
}
if scaleTotal <= 0 {
return row
}
for _, metric := range metrics {
if metric.value == nil || *metric.value <= 0 {
continue
}
row.Segments = append(row.Segments, summaryChartSegment{
Label: metric.label,
Value: formatLatencyNS(*metric.value),
Color: metric.color,
WidthPercent: float64(*metric.value) * 100 / float64(scaleTotal),
})
}
if remaining := total - knownTotal; remaining > 0 {
row.Segments = append(row.Segments, summaryChartSegment{
Label: "Unknown / missing",
Value: formatLatencyNS(remaining),
Color: "var(--unknown)",
WidthPercent: float64(remaining) * 100 / float64(scaleTotal),
})
}
return row
}
func buildSummaryChartTitle(summary Summary) string {
if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" {
return fmt.Sprintf("%s #%d (%s)", summary.MessageType, summary.MessageID, summary.FileName)
}
return fmt.Sprintf("%s #%d", summary.MessageType, summary.MessageID)
}
func buildSummaryChartSubtitle(summary Summary) string {
parts := []string{
fmt.Sprintf("%s -> %s", summary.From, summary.To),
fmt.Sprintf("%d bytes", summary.BodySize),
}
if summary.MessageType == protocol.MessageTypeFile && summary.FileName != "" {
parts = append(parts, fmt.Sprintf("file: %s", summary.FileName))
}
return strings.Join(parts, " | ")
}
func formatLatencyNS(ns int64) string {
return fmt.Sprintf("%.3f ms", float64(ns)/1_000_000)
}

View File

@@ -0,0 +1,67 @@
package latencylog
import (
"os"
"path/filepath"
"strings"
"testing"
"omnisocketgo/cmd/internal/protocol"
)
func TestWriteSummariesHTMLChart(t *testing.T) {
aProcessing := int64(20_000_000)
aQueue := int64(10_000_000)
transport := int64(40_000_000)
bProcessing := int64(30_000_000)
endToEnd := int64(100_000_000)
summaries := []Summary{
{
MessageType: protocol.MessageTypeText,
MessageID: 7,
From: "peer-a",
To: "peer-b",
BodySize: 5,
AProcessingLatencyNS: &aProcessing,
AQueueLatencyNS: &aQueue,
ABTransportPropagationBQueueLatencyNS: &transport,
BProcessingLatencyNS: &bProcessing,
EndToEndLatencyNS: &endToEnd,
},
{
MessageType: protocol.MessageTypeFile,
MessageID: 8,
From: "peer-b",
To: "peer-a",
FileName: "payload.bin",
BodySize: 128,
MissingTimestamps: []string{EventBRXSoftware},
},
}
path := filepath.Join(t.TempDir(), "charts", "latency-summary.html")
if err := WriteSummariesHTMLChart(path, summaries); err != nil {
t.Fatalf("WriteSummariesHTMLChart() error = %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("os.ReadFile() error = %v", err)
}
content := string(data)
for _, want := range []string{
"Latency Summary",
"text #7",
"peer-a -&gt; peer-b | 5 bytes",
"End-to-end: 100.000 ms",
"A processing 20.000 ms",
"file #8 (payload.bin)",
"Missing timestamps: B_RX_SOFTWARE",
} {
if !strings.Contains(content, want) {
t.Fatalf("chart content missing %q\n%s", want, content)
}
}
}

View File

@@ -0,0 +1,144 @@
package latencylog
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"reflect"
"testing"
"omnisocketgo/cmd/internal/protocol"
)
func TestSummarizeEventsComputesLatencyMetrics(t *testing.T) {
events := []Event{
{TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
{TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
{TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
{TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
{TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
{TsUnixNano: 230, Event: EventBPersistBegin, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
{TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 1, From: "peer-a", To: "peer-b"},
}
summaries := SummarizeEvents(events)
if len(summaries) != 1 {
t.Fatalf("summary count = %d, want 1", len(summaries))
}
summary := summaries[0]
if got := ptrValue(summary.AProcessingLatencyNS); got != 20 {
t.Fatalf("AProcessingLatencyNS = %d, want 20", got)
}
if got := ptrValue(summary.AQueueLatencyNS); got != 20 {
t.Fatalf("AQueueLatencyNS = %d, want 20", got)
}
if got := ptrValue(summary.ABTransportPropagationBQueueLatencyNS); got != 80 {
t.Fatalf("ABTransportPropagationBQueueLatencyNS = %d, want 80", got)
}
if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 {
t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got)
}
if got := ptrValue(summary.BProcessingLatencyNS); got != 40 {
t.Fatalf("BProcessingLatencyNS = %d, want 40", got)
}
if got := ptrValue(summary.EndToEndLatencyNS); got != 160 {
t.Fatalf("EndToEndLatencyNS = %d, want 160", got)
}
if got := summary.Timestamps[EventBRXSoftware]; got != 180 {
t.Fatalf("timestamps[%q] = %d, want 180", EventBRXSoftware, got)
}
if len(summary.MissingTimestamps) != 0 {
t.Fatalf("MissingTimestamps = %v, want empty", summary.MissingTimestamps)
}
}
func TestSummarizeEventsReportsMissingTimestamps(t *testing.T) {
events := []Event{
{TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"},
{TsUnixNano: 240, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 2, From: "peer-a", To: "peer-b"},
}
summaries := SummarizeEvents(events)
if len(summaries) != 1 {
t.Fatalf("summary count = %d, want 1", len(summaries))
}
wantMissing := []string{EventATXSched, EventATXSoftware, EventBRXSoftware, EventBAppRecv}
if !reflect.DeepEqual(summaries[0].MissingTimestamps, wantMissing) {
t.Fatalf("MissingTimestamps = %v, want %v", summaries[0].MissingTimestamps, wantMissing)
}
if summaries[0].AProcessingLatencyNS != nil {
t.Fatalf("AProcessingLatencyNS = %v, want nil", ptrValue(summaries[0].AProcessingLatencyNS))
}
if summaries[0].EndToEndLatencyNS == nil {
t.Fatal("EndToEndLatencyNS = nil, want non-nil because endpoints are present")
}
}
func TestLoadAndWriteSummaryFiles(t *testing.T) {
rawPath := filepath.Join(t.TempDir(), "raw.jsonl")
rawLogger, err := NewJSONLLogger(rawPath)
if err != nil {
t.Fatalf("NewJSONLLogger() error = %v", err)
}
t.Cleanup(func() {
_ = rawLogger.Close()
})
for _, event := range []Event{
{TsUnixNano: 100, Event: EventAAppPrepBegin, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
{TsUnixNano: 120, Event: EventATXSched, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
{TsUnixNano: 140, Event: EventATXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
{TsUnixNano: 180, Event: EventBRXSoftware, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
{TsUnixNano: 220, Event: EventBAppRecv, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
{TsUnixNano: 260, Event: EventBPersistEnd, MessageType: protocol.MessageTypeText, MessageID: 3, From: "peer-a", To: "peer-b"},
} {
if err := rawLogger.LogEvent(event); err != nil {
t.Fatalf("LogEvent() error = %v", err)
}
}
events, err := LoadEventsFromFile(rawPath)
if err != nil {
t.Fatalf("LoadEventsFromFile() error = %v", err)
}
summaryPath := filepath.Join(t.TempDir(), "summary.jsonl")
if err := WriteSummariesJSONL(summaryPath, SummarizeEvents(events)); err != nil {
t.Fatalf("WriteSummariesJSONL() error = %v", err)
}
file, err := os.Open(summaryPath)
if err != nil {
t.Fatalf("os.Open() error = %v", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
if !scanner.Scan() {
t.Fatal("expected one summary line, got none")
}
var summary Summary
if err := json.Unmarshal(scanner.Bytes(), &summary); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if summary.MessageID != 3 {
t.Fatalf("MessageID = %d, want 3", summary.MessageID)
}
if got := ptrValue(summary.BKernelReceivePathLatencyNS); got != 40 {
t.Fatalf("BKernelReceivePathLatencyNS = %d, want 40", got)
}
if got := ptrValue(summary.EndToEndLatencyNS); got != 160 {
t.Fatalf("EndToEndLatencyNS = %d, want 160", got)
}
}
func ptrValue(value *int64) int64 {
if value == nil {
return 0
}
return *value
}

179
cmd/internal/peer/client.go Normal file
View File

@@ -0,0 +1,179 @@
package peer
import (
"fmt"
"net"
"os"
"path/filepath"
"sync/atomic"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
var dialServer = net.Dial
type clientOptions struct {
logger latencylog.Logger
}
// Option 用于配置 Client 的可选行为,例如时延日志。
type Option func(*clientOptions)
// WithLogger 为 client 注入时延日志记录器。
func WithLogger(logger latencylog.Logger) Option {
return func(options *clientOptions) {
options.logger = logger
}
}
// Client 表示一个已经连接到 server 的 peer。
type Client struct {
id string
conn *transport.TCPConn
logger latencylog.Logger
nextID uint64
}
// Dial 连接到 server并立即发送 register 消息完成身份注册。
func Dial(serverAddr, peerID string, opts ...Option) (*Client, error) {
options := clientOptions{
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(&options)
}
if options.logger == nil {
options.logger = latencylog.NoopLogger{}
}
rawConn, err := dialServer("tcp", serverAddr) //使用 net.Dial 连接到 serverAddr 指定的 TCP 地址,返回一个 net.Conn。
if err != nil {
return nil, fmt.Errorf("peer: dial server: %w", err)
}
conn, err := transport.NewTCPConn(
rawConn,
transport.WithLogger(options.logger, latencylog.NodeRolePeer, peerID),
)
if err != nil {
_ = rawConn.Close()
return nil, fmt.Errorf("peer: create transport conn: %w", err)
}
client := &Client{
id: peerID,
conn: conn,
logger: options.logger,
}
if err := conn.Send(protocol.Message{ //向 server 发送一条 register 消息,完成身份注册。
Type: protocol.MessageTypeRegister,
From: peerID,
To: protocol.ServerPeerID,
}); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("peer: register with server: %w", err)
}
return client, nil
}
// ID 返回当前 client 的 peer 标识。
func (c *Client) ID() string {
return c.id
}
// SendText 向目标 peer 发送一条文本消息。
func (c *Client) SendText(to, body string) error {
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: c.nextMessageID(),
From: c.id,
To: to,
}
// 记录 A 端应用开始准备消息的时间点。
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
msg.Body = []byte(body)
return c.conn.Send(msg)
}
// SendFile 向目标 peer 发送一条文件消息。
func (c *Client) SendFile(to, fileName string, body []byte) error {
msg := protocol.Message{
Type: protocol.MessageTypeFile,
ID: c.nextMessageID(),
From: c.id,
To: to,
FileName: fileName,
}
// 记录 A 端应用开始准备消息的时间点。
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
bodyCopy := make([]byte, len(body))
copy(bodyCopy, body)
msg.Body = bodyCopy
return c.conn.Send(msg)
}
// SendFilePath 从本地文件读取内容并发送给目标 peer。
func (c *Client) SendFilePath(to, path string) error {
msg := protocol.Message{
Type: protocol.MessageTypeFile,
ID: c.nextMessageID(),
From: c.id,
To: to,
FileName: filepath.Base(path),
}
// 记录 A 端应用开始准备消息的时间点。
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
body, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("peer: read file %s: %w", path, err)
}
msg.Body = body
return c.conn.Send(msg)
}
// Receive 读取一条来自 server 的消息。
func (c *Client) Receive() (protocol.Message, error) {
msg, err := c.conn.Receive() //从底层 TCP 连接读取一条消息,返回一个 protocol.Message 结构体。
if err != nil {
return protocol.Message{}, fmt.Errorf("peer: receive from server: %w", err)
}
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
return msg, nil
}
// ReceiveLoop 持续接收 server 消息并交给 handler 处理。
func (c *Client) ReceiveLoop(handler func(protocol.Message) error) error {
return c.conn.ReceiveLoop(func(msg protocol.Message) error {
switch msg.Type {
case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError:
// 记录 B 端应用真正读到完整消息的时间点。
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
return handler(msg)
default:
return fmt.Errorf("peer: unexpected message type from server: %s", msg.Type)
}
})
}
// Close 关闭与 server 的连接。
func (c *Client) Close() error {
return c.conn.Close()
}
func (c *Client) nextMessageID() uint64 {
return atomic.AddUint64(&c.nextID, 1)
}

View File

@@ -0,0 +1,188 @@
//go:build linux
package peer
import (
"net"
"strings"
"sync"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/server"
)
func TestClientsExchangeMessagesWithLinuxTimestamps(t *testing.T) {
hub := server.NewHub()
serverAddr, cleanup := startRealHubServer(t, hub)
defer cleanup()
peerALogger := &recordingLogger{}
peerA, err := Dial(serverAddr, "peer-a", WithLogger(peerALogger))
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerBLogger := &recordingLogger{}
peerB, err := Dial(serverAddr, "peer-b", WithLogger(peerBLogger))
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
inboxDir := t.TempDir()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
textMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(text) error = %v", err)
}
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
}
if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil {
t.Fatalf("SendFile() error = %v", err)
}
fileMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(file) error = %v", err)
}
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
}
waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 1, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a text kernel timestamps")
waitFor(t, func() bool { return hasMessageEvents(peerALogger.Events(), 2, latencylog.EventAAppPrepBegin, latencylog.EventATXSched, latencylog.EventATXSoftware) }, "peer-a file kernel timestamps")
waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 1, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b text receive timestamps")
waitFor(t, func() bool { return hasMessageEvents(peerBLogger.Events(), 2, latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd) }, "peer-b file receive timestamps")
}
func startRealHubServer(t *testing.T, hub *server.Hub) (string, func()) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
var (
wg sync.WaitGroup
stop = make(chan struct{})
errOnce sync.Once
)
wg.Add(1)
go func() {
defer wg.Done()
for {
conn, acceptErr := listener.Accept()
if acceptErr != nil {
select {
case <-stop:
return
default:
}
if strings.Contains(acceptErr.Error(), "closed") {
return
}
t.Errorf("listener.Accept() error = %v", acceptErr)
return
}
wg.Add(1)
go func(rawConn net.Conn) {
defer wg.Done()
if serveErr := hub.ServeConn(rawConn); serveErr != nil && !isExpectedHubServeExit(serveErr) {
errOnce.Do(func() {
t.Logf("hub.ServeConn() ended with %v", serveErr)
})
}
}(conn)
}
}()
cleanup := func() {
close(stop)
_ = listener.Close()
wg.Wait()
}
return listener.Addr().String(), cleanup
}
func hasMessageEvents(events []latencylog.Event, messageID uint64, wantEvents ...string) bool {
seen := make(map[string]bool, len(wantEvents))
for _, event := range events {
if event.MessageID != messageID {
continue
}
if event.TsUnixNano <= 0 {
return false
}
seen[event.Event] = true
}
for _, wantEvent := range wantEvents {
if !seen[wantEvent] {
return false
}
}
return true
}
func isExpectedHubServeExit(err error) bool {
if err == nil {
return true
}
message := err.Error()
return strings.Contains(message, "closed") || strings.Contains(message, "protocol: read frame: EOF")
}
func TestLinuxTimestampedReceivePreservesBusinessMessageShape(t *testing.T) {
hub := server.NewHub()
serverAddr, cleanup := startRealHubServer(t, hub)
defer cleanup()
peerA, err := Dial(serverAddr, "peer-a")
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := Dial(serverAddr, "peer-b")
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
want := protocol.Message{
Type: protocol.MessageTypeFile,
ID: 1,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0xde, 0xad, 0xbe, 0xef},
}
if err := peerA.SendFile(want.To, want.FileName, want.Body); err != nil {
t.Fatalf("SendFile() error = %v", err)
}
got, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
if got.Type != want.Type || got.ID != want.ID || got.From != want.From || got.To != want.To || got.FileName != want.FileName || string(got.Body) != string(want.Body) {
t.Fatalf("received message mismatch: got %+v want %+v", got, want)
}
}

View File

@@ -0,0 +1,770 @@
package peer
import (
"bytes"
"encoding/json"
"net"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"testing"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/server"
"omnisocketgo/cmd/internal/transport"
)
type recordingLogger struct {
mu sync.Mutex
events []latencylog.Event
}
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
l.mu.Lock()
defer l.mu.Unlock()
l.events = append(l.events, event)
return nil
}
func (l *recordingLogger) Events() []latencylog.Event {
l.mu.Lock()
defer l.mu.Unlock()
return append([]latencylog.Event(nil), l.events...)
}
type failingLogger struct{}
func (failingLogger) LogEvent(latencylog.Event) error {
return net.ErrClosed
}
func TestDialRegistersPeer(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
client, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial() error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
}
func TestClientsExchangeTextAndFileMessages(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
peerA, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := Dial("ignored", "peer-b")
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
received := make(chan protocol.Message, 2)
receiveErr := make(chan error, 1)
go func() {
for i := 0; i < 2; i++ {
msg, err := peerB.Receive()
if err != nil {
receiveErr <- err
return
}
received <- msg
}
receiveErr <- nil
}()
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
fileBody := []byte{0x01, 0x02, 0x03}
if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil {
t.Fatalf("SendFile() error = %v", err)
}
if err := <-receiveErr; err != nil {
t.Fatalf("peerB.Receive() error = %v", err)
}
gotFirst := <-received
wantFirst := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
if !reflect.DeepEqual(gotFirst, wantFirst) {
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst)
}
gotSecond := <-received
wantSecond := protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: fileBody,
}
if !reflect.DeepEqual(gotSecond, wantSecond) {
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond)
}
}
func TestClientReceivesServerErrorForUnknownTarget(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
client, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial() error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.SendText("missing-peer", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "unknown target: missing-peer" {
t.Fatalf("error body = %q, want unknown target message", got.Body)
}
}
func TestClientReceiveLoopHandlesForwardedMessages(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
peerA, err := Dial("ignored", "peer-a")
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerB, err := Dial("ignored", "peer-b")
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- peerB.ReceiveLoop(func(msg protocol.Message) error {
mu.Lock()
defer mu.Unlock()
got = append(got, msg)
if len(got) == 1 {
return peerB.Close()
}
return nil
})
}()
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
err = <-loopErr
if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection")) {
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
}
mu.Lock()
defer mu.Unlock()
want := []protocol.Message{
{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
}
}
func TestClientSendLogsLatencyEvents(t *testing.T) {
tests := []struct {
name string
setup func(*testing.T) string
send func(*Client, string) error
wantMsg protocol.Message
wantEvents []string
}{
{
name: "text",
send: func(client *Client, _ string) error {
return client.SendText("peer-b", "hello")
},
wantMsg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
wantEvents: []string{
latencylog.EventAAppPrepBegin,
latencylog.EventSendHandoffBegin,
latencylog.EventATXSched,
latencylog.EventATXSoftware,
latencylog.EventSendHandoffEnd,
},
},
{
name: "file-bytes",
send: func(client *Client, _ string) error {
return client.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03})
},
wantMsg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 1,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
wantEvents: []string{
latencylog.EventAAppPrepBegin,
latencylog.EventSendHandoffBegin,
latencylog.EventATXSched,
latencylog.EventATXSoftware,
latencylog.EventSendHandoffEnd,
},
},
{
name: "file-path",
setup: func(t *testing.T) string {
t.Helper()
path := filepath.Join(t.TempDir(), "payload.bin")
if err := os.WriteFile(path, []byte{0x01, 0x02, 0x03}, 0o644); err != nil {
t.Fatalf("os.WriteFile() error = %v", err)
}
return path
},
send: func(client *Client, path string) error {
return client.SendFilePath("peer-b", path)
},
wantMsg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 1,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
wantEvents: []string{
latencylog.EventAAppPrepBegin,
latencylog.EventSendHandoffBegin,
latencylog.EventATXSched,
latencylog.EventATXSoftware,
latencylog.EventSendHandoffEnd,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputPath := ""
if tt.setup != nil {
inputPath = tt.setup(t)
}
logger := &recordingLogger{}
clientConn, receiver := newClientTransportPair(
t,
[]transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
)
client := &Client{
id: "peer-a",
conn: clientConn,
logger: logger,
}
sendErr := make(chan error, 1)
go func() {
sendErr <- tt.send(client, inputPath)
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.wantMsg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.wantMsg)
}
events := logger.Events()
if len(events) != len(tt.wantEvents) {
t.Fatalf("event count = %d, want %d", len(events), len(tt.wantEvents))
}
for i, wantEvent := range tt.wantEvents {
if events[i].Event != wantEvent {
t.Fatalf("event[%d] = %q, want %q", i, events[i].Event, wantEvent)
}
if events[i].MessageID != tt.wantMsg.ID || events[i].From != tt.wantMsg.From || events[i].To != tt.wantMsg.To {
t.Fatalf("event[%d] metadata mismatch: %+v", i, events[i])
}
}
})
}
}
func TestClientReceiveLogsOnlyBusinessMessages(t *testing.T) {
logger := &recordingLogger{}
clientConn, sender := newClientTransportPair(
t,
[]transport.Option{transport.WithLogger(logger, latencylog.NodeRolePeer, "peer-b")},
nil,
)
client := &Client{
id: "peer-b",
conn: clientConn,
logger: logger,
}
textMsg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 21,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(textMsg)
}()
if _, err := client.Receive(); err != nil {
t.Fatalf("client.Receive(text) error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send(text) error = %v", err)
}
errorMsg := protocol.Message{
Type: protocol.MessageTypeError,
ID: 22,
From: protocol.ServerPeerID,
To: "peer-b",
Body: []byte("failure"),
}
sendErr = make(chan error, 1)
go func() {
sendErr <- sender.Send(errorMsg)
}()
if _, err := client.Receive(); err != nil {
t.Fatalf("client.Receive(error) error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send(error) error = %v", err)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventBRXSoftware {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBRXSoftware)
}
if events[1].Event != latencylog.EventBAppRecv {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBAppRecv)
}
if events[0].MessageID != textMsg.ID || events[1].MessageID != textMsg.ID {
t.Fatalf("message IDs = %d,%d, want %d", events[0].MessageID, events[1].MessageID, textMsg.ID)
}
}
func TestClientPersistTextMessageWritesInboxFileAndLogs(t *testing.T) {
inboxDir := t.TempDir()
logger := &recordingLogger{}
client := &Client{
id: "peer-b",
logger: logger,
}
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 31,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
path, err := client.PersistMessage(msg, inboxDir)
if err != nil {
t.Fatalf("PersistMessage() error = %v", err)
}
if path != filepath.Join(inboxDir, textInboxFileName) {
t.Fatalf("path = %q, want %q", path, filepath.Join(inboxDir, textInboxFileName))
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("os.ReadFile() error = %v", err)
}
var record textInboxRecord
if err := json.Unmarshal(bytes.TrimSpace(data), &record); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if record.MessageID != msg.ID || record.From != msg.From || record.To != msg.To || record.Body != "hello" {
t.Fatalf("record mismatch: got %+v want message %+v", record, msg)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventBPersistBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
}
if events[1].Event != latencylog.EventBPersistEnd {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
}
}
func TestClientPersistFileMessageWritesInboxFileAndLogs(t *testing.T) {
inboxDir := t.TempDir()
logger := &recordingLogger{}
client := &Client{
id: "peer-b",
logger: logger,
}
msg := protocol.Message{
Type: protocol.MessageTypeFile,
ID: 32,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
}
path, err := client.PersistMessage(msg, inboxDir)
if err != nil {
t.Fatalf("PersistMessage() error = %v", err)
}
wantPath := filepath.Join(inboxDir, "peer-a-32-payload.bin")
if path != wantPath {
t.Fatalf("path = %q, want %q", path, wantPath)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("os.ReadFile() error = %v", err)
}
if !reflect.DeepEqual(data, msg.Body) {
t.Fatalf("file body mismatch: got %v want %v", data, msg.Body)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventBPersistBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
}
if events[1].Event != latencylog.EventBPersistEnd {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventBPersistEnd)
}
}
func TestClientPersistMessageReturnsErrorOnWriteFailure(t *testing.T) {
blocker := filepath.Join(t.TempDir(), "blocker")
if err := os.WriteFile(blocker, []byte("not a directory"), 0o644); err != nil {
t.Fatalf("os.WriteFile() error = %v", err)
}
logger := &recordingLogger{}
client := &Client{
id: "peer-b",
logger: logger,
}
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 33,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
if _, err := client.PersistMessage(msg, blocker); err == nil {
t.Fatal("PersistMessage() error = nil, want non-nil")
}
events := logger.Events()
if len(events) != 1 {
t.Fatalf("event count = %d, want 1", len(events))
}
if events[0].Event != latencylog.EventBPersistBegin {
t.Fatalf("event = %q, want %q", events[0].Event, latencylog.EventBPersistBegin)
}
}
func TestClientIgnoresLoggerFailure(t *testing.T) {
clientConn, receiver := newClientTransportPair(
t,
[]transport.Option{transport.WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
nil,
)
client := &Client{
id: "peer-a",
conn: clientConn,
logger: failingLogger{},
}
sendErr := make(chan error, 1)
go func() {
sendErr <- client.SendText("peer-b", "hello")
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("SendText() error = %v, want nil even if logger fails", err)
}
if string(got.Body) != "hello" {
t.Fatalf("body = %q, want hello", got.Body)
}
}
func TestClientPersistIgnoresLoggerFailure(t *testing.T) {
client := &Client{
id: "peer-b",
logger: failingLogger{},
}
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 34,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
path, err := client.PersistMessage(msg, t.TempDir())
if err != nil {
t.Fatalf("PersistMessage() error = %v, want nil even if logger fails", err)
}
if path == "" {
t.Fatal("PersistMessage() path = empty, want non-empty")
}
}
func TestClientsExchangeMessagesWithLatencyLogs(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
peerALogger := &recordingLogger{}
peerA, err := Dial("ignored", "peer-a", WithLogger(peerALogger))
if err != nil {
t.Fatalf("Dial(peer-a) error = %v", err)
}
defer func() { _ = peerA.Close() }()
peerBLogger := &recordingLogger{}
peerB, err := Dial("ignored", "peer-b", WithLogger(peerBLogger))
if err != nil {
t.Fatalf("Dial(peer-b) error = %v", err)
}
defer func() { _ = peerB.Close() }()
inboxDir := t.TempDir()
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
if err := peerA.SendText("peer-b", "hello"); err != nil {
t.Fatalf("SendText() error = %v", err)
}
textMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(text) error = %v", err)
}
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
}
if err := peerA.SendFile("peer-b", "payload.bin", []byte{0x01, 0x02, 0x03}); err != nil {
t.Fatalf("SendFile() error = %v", err)
}
fileMsg, err := peerB.Receive()
if err != nil {
t.Fatalf("peerB.Receive(file) error = %v", err)
}
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
}
waitFor(t, func() bool { return len(peerALogger.Events()) == 10 }, "peer-a latency events")
waitFor(t, func() bool { return len(peerBLogger.Events()) == 8 }, "peer-b latency events")
assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{
1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventATXSched, latencylog.EventATXSoftware, latencylog.EventSendHandoffEnd},
})
assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{
1: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
2: {latencylog.EventBRXSoftware, latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
})
}
func stubDialToHub(t *testing.T, hub *server.Hub) func() {
t.Helper()
originalDial := dialServer
serverAddr, cleanup := startRealHubServer(t, hub)
dialServer = func(network, addr string) (net.Conn, error) {
return net.Dial(network, serverAddr)
}
return func() {
dialServer = originalDial
cleanup()
}
}
func waitFor(t *testing.T, condition func() bool, description string) {
t.Helper()
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for %s", description)
}
func assertEventSequencesByMessage(t *testing.T, events []latencylog.Event, want map[uint64][]string) {
t.Helper()
grouped := make(map[uint64][]latencylog.Event)
for _, event := range events {
grouped[event.MessageID] = append(grouped[event.MessageID], event)
if event.TsUnixNano <= 0 {
t.Fatalf("event timestamp must be positive: %+v", event)
}
}
if len(grouped) != len(want) {
t.Fatalf("message group count = %d, want %d", len(grouped), len(want))
}
for messageID, wantEvents := range want {
gotEvents := grouped[messageID]
if len(gotEvents) != len(wantEvents) {
t.Fatalf("message %d event count = %d, want %d", messageID, len(gotEvents), len(wantEvents))
}
for i, wantEvent := range wantEvents {
if gotEvents[i].Event != wantEvent {
t.Fatalf("message %d event[%d] = %q, want %q", messageID, i, gotEvents[i].Event, wantEvent)
}
}
}
}
func newClientTransportPair(t *testing.T, clientOpts []transport.Option, peerOpts []transport.Option) (*transport.TCPConn, *transport.TCPConn) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
type acceptResult struct {
conn net.Conn
err error
}
accepted := make(chan acceptResult, 1)
go func() {
conn, acceptErr := listener.Accept()
accepted <- acceptResult{conn: conn, err: acceptErr}
}()
clientSide, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
_ = listener.Close()
t.Fatalf("net.Dial() error = %v", err)
}
result := <-accepted
if err := listener.Close(); err != nil {
t.Fatalf("listener.Close() error = %v", err)
}
if result.err != nil {
_ = clientSide.Close()
t.Fatalf("listener.Accept() error = %v", result.err)
}
clientConn, err := transport.NewTCPConn(clientSide, clientOpts...)
if err != nil {
_ = clientSide.Close()
_ = result.conn.Close()
t.Fatalf("transport.NewTCPConn(client) error = %v", err)
}
peerConn, err := transport.NewTCPConn(result.conn, peerOpts...)
if err != nil {
_ = clientConn.Close()
_ = result.conn.Close()
t.Fatalf("transport.NewTCPConn(peer) error = %v", err)
}
t.Cleanup(func() {
_ = clientConn.Close()
_ = peerConn.Close()
})
return clientConn, peerConn
}

View File

@@ -0,0 +1,99 @@
package peer
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
const textInboxFileName = "messages.log"
type textInboxRecord struct {
MessageType protocol.MessageType `json:"message_type"`
MessageID uint64 `json:"message_id"`
From string `json:"from"`
To string `json:"to"`
Body string `json:"body"`
}
// PersistMessage 将收到的业务消息写入本地磁盘,并记录处理完成节点。
func (c *Client) PersistMessage(msg protocol.Message, inboxDir string) (string, error) {
if !latencylog.IsBusinessMessage(msg) {
return "", fmt.Errorf("peer: cannot persist message type %s", msg.Type)
}
if inboxDir == "" {
return "", fmt.Errorf("peer: inbox directory is required")
}
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistBegin, msg)
if err := os.MkdirAll(inboxDir, 0o755); err != nil {
return "", fmt.Errorf("peer: create inbox dir %s: %w", inboxDir, err)
}
path, err := persistMessageToDisk(msg, inboxDir)
if err != nil {
return "", err
}
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistEnd, msg)
return path, nil
}
// persistMessageToDisk 根据消息类型将消息内容写入磁盘,文本消息追加到文本日志文件,文件消息写成独立文件。
func persistMessageToDisk(msg protocol.Message, inboxDir string) (string, error) {
switch msg.Type {
case protocol.MessageTypeText:
return persistTextMessage(msg, inboxDir)
case protocol.MessageTypeFile:
return persistFileMessage(msg, inboxDir)
default:
return "", fmt.Errorf("peer: cannot persist unsupported message type %s", msg.Type)
}
}
// registerPeer 验证 peer ID 的合法性和唯一性,并将其与连接关联起来。
func persistTextMessage(msg protocol.Message, inboxDir string) (string, error) {
record := textInboxRecord{
MessageType: msg.Type,
MessageID: msg.ID,
From: msg.From,
To: msg.To,
Body: string(msg.Body),
}
line, err := json.Marshal(record)
if err != nil {
return "", fmt.Errorf("peer: encode text inbox record: %w", err)
}
path := filepath.Join(inboxDir, textInboxFileName)
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return "", fmt.Errorf("peer: open text inbox %s: %w", path, err)
}
defer file.Close()
if _, err := file.Write(append(line, '\n')); err != nil {
return "", fmt.Errorf("peer: append text inbox %s: %w", path, err)
}
return path, nil
}
// persistFileMessage 将文件消息的内容写成独立文件,文件名包含发送方、消息 ID 和原始文件名,保证唯一性和可读性。
func persistFileMessage(msg protocol.Message, inboxDir string) (string, error) {
fileName := filepath.Base(msg.FileName)
path := filepath.Join(inboxDir, fmt.Sprintf("%s-%d-%s", msg.From, msg.ID, fileName))
if err := os.WriteFile(path, msg.Body, 0o644); err != nil {
return "", fmt.Errorf("peer: write received file %s: %w", path, err)
}
return path, nil
}

View File

@@ -0,0 +1,279 @@
package protocol
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"unicode/utf8"
)
// MaxFrameSize 用于限制单个帧的最大长度,
// 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。
const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB
var (
ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。
ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。
ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。
ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。
ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。
ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。
ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。
ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。
ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。
ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。
ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。
ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。
ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。
ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。
)
// 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON下面自定义的Message头][body bytes]
// 写了 tagJSON 字段名是你指定的 type不写 tagJSON 字段名默认是 Go 字段名 Type
type messageHeader struct {
Type MessageType `json:"type"`
ID uint64 `json:"id"`
From string `json:"from"`
To string `json:"to"`
FileName string `json:"file_name,omitempty"`
ContentLength int `json:"content_length"`
}
// EncodeMessage 将逻辑消息编码为帧内字节格式:
// 1. 4 字节大端序 header 长度
// 2. header JSON
// 3. 原始 body 字节
func EncodeMessage(msg Message) ([]byte, error) {
if err := validateMessage(msg); err != nil {
return nil, err
}
header := messageHeader{
Type: msg.Type,
ID: msg.ID,
From: msg.From,
To: msg.To,
FileName: msg.FileName,
ContentLength: len(msg.Body),
}
headerPayload, err := json.Marshal(header)
if err != nil {
return nil, fmt.Errorf("protocol: encode header: %w", err)
}
// 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。
payload := make([]byte, 4+len(headerPayload)+len(msg.Body))
// 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON第五个字节开始 和 body。
binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload)))
copy(payload[4:], headerPayload)
copy(payload[4+len(headerPayload):], msg.Body)
//检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。
if len(payload) > MaxFrameSize {
return nil, ErrFrameTooLarge
}
return payload, nil
}
// DecodeMessage 将帧内字节格式还原为 Message。
func DecodeMessage(data []byte) (Message, error) {
if len(data) > MaxFrameSize {
return Message{}, ErrFrameTooLarge
}
if len(data) < 4 {
return Message{}, ErrInvalidHeaderLength
}
headerLen := int(binary.BigEndian.Uint32(data[:4]))
if headerLen == 0 || headerLen > len(data)-4 {
return Message{}, ErrInvalidHeaderLength
}
headerPayload := data[4 : 4+headerLen]
body := data[4+headerLen:]
var header messageHeader
if err := json.Unmarshal(headerPayload, &header); err != nil {
return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err))
}
if header.ContentLength < 0 || header.ContentLength != len(body) {
return Message{}, ErrInvalidContentLength
}
bodyCopy := make([]byte, len(body))
copy(bodyCopy, body)
msg := Message{
Type: header.Type,
ID: header.ID,
From: header.From,
To: header.To,
FileName: header.FileName,
Body: bodyCopy,
}
if err := validateMessage(msg); err != nil {
return Message{}, err
}
return msg, nil
}
// WriteFrame 向流中写入一个带长度前缀的帧。
// TCP帧格式如下
// 1. 4 字节大端序长度
// 2. 后续 payload 内容
//
// TCP 是字节流协议,没有天然的消息边界。
// 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节,
// 从而解决粘包和拆包问题。
func WriteFrame(w io.Writer, payload []byte) error {
size := len(payload)
//空帧
if size == 0 {
return ErrInvalidFrameLength
}
//帧过大
if size > MaxFrameSize {
return ErrFrameTooLarge
}
var header [4]byte
binary.BigEndian.PutUint32(header[:], uint32(size))
// 先写长度头,接收方才能根据长度一次性读取完整消息体。
if err := writeFull(w, header[:]); err != nil {
return err
}
return writeFull(w, payload)
}
// ReadFrame 从流中读取一个完整的长度前缀帧。
// 它会先读取固定 4 字节长度头,校验长度是否合法,
// 再使用 io.ReadFull 按长度读取完整消息体,
// 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。
func ReadFrame(r io.Reader) ([]byte, error) {
var header [4]byte
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, err
}
size := binary.BigEndian.Uint32(header[:])
// 长度为 0 的帧被认为是非法输入,而不是合法的空消息。
if size == 0 {
return nil, ErrInvalidFrameLength
}
// 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。
if size > MaxFrameSize {
return nil, ErrFrameTooLarge
}
payload := make([]byte, int(size))
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
return payload, nil
}
// WriteMessage 是给上层直接使用的完整发送路径:
// 把一条结构化消息完整编码并发送出去”的总入口。
// Message -> header+body -> 长度前缀帧 -> io.Writer。
func WriteMessage(w io.Writer, msg Message) error {
payload, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("protocol: encode message: %w", err)
}
if err := WriteFrame(w, payload); err != nil {
return fmt.Errorf("protocol: write frame: %w", err)
}
return nil
}
// ReadMessage 是给上层直接使用的完整接收路径:
// io.Reader -> 长度前缀帧 -> header+body -> Message。
func ReadMessage(r io.Reader) (Message, error) {
payload, err := ReadFrame(r)
if err != nil {
return Message{}, fmt.Errorf("protocol: read frame: %w", err)
}
msg, err := DecodeMessage(payload)
if err != nil {
return Message{}, fmt.Errorf("protocol: decode message: %w", err)
}
return msg, nil
}
// validateMessage 检查 Message 传输的类型(只接受 text 和 file )。
func validateMessage(msg Message) error {
if msg.From == "" {
return ErrMissingFrom
}
if msg.To == "" {
return ErrMissingTo
}
switch msg.Type {
case MessageTypeText:
if msg.FileName != "" {
return ErrUnexpectedFileName
}
if !utf8.Valid(msg.Body) {
return ErrInvalidTextBody
}
case MessageTypeFile:
if msg.FileName == "" {
return ErrMissingFileName
}
case MessageTypeRegister:
if msg.To != ServerPeerID {
return ErrInvalidRegisterTarget
}
if msg.FileName != "" {
return ErrUnexpectedFileName
}
if len(msg.Body) != 0 {
return ErrUnexpectedBody
}
case MessageTypeError:
if msg.From != ServerPeerID {
return ErrInvalidErrorSource
}
if msg.FileName != "" {
return ErrUnexpectedFileName
}
if !utf8.Valid(msg.Body) {
return ErrInvalidTextBody
}
default:
return ErrInvalidMessageType
}
return nil
}
// writeFull 会持续写入,直到所有字节都写完或者底层返回错误。
// 这样可以避免某些 Writer 发生部分写入时破坏帧格式。
func writeFull(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if err != nil {
return err
}
if n == 0 {
return io.ErrShortWrite
}
data = data[n:]
}
return nil
}

View File

@@ -0,0 +1,507 @@
package protocol
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"reflect"
"strings"
"testing"
)
// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
func TestEncodeDecodeMessageTextASCII(t *testing.T) {
original := Message{
Type: MessageTypeText,
ID: 42,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8
// 从而天然兼容 ASCII 之外的普通文本。
func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
original := Message{
Type: MessageTypeText,
ID: 43,
From: "peer-a",
To: "peer-b",
Body: []byte("你好, world"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
func TestEncodeDecodeMessageFile(t *testing.T) {
original := Message{
Type: MessageTypeFile,
ID: 44,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0xff, 0x10, 0x7f},
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
func TestEncodeDecodeMessageRegister(t *testing.T) {
original := Message{
Type: MessageTypeRegister,
ID: 45,
From: "peer-a",
To: ServerPeerID,
Body: []byte{},
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
func TestEncodeDecodeMessageError(t *testing.T) {
original := Message{
Type: MessageTypeError,
ID: 46,
From: ServerPeerID,
To: "peer-a",
Body: []byte("unknown target"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
func TestWriteReadFrame(t *testing.T) {
var buf bytes.Buffer
payload := []byte("header+body")
if err := WriteFrame(&buf, payload); err != nil {
t.Fatalf("WriteFrame() error = %v", err)
}
got, err := ReadFrame(&buf)
if err != nil {
t.Fatalf("ReadFrame() error = %v", err)
}
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch: got %q want %q", got, payload)
}
}
// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
// 因为外层帧非空的前提下,空正文是合法业务内容。
func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
tests := []struct {
name string
message Message
}{
{
name: "empty text",
message: Message{
Type: MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte(""),
},
},
{
name: "empty file",
message: Message{
Type: MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "empty.txt",
Body: []byte{},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
if err := WriteMessage(&buf, tt.message); err != nil {
t.Fatalf("WriteMessage() error = %v", err)
}
got, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage() error = %v", err)
}
if !reflect.DeepEqual(got, tt.message) {
t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
}
})
}
}
// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
tests := []struct {
name string
message Message
wantErr error
}{
{
name: "invalid type",
message: Message{
Type: MessageType("unknown"),
ID: 1,
From: "peer-a",
To: "peer-b",
},
wantErr: ErrInvalidMessageType,
},
{
name: "missing from",
message: Message{
Type: MessageTypeText,
ID: 2,
To: "peer-b",
},
wantErr: ErrMissingFrom,
},
{
name: "missing to",
message: Message{
Type: MessageTypeText,
ID: 3,
From: "peer-a",
},
wantErr: ErrMissingTo,
},
{
name: "text with file name",
message: Message{
Type: MessageTypeText,
ID: 4,
From: "peer-a",
To: "peer-b",
FileName: "bad.txt",
Body: []byte("hello"),
},
wantErr: ErrUnexpectedFileName,
},
{
name: "text with invalid utf8",
message: Message{
Type: MessageTypeText,
ID: 5,
From: "peer-a",
To: "peer-b",
Body: []byte{0xff, 0xfe},
},
wantErr: ErrInvalidTextBody,
},
{
name: "file without file name",
message: Message{
Type: MessageTypeFile,
ID: 6,
From: "peer-a",
To: "peer-b",
Body: []byte{0x01},
},
wantErr: ErrMissingFileName,
},
{
name: "register with wrong target",
message: Message{
Type: MessageTypeRegister,
ID: 7,
From: "peer-a",
To: "peer-b",
},
wantErr: ErrInvalidRegisterTarget,
},
{
name: "register with body",
message: Message{
Type: MessageTypeRegister,
ID: 8,
From: "peer-a",
To: ServerPeerID,
Body: []byte("unexpected"),
},
wantErr: ErrUnexpectedBody,
},
{
name: "error with wrong source",
message: Message{
Type: MessageTypeError,
ID: 9,
From: "peer-a",
To: "peer-b",
Body: []byte("bad"),
},
wantErr: ErrInvalidErrorSource,
},
{
name: "error with file name",
message: Message{
Type: MessageTypeError,
ID: 10,
From: ServerPeerID,
To: "peer-a",
FileName: "bad.txt",
Body: []byte("bad"),
},
wantErr: ErrUnexpectedFileName,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := EncodeMessage(tt.message)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
}
})
}
}
// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
// 而不是被当成一条合法的空消息。
func TestReadFrameRejectsInvalidLength(t *testing.T) {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
_, err := ReadFrame(&buf)
if !errors.Is(err, ErrInvalidFrameLength) {
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
}
}
// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
// 从而保证最大长度限制真正生效。
func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
_, err := ReadFrame(&buf)
if !errors.Is(err, ErrFrameTooLarge) {
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
}
}
// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
var buf bytes.Buffer
err := WriteFrame(&buf, nil)
if !errors.Is(err, ErrInvalidFrameLength) {
t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
}
}
// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
tests := []struct {
name string
data []byte
}{
{
name: "too short for header len",
data: []byte{0x00, 0x00, 0x00},
},
{
name: "zero header len",
data: []byte{0x00, 0x00, 0x00, 0x00},
},
{
name: "header len exceeds payload",
data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := DecodeMessage(tt.data)
if !errors.Is(err, ErrInvalidHeaderLength) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
}
})
}
}
// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
_, err := DecodeMessage(data)
if !errors.Is(err, ErrInvalidHeaderJSON) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
}
}
// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
headerPayload, err := json.Marshal(messageHeader{
Type: MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
ContentLength: 10,
})
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
var data bytes.Buffer
if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
if _, err := data.Write(headerPayload); err != nil {
t.Fatalf("data.Write(headerPayload) error = %v", err)
}
if _, err := data.Write([]byte("hello")); err != nil {
t.Fatalf("data.Write(body) error = %v", err)
}
_, err = DecodeMessage(data.Bytes())
if !errors.Is(err, ErrInvalidContentLength) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
}
}
// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file
// 验证读取端每次都能严格停在当前帧边界,不会串包。
func TestReadMultipleMessages(t *testing.T) {
var buf bytes.Buffer
first := Message{
Type: MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
second := Message{
Type: MessageTypeFile,
ID: 2,
From: "peer-b",
To: "peer-a",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
}
if err := WriteMessage(&buf, first); err != nil {
t.Fatalf("WriteMessage(first) error = %v", err)
}
if err := WriteMessage(&buf, second); err != nil {
t.Fatalf("WriteMessage(second) error = %v", err)
}
gotFirst, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage(first) error = %v", err)
}
gotSecond, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage(second) error = %v", err)
}
if !reflect.DeepEqual(gotFirst, first) {
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
}
if !reflect.DeepEqual(gotSecond, second) {
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
}
}
// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
func TestReadMessageWrapsDecodeError(t *testing.T) {
var buf bytes.Buffer
if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
t.Fatalf("WriteFrame() error = %v", err)
}
_, err := ReadMessage(&buf)
if err == nil {
t.Fatal("ReadMessage() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "decode message") {
t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
}
}

View File

@@ -0,0 +1,33 @@
package protocol
// MessageType 表示一条消息的传输类型。
// v1 只区分普通文本和文件两类负载。
type MessageType string
const (
// MessageTypeText 表示正文按 UTF-8 文本解释,天然兼容 ASCII。
MessageTypeText MessageType = "text"
// MessageTypeFile 表示正文是原始文件字节。
MessageTypeFile MessageType = "file"
// MessageTypeRegister 表示 peer 向 server 显式注册自己的身份。
MessageTypeRegister MessageType = "register"
// MessageTypeError 表示 server 向 peer 返回错误信息。
MessageTypeError MessageType = "error"
)
// ServerPeerID 是协议中约定的 server 端固定标识。
const ServerPeerID = "server"
// Message 是 peer 和 server 共用的传输消息结构。
// 头部元信息会被编码为 JSONBody 则作为原始字节拼接在头部之后。
type Message struct {
Type MessageType `json:"type"` // 消息类型,只允许 text 或 file。
ID uint64 `json:"id"` // 由发送方生成,用于追踪消息。
From string `json:"from"` // 发送方标识。
To string `json:"to"` // 接收方标识。
// FileName 仅在 Type 为 file 时使用。
FileName string `json:"file_name,omitempty"`
// Body 是真正传输的正文内容,不进入头部 JSON。
Body []byte `json:"-"`
}

192
cmd/internal/server/hub.go Normal file
View File

@@ -0,0 +1,192 @@
package server
import (
"fmt"
"net"
"sync"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
const gracefulRejectCloseTimeout = 100 * time.Millisecond
// Hub 管理已注册 peer 的连接,并负责在它们之间转发消息。
type Hub struct {
mu sync.RWMutex
peers map[string]*transport.TCPConn
logger latencylog.Logger
}
// Option 用于配置 Hub 的可选行为,例如时延日志。
type Option func(*Hub)
// WithLogger 为 hub 注入时延日志记录器。
func WithLogger(logger latencylog.Logger) Option {
return func(hub *Hub) {
hub.logger = logger
}
}
// NewHub 创建一个空的连接中心。
func NewHub(opts ...Option) *Hub {
hub := &Hub{
peers: make(map[string]*transport.TCPConn),
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
return hub
}
// HasPeer 返回给定 ID 是否已经注册到 hub。
func (h *Hub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
// ServeConn 处理一条新接入的底层 TCP 连接。
// 连接上的第一条消息必须是 register之后才允许转发 text/file。
func (h *Hub) ServeConn(rawConn net.Conn) error {
conn, err := transport.NewTCPConn(rawConn)
if err != nil {
_ = rawConn.Close()
return fmt.Errorf("server: create transport conn: %w", err)
}
peerID, gracefulClose, err := h.registerConn(conn)
if err != nil {
h.closeConn(conn, gracefulClose)
return err
}
defer h.unregister(peerID, conn)
if err := h.receivePeerLoop(peerID, conn); err != nil {
return err
}
return nil
}
// registerConn 从新连接上读取第一条消息,验证它是 register 消息,并把连接注册到 hub。
func (h *Hub) registerConn(conn *transport.TCPConn) (string, bool, error) {
msg, err := conn.Receive()
if err != nil {
return "", false, fmt.Errorf("server: receive register: %w", err)
}
if msg.Type != protocol.MessageTypeRegister {
if sendErr := sendServerError(conn, msg.From, "first message must be register"); sendErr != nil {
return "", false, fmt.Errorf("server: reject unregistered peer: %w", sendErr)
}
return "", true, fmt.Errorf("server: first message must be register, got %s", msg.Type)
}
h.mu.Lock()
defer h.mu.Unlock()
if _, exists := h.peers[msg.From]; exists {
if sendErr := sendServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
return "", false, fmt.Errorf("server: duplicate peer id %s: %w", msg.From, sendErr)
}
return "", true, fmt.Errorf("server: duplicate peer id: %s", msg.From)
}
h.peers[msg.From] = conn
return msg.From, false, nil
}
// handlePeerMessage 验证消息类型并执行相应的转发或错误响应。
func (h *Hub) handlePeerMessage(peerID string, conn *transport.TCPConn, msg protocol.Message) (bool, error) {
switch msg.Type {
case protocol.MessageTypeText, protocol.MessageTypeFile: //只允许已注册的 peer 发送文本或文件消息,其他类型都视为协议错误。
msg.From = peerID
targetConn, ok := h.lookup(msg.To)
if !ok {
return false, sendServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
}
if err := targetConn.Send(msg); err != nil { //转发消息,如果发送失败,说明目标连接可能已经不可用,此时从 hub 中注销该连接并关闭它,并向发送方返回错误响应。
h.unregister(msg.To, targetConn)
_ = targetConn.Close()
return false, sendServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
}
return false, nil
case protocol.MessageTypeRegister, protocol.MessageTypeError: //已注册的 peer 不允许再发送 register 或 error 消息,这些都视为协议错误。
if err := sendServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
return false, fmt.Errorf("server: send protocol error: %w", err)
}
return true, fmt.Errorf("server: unexpected message type from peer %s: %s", peerID, msg.Type)
default: // 其他任何消息类型都视为协议错误。
if err := sendServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
return false, fmt.Errorf("server: send unsupported type error: %w", err)
}
return true, fmt.Errorf("server: unsupported message type: %s", msg.Type)
}
}
func (h *Hub) receivePeerLoop(peerID string, conn *transport.TCPConn) error {
for {
msg, err := conn.Receive()
if err != nil {
_ = conn.Close()
return fmt.Errorf("transport: receive loop read: %w", err)
}
gracefulClose, err := h.handlePeerMessage(peerID, conn, msg)
if err != nil {
h.closeConn(conn, gracefulClose)
return fmt.Errorf("transport: receive loop handler: %w", err)
}
}
}
// lookup 在 hub 中查找目标 peer 的连接。
func (h *Hub) lookup(peerID string) (*transport.TCPConn, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, ok := h.peers[peerID]
return conn, ok
}
// unregister 从 hub 中移除指定 peer 的连接,通常在连接关闭或发生错误时调用。
func (h *Hub) unregister(peerID string, conn *transport.TCPConn) {
h.mu.Lock()
defer h.mu.Unlock()
current, ok := h.peers[peerID]
if ok && current == conn {
delete(h.peers, peerID)
}
}
func (h *Hub) closeConn(conn *transport.TCPConn, graceful bool) {
if graceful {
_ = conn.CloseGracefully(gracefulRejectCloseTimeout)
return
}
_ = conn.Close()
}
// sendServerError 是一个辅助函数,用于向指定 peer 发送错误消息。
func sendServerError(conn *transport.TCPConn, to, message string) error {
return conn.Send(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
})
}

View File

@@ -0,0 +1,398 @@
package server
import (
"net"
"reflect"
"strings"
"sync"
"testing"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
type recordingLogger struct {
mu sync.Mutex
events []latencylog.Event
}
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
l.mu.Lock()
defer l.mu.Unlock()
l.events = append(l.events, event)
return nil
}
func (l *recordingLogger) Events() []latencylog.Event {
l.mu.Lock()
defer l.mu.Unlock()
return append([]latencylog.Event(nil), l.events...)
}
func TestServeConnRegistersPeer(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
}); err != nil {
t.Fatalf("Send(register) error = %v", err)
}
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Close(); err != nil {
t.Fatalf("client.Close() error = %v", err)
}
err := <-done
if err == nil || !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("ServeConn() error = %v, want read-loop shutdown error", err)
}
}
func TestServeConnRejectsDuplicatePeerID(t *testing.T) {
hub := NewHub()
first, firstDone := startHubConn(t, hub)
registerPeer(t, first, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
second, secondDone := startHubConn(t, hub)
registerPeer(t, second, "peer-a")
got, err := second.Receive()
if err != nil {
t.Fatalf("second.Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "duplicate peer id: peer-a" {
t.Fatalf("error body = %q, want duplicate peer message", got.Body)
}
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "original peer-a to remain registered")
if err := first.Close(); err != nil {
t.Fatalf("first.Close() error = %v", err)
}
if err := <-secondDone; err == nil || !strings.Contains(err.Error(), "duplicate peer id") {
t.Fatalf("second ServeConn() error = %v, want duplicate peer id error", err)
}
if err := <-firstDone; err == nil || !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("first ServeConn() error = %v, want read-loop shutdown error", err)
}
}
func TestServeConnForwardsMessages(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "spoofed",
To: "peer-b",
Body: []byte("hello"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "spoofed",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hub := NewHub()
sender, senderDone := startHubConn(t, hub)
receiver, receiverDone := startHubConn(t, hub)
registerPeer(t, sender, "peer-a")
registerPeer(t, receiver, "peer-b")
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
if err := sender.Send(tt.msg); err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
want := tt.msg
want.From = "peer-a"
if !reflect.DeepEqual(got, want) {
t.Fatalf("forwarded message mismatch: got %+v want %+v", got, want)
}
_ = sender.Close()
_ = receiver.Close()
<-senderDone
<-receiverDone
})
}
}
func TestServeConnReturnsErrorForUnknownTarget(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "missing-peer",
Body: []byte("hello"),
}); err != nil {
t.Fatalf("Send(text) error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "unknown target: missing-peer" {
t.Fatalf("error body = %q, want unknown target message", got.Body)
}
if !hub.HasPeer("peer-a") {
t.Fatal("peer-a should remain registered after unknown target error")
}
_ = client.Close()
<-done
}
func TestServeConnRejectsRegisterAfterRegistration(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
}); err != nil {
t.Fatalf("Send(register again) error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "registered peers can only send text or file messages" {
t.Fatalf("error body = %q, want registered-peer protocol error", got.Body)
}
if err := <-done; err == nil || !strings.Contains(err.Error(), "unexpected message type from peer peer-a: register") {
t.Fatalf("ServeConn() error = %v, want unexpected register message error", err)
}
}
func TestServeConnUnregistersPeerOnClose(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Close(); err != nil {
t.Fatalf("client.Close() error = %v", err)
}
<-done
waitFor(t, func() bool { return !hub.HasPeer("peer-a") }, "peer-a to be unregistered")
}
func TestServeConnDoesNotEmitEndpointLatencyEventsOnForward(t *testing.T) {
logger := &recordingLogger{}
hub := NewHub(WithLogger(logger))
sender, senderDone := startHubConn(t, hub)
receiver, receiverDone := startHubConn(t, hub)
registerPeer(t, sender, "peer-a")
registerPeer(t, receiver, "peer-b")
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 11,
From: "spoofed",
To: "peer-b",
Body: []byte("hello"),
}
if err := sender.Send(msg); err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
msg.From = "peer-a"
if !reflect.DeepEqual(got, msg) {
t.Fatalf("forwarded message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) != 0 {
t.Fatalf("event count = %d, want 0 because server is a black-box relay", len(events))
}
_ = sender.Close()
_ = receiver.Close()
<-senderDone
<-receiverDone
}
func TestServeConnDoesNotLogLatencyEventsForUnknownTarget(t *testing.T) {
logger := &recordingLogger{}
hub := NewHub(WithLogger(logger))
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeText,
ID: 15,
From: "peer-a",
To: "missing-peer",
Body: []byte("hello"),
}); err != nil {
t.Fatalf("Send(text) error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if events := logger.Events(); len(events) != 0 {
t.Fatalf("event count = %d, want 0 for unknown target path", len(events))
}
_ = client.Close()
<-done
}
func TestServeConnDoesNotLogLatencyEventsForDuplicateRegister(t *testing.T) {
logger := &recordingLogger{}
hub := NewHub(WithLogger(logger))
first, firstDone := startHubConn(t, hub)
registerPeer(t, first, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
second, secondDone := startHubConn(t, hub)
registerPeer(t, second, "peer-a")
got, err := second.Receive()
if err != nil {
t.Fatalf("second.Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
}
if events := logger.Events(); len(events) != 0 {
t.Fatalf("event count = %d, want 0 for duplicate register path", len(events))
}
_ = first.Close()
<-secondDone
<-firstDone
}
func startHubConn(t *testing.T, hub *Hub) (*transport.TCPConn, <-chan error) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
done := make(chan error, 1)
go func() {
serverSide, acceptErr := listener.Accept()
if acceptErr != nil {
done <- acceptErr
return
}
done <- hub.ServeConn(serverSide)
}()
clientSide, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
_ = listener.Close()
t.Fatalf("net.Dial() error = %v", err)
}
if err := listener.Close(); err != nil {
t.Fatalf("listener.Close() error = %v", err)
}
conn, err := transport.NewTCPConn(clientSide)
if err != nil {
_ = clientSide.Close()
t.Fatalf("transport.NewTCPConn() error = %v", err)
}
return conn, done
}
func registerPeer(t *testing.T, conn *transport.TCPConn, peerID string) {
t.Helper()
if err := conn.Send(protocol.Message{
Type: protocol.MessageTypeRegister,
From: peerID,
To: protocol.ServerPeerID,
}); err != nil {
t.Fatalf("Send(register %s) error = %v", peerID, err)
}
}
func waitFor(t *testing.T, condition func() bool, description string) {
t.Helper()
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for %s", description)
}

View File

@@ -0,0 +1,151 @@
package transport
import (
"errors"
"fmt"
"io"
"net"
"sync"
"syscall"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// TCPConn 是对单条活跃 TCP 连接的轻量封装。
// 它负责把协议层的单条消息读写,提升为可复用的收发接口。
type TCPConn struct {
conn net.Conn
raw syscall.RawConn // 连接对应的底层 syscall 句柄,用于 Linux socket timestamping 收发。
logger latencylog.Logger
nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer"
nodeID string // 日志中记录的节点 ID例如 peer 的 ID 或 server 的 "hub"
writeMu sync.Mutex // 保护 Send 方法的互斥锁,确保同一时刻只有一条完整协议消息被写入连接,防止多条消息字节交叉
closeOnce sync.Once // 保护 Close 方法的 sync.Once确保连接只被关闭一次
closeErr error // 连接关闭时的错误,如果连接成功关闭则为 nil重复调用 Close 时会返回同样的错误
}
// Option 用于为 TCPConn 注入可选行为,例如时延日志。
type Option func(*TCPConn)
// WithLogger 为连接发送路径注入业务消息日志上下文。
func WithLogger(logger latencylog.Logger, nodeRole, nodeID string) Option {
return func(conn *TCPConn) {
conn.logger = logger
conn.nodeRole = nodeRole
conn.nodeID = nodeID
}
}
// NewTCPConn 用已有的 net.Conn 创建 transport 连接封装。
func NewTCPConn(conn net.Conn, opts ...Option) (*TCPConn, error) {
tcpConn := &TCPConn{
conn: conn,
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(tcpConn)
}
if tcpConn.logger == nil {
tcpConn.logger = latencylog.NoopLogger{}
}
if err := tcpConn.initLinuxTimestamping(); err != nil {
return nil, err
}
return tcpConn, nil
}
// Send 将一条协议消息完整写入底层连接。
// 多个 goroutine 可以并发调用,内部会串行化写入。
func (c *TCPConn) Send(msg protocol.Message) error {
c.writeMu.Lock() //“同一时刻只能有一条完整协议消息往连接里写,防止多条消息字节交叉
defer c.writeMu.Unlock()
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
if err := c.sendMessageLinux(msg); err != nil {
return fmt.Errorf("transport: send message: %w", err)
}
//记录发送完成的时延日志事件,事件类型为 EventSendHandoffEnd包含消息的基本信息类型、ID、来源、目标
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
return nil
}
// Receive 从底层连接读取一条完整协议消息。
// 同一条连接应只由单个 reader 持续调用该方法。
func (c *TCPConn) Receive() (protocol.Message, error) {
msg, err := c.receiveMessageLinux()
if err != nil {
return protocol.Message{}, fmt.Errorf("transport: receive message: %w", err)
}
return msg, nil
}
// ReceiveLoop 持续读取消息并交给 handler 处理。
// 读取错误、handler 错误或连接关闭都会结束循环,并关闭连接。
func (c *TCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
for {
msg, err := c.Receive()
if err != nil {
_ = c.Close()
return fmt.Errorf("transport: receive loop read: %w", err)
}
if err := handler(msg); err != nil {
_ = c.Close()
return fmt.Errorf("transport: receive loop handler: %w", err)
}
}
}
// CloseGracefully 在支持 half-close 的连接上先关闭写方向,给对端留出读取最终响应的机会,
// 然后在短暂等待后再彻底关闭连接。
func (c *TCPConn) CloseGracefully(drainTimeout time.Duration) error {
if closeWriter, ok := c.conn.(interface{ CloseWrite() error }); ok {
if err := closeWriter.CloseWrite(); err != nil && !errors.Is(err, net.ErrClosed) {
return c.Close()
}
if drainTimeout > 0 {
_ = c.conn.SetReadDeadline(time.Now().Add(drainTimeout))
defer func() {
_ = c.conn.SetReadDeadline(time.Time{})
}()
var buf [256]byte
for {
_, err := c.conn.Read(buf[:])
switch {
case err == nil:
continue
case errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
return c.Close()
default:
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return c.Close()
}
return c.Close()
}
}
}
}
return c.Close()
}
// Close 关闭底层连接,并保证重复调用是安全的。
func (c *TCPConn) Close() error {
c.closeOnce.Do(func() {
c.closeErr = c.conn.Close()
})
return c.closeErr
}

View File

@@ -0,0 +1,462 @@
//go:build linux
package transport
import (
"encoding/binary"
"errors"
"fmt"
"io"
"syscall"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
const (
linuxTimestampControlBufferSize = 256 // 控制消息缓冲区。
linuxTXTimestampWaitTimeout = 250 * time.Millisecond // 等待 TX 时间戳的上限。
linuxTXTimestampPollInterval = time.Millisecond // 轮询 errqueue 的间隔。
linuxSOTimestampingNew = 0x41
linuxSCMTimestampingNew = linuxSOTimestampingNew
linuxSOEEOriginTimestamping = 4 // timestamping errqueue 事件。
linuxSCMTstampSnd = 0 // 对应 A_TX_SOFTWARE。
linuxSCMTstampSched = 1 // 对应 A_TX_SCHED。
linuxSOFTimestampingTXSoftware = 1 << 1 // 打开 TX software timestamp。
linuxSOFTimestampingRXSoftware = 1 << 3 // 打开 RX software timestamp。
linuxSOFTimestampingSoftware = 1 << 4 // software timestamp 总开关。
linuxSOFTimestampingOptID = 1 << 7 // 给时间戳关联 ID。
linuxSOFTimestampingTXSched = 1 << 8 // 打开 TX sched timestamp。
linuxSOFTimestampingOptTSONLY = 1 << 11 // 只回时间戳。
linuxSOFTimestampingOptIDTCP = 1 << 16 // 让 TCP 也带 timestamp ID。
)
// 拿到底层 fd并打开 Linux timestamping。
func (c *TCPConn) initLinuxTimestamping() error {
sysConn, ok := c.conn.(interface {
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return fmt.Errorf("transport: connection does not support SyscallConn")
}
rawConn, err := sysConn.SyscallConn()
if err != nil || rawConn == nil {
if err != nil {
return fmt.Errorf("transport: get syscall conn: %w", err)
}
return fmt.Errorf("transport: missing syscall conn")
}
//socket是否可以成功打开 timestamping 取决于内核版本和配置,尝试多个 flag 组合直到成功或遇到非 EINVAL 错误。
if err := enableLinuxTimestamping(rawConn); err != nil {
return fmt.Errorf("transport: enable linux timestamping: %w", err)
}
//成功打开 timestamping 后rawConn 就可以用来收 TX/RX 时间戳了。
c.raw = rawConn
return nil
}
// 给 socket开权限打开TX software timestamping。
func enableLinuxTimestamping(rawConn syscall.RawConn) error {
flagCandidates := []int{ //不同linux版本可能支持不同的 flag 组合,尝试多个组合直到成功。
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptID | //TCP 协议栈给每个时间戳生成一个序列号
linuxSOFTimestampingOptIDTCP |
linuxSOFTimestampingOptTSONLY,
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptID |
linuxSOFTimestampingOptTSONLY,
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptTSONLY,
}
var lastErr error
for _, flags := range flagCandidates { //尝试不同的 flag 组合,直到成功或遇到非 EINVAL 错误。
// 内核根据 fd 找到对应的内存结构体Socket 缓冲区)
err := rawConn.Control(func(fd uintptr) { //Control 方法保证在回调里 fd 是有效的,可以安全地调用 syscall.SetsockoptInt。
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
})
if err != nil {
return err
}
if lastErr == nil {
return nil
}
if !errors.Is(lastErr, syscall.EINVAL) {
return lastErr
}
}
return lastErr
}
// sendMessageLinux 编码消息、写完整帧,再记录 TX 时间戳。
func (c *TCPConn) sendMessageLinux(msg protocol.Message) error {
payload, err := protocol.EncodeMessage(msg)
if err != nil {
return fmt.Errorf("protocol: encode message: %w", err)
}
//编码后的消息 payload 前面加 4 字节长度,构成完整帧。
frame := make([]byte, 4+len(payload))
binary.BigEndian.PutUint32(frame[:4], uint32(len(payload)))
copy(frame[4:], payload)
if err := c.writeFrameLinux(frame); err != nil {
return fmt.Errorf("protocol: write frame: %w", err)
}
//记录发送延时日志
c.logTXTimestampEvents(msg)
return nil
}
// writeFrameLinux 用 sendmsg 写完整帧。
func (c *TCPConn) writeFrameLinux(frame []byte) error {
written := 0
var opErr error
err := c.raw.Write(func(fd uintptr) bool {
if written >= len(frame) {
return true
}
n, sendErr := syscall.SendmsgN(int(fd), frame[written:], nil, nil, 0)
switch {
case sendErr == nil:
if n <= 0 {
opErr = io.ErrShortWrite
return true
}
written += n
return written >= len(frame)
case errors.Is(sendErr, syscall.EAGAIN), errors.Is(sendErr, syscall.EWOULDBLOCK):
return false
default:
opErr = sendErr
return true
}
})
if err != nil {
return err
}
if opErr != nil {
return opErr
}
if written != len(frame) {
return io.ErrShortWrite
}
return nil
}
// 把 A_TX_SCHED / A_TX_SOFTWARE 写入日志。(发送过程中)
func (c *TCPConn) logTXTimestampEvents(msg protocol.Message) {
timestamps := c.collectTXTimestampEvents()
if ts, ok := timestamps[latencylog.EventATXSched]; ok {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg)
}
if ts, ok := timestamps[latencylog.EventATXSoftware]; ok {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg)
}
}
// 在 errqueue 里等两类 TX 时间戳。
func (c *TCPConn) collectTXTimestampEvents() map[string]int64 {
timestamps := make(map[string]int64, 2)
//设置合理等待上限
deadline := time.Now().Add(linuxTXTimestampWaitTimeout)
//轮询 errqueue 直到拿到两类时间戳,或超时,或遇到非 EAGAIN 错误。
for len(timestamps) < 2 && time.Now().Before(deadline) {
eventName, ts, err := c.recvTXTimestampOnce()
if err != nil {
if isWouldBlock(err) {
time.Sleep(linuxTXTimestampPollInterval)
continue
}
break
}
if eventName == "" || ts <= 0 {
continue
}
if _, exists := timestamps[eventName]; !exists {
timestamps[eventName] = ts
}
}
return timestamps
}
// recvTXTimestampOnce 从 errqueue 读一次时间戳事件。
func (c *TCPConn) recvTXTimestampOnce() (string, int64, error) {
var (
eventName string // 事件名,例如 A_TX_SCHED 或 A_TX_SOFTWARE。
tsUnixNS int64 // 时间戳的 UnixNano 表示。
opErr error
)
err := c.raw.Control(func(fd uintptr) {
//设置足够大的 oob buffer 来接收控制消息,调用 recvmsg 从 errqueue 读一条消息。
oob := make([]byte, linuxTimestampControlBufferSize)
//recvmsg 的 flags 里必须带 MSG_ERRQUEUE才能从 errqueue 里读消息,非阻塞模式下如果没有消息可读会返回 EAGAIN。
_, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
if recvErr != nil {
opErr = recvErr
return
}
//解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
eventName, tsUnixNS = parseTXTimestampControlMessages(oob[:oobn])
})
if err != nil {
return "", 0, err
}
if opErr != nil {
return "", 0, opErr
}
return eventName, tsUnixNS, nil //如果成功拿到时间戳事件eventName 会是 A_TX_SCHED 或 A_TX_SOFTWARE 之一tsUnixNS 是对应的时间戳如果没有拿到事件或时间戳无效eventName 会是空字符串tsUnixNS 会是 0。
}
// 把底层时间戳映射成日志事件名。
func parseTXTimestampControlMessages(oob []byte) (string, int64) {
if len(oob) == 0 {
return "", 0
}
//解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
controlMessages, err := syscall.ParseSocketControlMessage(oob)
if err != nil {
return "", 0
}
var (
tsUnixNS int64 //时间戳的 UnixNano 表示。
tsKind uint32 //extended err里告诉我们这个时间戳是 sched 还是 software。
hasTS bool // 是否拿到时间戳了。
hasKind bool // 是否拿到时间戳类型了。
)
//一个 recvmsg 可能会收到多个控制消息,循环找我们关心的时间戳事件,拿到时间戳和事件类型。
for _, controlMessage := range controlMessages {
switch {
case controlMessage.Header.Level == syscall.SOL_SOCKET && controlMessage.Header.Type == linuxSCMTimestampingNew:
if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
tsUnixNS = ts
hasTS = true
}
case isSocketExtendedErr(controlMessage): //判断时间戳是否进入了errqueue
if info, ok := parseSocketExtendedErrInfo(controlMessage.Data); ok {
tsKind = info //时间戳类型被内核放在 extended err 的附加信息里,解析出来。
hasKind = true
}
}
}
if !hasTS || !hasKind {
return "", 0
}
switch tsKind { //把内核的时间戳类型映射成日志事件名。(记录时只关心 sched 和 software 两类时间戳)
case linuxSCMTstampSched:
return latencylog.EventATXSched, tsUnixNS
case linuxSCMTstampSnd:
return latencylog.EventATXSoftware, tsUnixNS
default:
return "", 0
}
}
// 判断控制消息是否来自 socket extended err。
// 内核产生的时间戳并不会混合在普通的数据流里,而是被包装成一种特殊的“错误消息”丢进 Error Queue。
func isSocketExtendedErr(controlMessage syscall.SocketControlMessage) bool {
switch {
case controlMessage.Header.Level == syscall.SOL_IP && controlMessage.Header.Type == syscall.IP_RECVERR:
return true
case controlMessage.Header.Level == syscall.SOL_IPV6 && controlMessage.Header.Type == syscall.IPV6_RECVERR:
return true
default:
return false
}
}
// 从 socket extended err 的数据里取 origin timestamping 信息。
func parseSocketExtendedErrInfo(data []byte) (uint32, bool) {
if len(data) < 16 {
return 0, false
}
if data[4] != linuxSOEEOriginTimestamping {
return 0, false
}
return binary.NativeEndian.Uint32(data[8:12]), true
}
// 读一条完整消息,并记录 B_RX_SOFTWARE。
func (c *TCPConn) receiveMessageLinux() (protocol.Message, error) {
payload, rxTimestamp, err := c.readFrameLinux()
if err != nil {
return protocol.Message{}, fmt.Errorf("protocol: read frame: %w", err)
}
msg, err := protocol.DecodeMessage(payload)
if err != nil {
return protocol.Message{}, fmt.Errorf("protocol: decode message: %w", err)
}
if rxTimestamp > 0 {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg)
}
return msg, nil
}
// readFrameLinux 先读 4 字节长度,再读整条 payload。
func (c *TCPConn) readFrameLinux() ([]byte, int64, error) {
var frameHeader [4]byte
rxTimestamp, err := c.readFullLinux(frameHeader[:])
if err != nil {
return nil, rxTimestamp, err
}
size := binary.BigEndian.Uint32(frameHeader[:])
switch {
case size == 0:
return nil, rxTimestamp, protocol.ErrInvalidFrameLength
case size > protocol.MaxFrameSize:
return nil, rxTimestamp, protocol.ErrFrameTooLarge
}
payload := make([]byte, int(size))
bodyTimestamp, err := c.readFullLinux(payload)
if rxTimestamp == 0 {
rxTimestamp = bodyTimestamp
}
if err != nil {
return nil, rxTimestamp, err
}
return payload, rxTimestamp, nil
}
// 读满 buf并保留首个 RX_SOFTWARE返回进入tcp协议栈的时间戳
func (c *TCPConn) readFullLinux(buf []byte) (int64, error) {
if len(buf) == 0 {
return 0, nil
}
var (
offset int
firstRXTime int64
)
for offset < len(buf) {
n, rxTimestamp, err := c.recvmsgLinux(buf[offset:])
if firstRXTime == 0 && rxTimestamp > 0 {
firstRXTime = rxTimestamp
}
if err != nil {
if errors.Is(err, io.EOF) && offset > 0 {
return firstRXTime, io.ErrUnexpectedEOF
}
return firstRXTime, err
}
offset += n
}
return firstRXTime, nil
}
// recvmsgLinux 用 recvmsg 同时读取数据和控制消息。
func (c *TCPConn) recvmsgLinux(buf []byte) (int, int64, error) {
var (
n int
rxTimeNS int64
opErr error
)
err := c.raw.Read(func(fd uintptr) bool {
oob := make([]byte, linuxTimestampControlBufferSize)
readN, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
switch {
case recvErr == nil:
if readN == 0 {
opErr = io.EOF
return true
}
n = readN
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
return true
case errors.Is(recvErr, syscall.EAGAIN), errors.Is(recvErr, syscall.EWOULDBLOCK):
return false
default:
opErr = recvErr
return true
}
})
if err != nil {
return 0, 0, err
}
if opErr != nil {
return 0, 0, opErr
}
return n, rxTimeNS, nil
}
// 从控制消息里取 RX_SOFTWARE。
func parseRXTimestampControlMessages(oob []byte) int64 {
if len(oob) == 0 {
return 0
}
controlMessages, err := syscall.ParseSocketControlMessage(oob)
if err != nil {
return 0
}
for _, controlMessage := range controlMessages {
if controlMessage.Header.Level != syscall.SOL_SOCKET || controlMessage.Header.Type != linuxSCMTimestampingNew {
continue
}
if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
return ts
}
}
return 0
}
// 取第一个非零 timespec。
func parseSCMTimestampingData(data []byte) int64 {
const timespec64Size = 16
for offset := 0; offset+timespec64Size <= len(data); offset += timespec64Size {
sec := int64(binary.NativeEndian.Uint64(data[offset : offset+8]))
nsec := int64(binary.NativeEndian.Uint64(data[offset+8 : offset+16]))
if sec == 0 && nsec == 0 {
continue
}
return sec*int64(time.Second) + nsec
}
return 0
}
// 判断错误是否是 EAGAIN 或 EWOULDBLOCK。
func isWouldBlock(err error) bool {
return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK)
}

View File

@@ -0,0 +1,140 @@
//go:build linux
package transport
import (
"net"
"reflect"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
func TestLinuxTimestampingRecordsKernelEvents(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 41,
From: "peer-a",
To: "peer-b",
Body: []byte("hello over tcp"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 42,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x00, 0x01, 0x02, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientConn, serverConn := newTCPPair(t)
senderLogger := &recordingLogger{}
receiverLogger := &recordingLogger{}
sender, err := NewTCPConn(
clientConn,
WithLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"),
)
if err != nil {
t.Fatalf("NewTCPConn(sender) error = %v", err)
}
receiver, err := NewTCPConn(
serverConn,
WithLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"),
)
if err != nil {
t.Fatalf("NewTCPConn(receiver) error = %v", err)
}
t.Cleanup(func() {
_ = sender.Close()
_ = receiver.Close()
})
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg)
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID)
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID)
assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID)
})
}
}
func newTCPPair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
type acceptResult struct {
conn net.Conn
err error
}
accepted := make(chan acceptResult, 1)
go func() {
conn, acceptErr := listener.Accept()
accepted <- acceptResult{conn: conn, err: acceptErr}
}()
clientConn, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
_ = listener.Close()
t.Fatalf("net.Dial() error = %v", err)
}
result := <-accepted
if err := listener.Close(); err != nil {
t.Fatalf("listener.Close() error = %v", err)
}
if result.err != nil {
_ = clientConn.Close()
t.Fatalf("listener.Accept() error = %v", result.err)
}
return clientConn, result.conn
}
func assertHasEvent(t *testing.T, events []latencylog.Event, wantEvent string, wantMessageID uint64) {
t.Helper()
for _, event := range events {
if event.Event == wantEvent && event.MessageID == wantMessageID {
if event.TsUnixNano <= 0 {
t.Fatalf("event %s timestamp must be positive: %+v", wantEvent, event)
}
return
}
}
t.Fatalf("missing event %s for message %d in %+v", wantEvent, wantMessageID, events)
}

View File

@@ -0,0 +1,416 @@
package transport
import (
"errors"
"io"
"reflect"
"strings"
"sync"
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
type recordingLogger struct {
mu sync.Mutex
events []latencylog.Event
}
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
l.mu.Lock()
defer l.mu.Unlock()
l.events = append(l.events, event)
return nil
}
func (l *recordingLogger) Events() []latencylog.Event {
l.mu.Lock()
defer l.mu.Unlock()
return append([]latencylog.Event(nil), l.events...)
}
type failingLogger struct{}
func (failingLogger) LogEvent(latencylog.Event) error {
return errors.New("log failed")
}
// TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。
func TestSendReceiveMessage(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0x10, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
//创建一个容量为1的缓冲通道sendErr用于接收发送操作的错误结果。
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg) //发送消息并将结果错误或nil发送到sendErr通道。
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
})
}
}
func TestSendLogsHandoffEvents(t *testing.T) {
logger := &recordingLogger{}
sender, receiver := newTransportConnPair(
t,
[]Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
)
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) != 4 {
t.Fatalf("event count = %d, want 4", len(events))
}
if events[0].Event != latencylog.EventSendHandoffBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
}
if events[1].Event != latencylog.EventATXSched {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched)
}
if events[2].Event != latencylog.EventATXSoftware {
t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware)
}
if events[3].Event != latencylog.EventSendHandoffEnd {
t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd)
}
for i, event := range events {
if event.MessageID != msg.ID {
t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID)
}
}
if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" {
t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a")
}
if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 {
t.Fatalf("timestamps must be positive: %+v", events)
}
}
func TestSendIgnoresLoggerFailure(t *testing.T) {
sender, receiver := newTransportConnPair(
t,
[]Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
nil,
)
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 9,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
got, err := receiver.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("Send() error = %v, want nil even if logger fails", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
}
// TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
func TestReceiveLoopDeliversMessages(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
want := []protocol.Message{
{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
},
{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
}
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
mu.Lock()
defer mu.Unlock()
got = append(got, msg)
return nil
})
}()
for _, msg := range want {
if err := sender.Send(msg); err != nil {
t.Fatalf("Send() error = %v", err)
}
}
if err := sender.Close(); err != nil {
t.Fatalf("sender.Close() error = %v", err)
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
}
if !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
}
mu.Lock()
defer mu.Unlock()
if !reflect.DeepEqual(got, want) {
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
}
}
// TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。
func TestConcurrentSendKeepsMessagesIntact(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
// 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。
want := []protocol.Message{
{Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")},
{Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")},
{Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")},
{Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")},
}
received := make(chan protocol.Message, len(want))
readErr := make(chan error, 1)
go func() { //异步地运行一个 goroutine
for range want {
msg, err := receiver.Receive()
if err != nil {
readErr <- err
return
}
received <- msg
}
readErr <- nil
}()
var wg sync.WaitGroup
for _, msg := range want {
msg := msg
wg.Add(1)
go func() { //异步处理
defer wg.Done()
if err := sender.Send(msg); err != nil {
t.Errorf("Send() error = %v", err)
}
}()
}
wg.Wait()
if err := <-readErr; err != nil {
t.Fatalf("Receive() error = %v", err)
}
gotByID := make(map[uint64]protocol.Message, len(want))
for range want {
msg := <-received
gotByID[msg.ID] = msg
}
for _, msg := range want {
got, ok := gotByID[msg.ID]
if !ok {
t.Fatalf("missing message with ID %d", msg.ID)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg)
}
}
}
// TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。
func TestReceiveLoopStopsOnHandlerError(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
wantErr := errors.New("stop loop")
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
return wantErr
})
}()
first := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
if err := sender.Send(first); err != nil {
t.Fatalf("Send(first) error = %v", err)
}
err := <-loopErr
if !errors.Is(err, wantErr) {
t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr)
}
if !strings.Contains(err.Error(), "receive loop handler") {
t.Fatalf("ReceiveLoop() error = %v, want handler context", err)
}
}
// TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。
func TestReceiveLoopStopsOnReadError(t *testing.T) {
sender, receiver := newTransportConnPair(t, nil, nil)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
return nil
})
}()
if err := sender.Close(); err != nil {
t.Fatalf("sender.Close() error = %v", err)
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
}
}
// TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。
func TestCloseIsIdempotent(t *testing.T) {
conn, peer := newTransportConnPair(t, nil, nil)
if err := conn.Close(); err != nil {
t.Fatalf("Close(first) error = %v", err)
}
if err := conn.Close(); err != nil {
t.Fatalf("Close(second) error = %v, want nil", err)
}
if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") {
t.Fatalf("peer.Close() error = %v", err)
}
}
// TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。
func TestReceiveReturnsWrappedReadError(t *testing.T) {
conn, peer := newTransportConnPair(t, nil, nil)
go func() {
_ = peer.Close()
}()
_, err := conn.Receive()
if err == nil {
t.Fatal("Receive() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "transport: receive message") {
t.Fatalf("Receive() error = %v, want wrapped receive error", err)
}
if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") {
t.Fatalf("Receive() error = %v, want underlying read failure", err)
}
}
func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) {
t.Helper()
left, right := newTCPPair(t)
sender, err := NewTCPConn(left, senderOpts...)
if err != nil {
t.Fatalf("NewTCPConn(sender) error = %v", err)
}
receiver, err := NewTCPConn(right, receiverOpts...)
if err != nil {
t.Fatalf("NewTCPConn(receiver) error = %v", err)
}
t.Cleanup(func() {
_ = sender.Close()
_ = receiver.Close()
})
return sender, receiver
}