feat:KCP协议

This commit is contained in:
nnbcccscdscdsc
2026-03-24 21:09:06 +08:00
parent 290ba18962
commit be013b701b
20 changed files with 2284 additions and 16 deletions

View File

@@ -0,0 +1,256 @@
package transport
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"sync"
kcp "github.com/xtaci/kcp-go/v5"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
const (
kcpNoDelayNodelay = 1
kcpNoDelayInterval = 10
kcpNoDelayResend = 2
kcpNoDelayNC = 1
kcpWindowSize = 256
kcpMTU = 1400
)
// KCPConn 是对单条活跃 KCP 会话的轻量封装。
type KCPConn struct {
session *kcp.UDPSession
logger latencylog.Logger
nodeRole string
nodeID string
writeMu sync.Mutex
closeOnce sync.Once
closeErr error
}
// KCPOption 用于为 KCPConn 注入可选行为。
type KCPOption func(*KCPConn)
// WithKCPLogger 为 KCP 连接发送路径注入业务消息日志上下文。
func WithKCPLogger(logger latencylog.Logger, nodeRole, nodeID string) KCPOption {
return func(conn *KCPConn) {
conn.logger = logger
conn.nodeRole = nodeRole
conn.nodeID = nodeID
}
}
// NewKCPConn 用已有的 KCP 会话创建 transport 连接封装。
func NewKCPConn(session *kcp.UDPSession, opts ...KCPOption) (*KCPConn, error) {
if session == nil {
return nil, fmt.Errorf("transport: nil kcp session")
}
conn := &KCPConn{
session: session,
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(conn)
}
if conn.logger == nil {
conn.logger = latencylog.NoopLogger{}
}
configureKCPSession(session)
return conn, nil
}
// Send 将一条协议消息完整写入底层 KCP 会话。
func (c *KCPConn) Send(msg protocol.Message) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
if err := protocol.WriteMessage(c.session, msg); err != nil {
return fmt.Errorf("transport: kcp send message: %w", err)
}
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
return nil
}
// Receive 从底层 KCP 会话读取一条完整协议消息。
func (c *KCPConn) Receive() (protocol.Message, error) {
msg, err := protocol.ReadMessage(c.session)
if err != nil {
return protocol.Message{}, fmt.Errorf("transport: kcp receive message: %w", err)
}
return msg, nil
}
// ReceiveLoop 持续读取消息并交给 handler 处理。
func (c *KCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
for {
msg, err := c.Receive()
if err != nil {
_ = c.Close()
return fmt.Errorf("transport: kcp receive loop read: %w", err)
}
if err := handler(msg); err != nil {
_ = c.Close()
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
}
}
}
// Close 关闭底层 KCP 会话,并保证重复调用是安全的。
func (c *KCPConn) Close() error {
c.closeOnce.Do(func() {
c.closeErr = c.session.Close()
})
return c.closeErr
}
// DialKCPSession 创建一条主动发起的 KCP 会话,并按项目默认参数配置底层 UDP socket。
func DialKCPSession(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.UDPSession, error) {
packetConn, remoteAddr, err := dialKCPPacketConn(serverAddr, bindIP, bindDevice, logger, nodeRole, nodeID)
if err != nil {
return nil, err
}
convID, err := generateKCPConversationID()
if err != nil {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: generate kcp conversation id: %w", err)
}
session, err := kcp.NewConn4(convID, remoteAddr, nil, 0, 0, true, packetConn)
if err != nil {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: create kcp session: %w", err)
}
return session, nil
}
// ListenKCPSessions 在给定地址上启动 KCP listener并返回 listener 与底层 packetConn。
func ListenKCPSessions(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.Listener, net.PacketConn, error) {
packetConn, err := listenKCPPacketConn(listenAddr, bindDevice, logger, nodeRole, nodeID)
if err != nil {
return nil, nil, err
}
listener, err := kcp.ServeConn(nil, 0, 0, packetConn)
if err != nil {
_ = packetConn.Close()
return nil, nil, fmt.Errorf("transport: serve kcp listener: %w", err)
}
return listener, packetConn, nil
}
func configureKCPSession(session *kcp.UDPSession) {
session.SetStreamMode(true)
session.SetNoDelay(kcpNoDelayNodelay, kcpNoDelayInterval, kcpNoDelayResend, kcpNoDelayNC)
session.SetWindowSize(kcpWindowSize, kcpWindowSize)
session.SetACKNoDelay(true)
session.SetWriteDelay(false)
session.SetMtu(kcpMTU)
}
func generateKCPConversationID() (uint32, error) {
var convID uint32
if err := binary.Read(rand.Reader, binary.LittleEndian, &convID); err != nil {
return 0, err
}
return convID, nil
}
func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("transport: resolve kcp listen addr %s: %w", listenAddr, err)
}
rawConn, err := listenUDPConn("udp", udpAddr, bindDevice)
if err != nil {
return nil, fmt.Errorf("transport: listen udp for kcp on %s: %w", listenAddr, err)
}
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
if err != nil {
_ = rawConn.Close()
return nil, err
}
return packetConn, nil
}
func dialKCPPacketConn(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, *net.UDPAddr, error) {
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
return nil, nil, fmt.Errorf("transport: resolve kcp server addr %s: %w", serverAddr, err)
}
localAddr := &net.UDPAddr{Port: 0}
if bindIP != "" {
ip := net.ParseIP(bindIP)
if ip == nil {
return nil, nil, fmt.Errorf("transport: invalid bind ip %q", bindIP)
}
localAddr.IP = ip
}
network := "udp"
if remoteAddr.IP.To4() != nil {
network = "udp4"
}
rawConn, err := listenUDPConn(network, localAddr, bindDevice)
if err != nil {
return nil, nil, fmt.Errorf("transport: listen udp for kcp dial to %s: %w", serverAddr, err)
}
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
if err != nil {
_ = rawConn.Close()
return nil, nil, err
}
return packetConn, remoteAddr, nil
}
func listenUDPConn(network string, localAddr *net.UDPAddr, bindDevice string) (*net.UDPConn, error) {
listenConfig := net.ListenConfig{}
if bindDevice != "" {
control, err := udpBindDeviceControl(bindDevice)
if err != nil {
return nil, err
}
listenConfig.Control = control
}
packetConn, err := listenConfig.ListenPacket(context.Background(), network, udpListenAddr(localAddr))
if err != nil {
return nil, err
}
udpConn, ok := packetConn.(*net.UDPConn)
if !ok {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: expected *net.UDPConn, got %T", packetConn)
}
return udpConn, nil
}
func udpListenAddr(addr *net.UDPAddr) string {
if addr == nil {
return ":0"
}
return addr.String()
}

View File

@@ -0,0 +1,90 @@
//go:build linux
package transport
import (
"testing"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
func TestKCPLinuxPacketDebugLogsKernelEvents(t *testing.T) {
senderPacketLogger := &recordingKCPPacketDebugLogger{}
receiverPacketLogger := &recordingKCPPacketDebugLogger{}
sender, accepted, cleanup := newKCPConnPair(t, nil, nil, senderPacketLogger, receiverPacketLogger)
defer cleanup()
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello kcp linux"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
receiver := awaitAcceptedKCPConn(t, accepted)
if _, err := receiver.Receive(); err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
waitForKCPPacketRecords(t, senderPacketLogger, func(records []KCPPacketDebugRecord) bool {
return hasKCPPacketEvent(records, latencylog.EventATXSched) && hasKCPPacketEvent(records, latencylog.EventATXSoftware)
}, "sender tx kernel timestamp records")
waitForKCPPacketRecords(t, receiverPacketLogger, func(records []KCPPacketDebugRecord) bool {
return hasKCPPacketEvent(records, latencylog.EventBRXSoftware)
}, "receiver rx kernel timestamp records")
senderRecords := senderPacketLogger.Records()
receiverRecords := receiverPacketLogger.Records()
assertKCPPacketRecord(t, senderRecords, latencylog.EventATXSched, true)
assertKCPPacketRecord(t, senderRecords, latencylog.EventATXSoftware, true)
assertKCPPacketRecord(t, receiverRecords, latencylog.EventBRXSoftware, false)
}
func hasKCPPacketEvent(records []KCPPacketDebugRecord, wantEvent string) bool {
for _, record := range records {
if record.Event == wantEvent {
return true
}
}
return false
}
func assertKCPPacketRecord(t *testing.T, records []KCPPacketDebugRecord, wantEvent string, wantUDPTXID bool) {
t.Helper()
for _, record := range records {
if record.Event != wantEvent {
continue
}
if record.TSUnixNano <= 0 {
t.Fatalf("record %s timestamp must be positive: %+v", wantEvent, record)
}
if record.PacketBytes <= 0 {
t.Fatalf("record %s packet bytes must be positive: %+v", wantEvent, record)
}
if record.KCPConv == nil {
t.Fatalf("record %s missing kcp_conv: %+v", wantEvent, record)
}
if wantUDPTXID && record.UDPTXID == nil {
t.Fatalf("record %s missing udp_tx_id: %+v", wantEvent, record)
}
if !wantUDPTXID && record.UDPTXID != nil {
t.Fatalf("record %s unexpected udp_tx_id: %+v", wantEvent, record)
}
return
}
t.Fatalf("missing KCP packet debug event %s in %+v", wantEvent, records)
}

View File

@@ -0,0 +1,91 @@
package transport
import (
"encoding/binary"
"net"
"sync"
"time"
)
func newKCPPacketConn(conn *net.UDPConn, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
return newPlatformKCPPacketConn(conn, logger, nodeRole, nodeID)
}
type kcpPacketConnBase struct {
conn *net.UDPConn
logger KCPPacketDebugLogger
nodeRole string
nodeID string
closeOnce sync.Once
closeErr error
closed chan struct{}
}
func (c *kcpPacketConnBase) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *kcpPacketConnBase) Close() error {
c.closeOnce.Do(func() {
close(c.closed)
c.closeErr = c.conn.Close()
})
return c.closeErr
}
func (c *kcpPacketConnBase) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *kcpPacketConnBase) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *kcpPacketConnBase) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *kcpPacketConnBase) SetReadBuffer(bytes int) error {
return c.conn.SetReadBuffer(bytes)
}
func (c *kcpPacketConnBase) SetWriteBuffer(bytes int) error {
return c.conn.SetWriteBuffer(bytes)
}
func (c *kcpPacketConnBase) logKCPPacketDebugRecord(record KCPPacketDebugRecord) {
if c.logger == nil {
return
}
_ = c.logger.LogKCPPacketDebugRecord(record)
}
func (c *kcpPacketConnBase) newKCPPacketDebugRecord(event string, remoteAddr net.Addr, packetBytes int, tsUnixNano int64, udpTxID *uint32, kcpConv *uint32) KCPPacketDebugRecord {
record := KCPPacketDebugRecord{
Event: event,
NodeRole: c.nodeRole,
NodeID: c.nodeID,
LocalAddr: "",
RemoteAddr: "",
PacketBytes: packetBytes,
UDPTXID: udpTxID,
KCPConv: kcpConv,
TSUnixNano: tsUnixNano,
}
if localAddr := c.conn.LocalAddr(); localAddr != nil {
record.LocalAddr = localAddr.String()
}
if remoteAddr != nil {
record.RemoteAddr = remoteAddr.String()
}
return record
}
func parseKCPConversationID(packet []byte) *uint32 {
if len(packet) < 4 {
return nil
}
conv := binary.LittleEndian.Uint32(packet[:4])
return &conv
}

View File

@@ -0,0 +1,330 @@
//go:build linux
package transport
import (
"errors"
"fmt"
"net"
"sync"
"syscall"
"time"
"omnisocketgo/cmd/internal/latencylog"
)
type kcpPendingPacketDebug struct {
remoteAddr net.Addr
packetBytes int
kcpConv *uint32
timestamps map[string]int64
}
type platformKCPPacketConn struct {
*kcpPacketConnBase
raw syscall.RawConn
writeMu sync.Mutex
pendingMu sync.Mutex
pendingTX map[uint32]kcpPendingPacketDebug
nextTXID uint32
}
func newPlatformKCPPacketConn(conn *net.UDPConn, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
packetConn := &platformKCPPacketConn{
kcpPacketConnBase: &kcpPacketConnBase{
conn: conn,
logger: logger,
nodeRole: nodeRole,
nodeID: nodeID,
closed: make(chan struct{}),
},
pendingTX: make(map[uint32]kcpPendingPacketDebug),
}
if logger == nil {
return packetConn, nil
}
if err := packetConn.initLinuxTimestamping(); err != nil {
return nil, err
}
go packetConn.collectTXErrqueueLoop()
return packetConn, nil
}
func (c *platformKCPPacketConn) Close() error {
return c.kcpPacketConnBase.Close()
}
func (c *platformKCPPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
if c.raw == nil {
return c.conn.ReadFrom(p)
}
for {
n, addr, rxTimestamp, err := c.recvmsgRaw(p)
if err != nil {
if isWouldBlock(err) {
time.Sleep(linuxDataPollInterval)
continue
}
return 0, nil, err
}
if rxTimestamp > 0 {
c.logKCPPacketDebugRecord(c.newKCPPacketDebugRecord(
latencylog.EventBRXSoftware,
addr,
n,
rxTimestamp,
nil,
parseKCPConversationID(p[:n]),
))
}
return n, addr, nil
}
}
func (c *platformKCPPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.raw == nil {
return c.conn.WriteTo(p, addr)
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, fmt.Errorf("transport: kcp packet write target must be UDPAddr, got %T", addr)
}
expectedTXID := c.nextExpectedTXID()
for {
err := c.sendmsgRaw(p, udpAddr)
if err != nil {
if isWouldBlock(err) {
time.Sleep(linuxDataPollInterval)
continue
}
return 0, err
}
c.storePendingTX(expectedTXID, udpAddr, len(p), parseKCPConversationID(p))
return len(p), nil
}
}
func (c *platformKCPPacketConn) initLinuxTimestamping() error {
rawConn, err := c.conn.SyscallConn()
if err != nil || rawConn == nil {
if err != nil {
return fmt.Errorf("transport: kcp get syscall conn: %w", err)
}
return fmt.Errorf("transport: kcp missing syscall conn")
}
if err := configureLinuxSocketWriteBuffer(rawConn); err != nil {
return fmt.Errorf("transport: kcp configure socket write buffer: %w", err)
}
flagCandidates := []int{
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptID |
linuxSOFTimestampingOptTSONLY,
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptTSONLY,
}
var lastErr error
for _, flags := range flagCandidates {
err := rawConn.Control(func(fd uintptr) {
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
})
if err != nil {
return err
}
if lastErr == nil {
c.raw = rawConn
return nil
}
if !errors.Is(lastErr, syscall.EINVAL) {
return lastErr
}
}
return lastErr
}
func (c *platformKCPPacketConn) recvmsgRaw(buf []byte) (int, net.Addr, int64, error) {
var (
n int
rxTimeNS int64
from syscall.Sockaddr
opErr error
)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
readN, oobN, _, sa, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
if recvErr != nil {
opErr = recvErr
return
}
n = readN
from = sa
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
})
if err != nil {
return 0, nil, 0, err
}
if opErr != nil {
return 0, nil, 0, opErr
}
return n, sockaddrToUDPAddr(from), rxTimeNS, nil
}
func (c *platformKCPPacketConn) sendmsgRaw(payload []byte, addr *net.UDPAddr) error {
var opErr error
sa := udpAddrToSockaddr(addr)
if sa == nil {
return fmt.Errorf("transport: invalid udp addr %v", addr)
}
err := c.raw.Control(func(fd uintptr) {
opErr = syscall.Sendmsg(int(fd), payload, nil, sa, 0)
})
if err != nil {
return err
}
return opErr
}
func (c *platformKCPPacketConn) collectTXErrqueueLoop() {
for {
event, err := c.recvTXErrqueueOnce()
if err != nil {
if isWouldBlock(err) {
if c.isClosed() {
return
}
time.Sleep(linuxTXTimestampPollInterval)
continue
}
if c.isClosed() {
return
}
return
}
if event.EventName == "" || event.TSUnixNano <= 0 {
continue
}
if event.EventName != latencylog.EventATXSched && event.EventName != latencylog.EventATXSoftware {
continue
}
record, complete := c.recordPendingTXEvent(event)
if record == nil {
continue
}
udpTxID := event.EEData
c.logKCPPacketDebugRecord(c.newKCPPacketDebugRecord(
event.EventName,
record.remoteAddr,
record.packetBytes,
event.TSUnixNano,
&udpTxID,
record.kcpConv,
))
if complete {
c.pendingMu.Lock()
delete(c.pendingTX, event.EEData)
c.pendingMu.Unlock()
}
}
}
func (c *platformKCPPacketConn) recvTXErrqueueOnce() (txTimestampEvent, error) {
var (
event txTimestampEvent
opErr error
)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
_, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
if recvErr != nil {
opErr = recvErr
return
}
event, _ = parseTXTimestampControlMessages(oob[:oobN])
})
if err != nil {
return txTimestampEvent{}, err
}
if opErr != nil {
return txTimestampEvent{}, opErr
}
return event, nil
}
func (c *platformKCPPacketConn) nextExpectedTXID() uint32 {
c.pendingMu.Lock()
defer c.pendingMu.Unlock()
next := c.nextTXID
c.nextTXID++
return next
}
func (c *platformKCPPacketConn) storePendingTX(txID uint32, remoteAddr net.Addr, packetBytes int, kcpConv *uint32) {
c.pendingMu.Lock()
defer c.pendingMu.Unlock()
c.pendingTX[txID] = kcpPendingPacketDebug{
remoteAddr: remoteAddr,
packetBytes: packetBytes,
kcpConv: kcpConv,
timestamps: make(map[string]int64, 2),
}
}
func (c *platformKCPPacketConn) recordPendingTXEvent(event txTimestampEvent) (*kcpPendingPacketDebug, bool) {
c.pendingMu.Lock()
defer c.pendingMu.Unlock()
record, ok := c.pendingTX[event.EEData]
if !ok {
return nil, false
}
if existing, exists := record.timestamps[event.EventName]; !exists || event.TSUnixNano < existing {
record.timestamps[event.EventName] = event.TSUnixNano
}
c.pendingTX[event.EEData] = record
complete := hasCompleteTXTimestampPair(record.timestamps)
copyRecord := record
return &copyRecord, complete
}
func (c *platformKCPPacketConn) isClosed() bool {
select {
case <-c.closed:
return true
default:
}
return false
}

View File

@@ -0,0 +1,29 @@
//go:build !linux
package transport
import "net"
type platformKCPPacketConn struct {
*kcpPacketConnBase
}
func newPlatformKCPPacketConn(conn *net.UDPConn, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
return &platformKCPPacketConn{
kcpPacketConnBase: &kcpPacketConnBase{
conn: conn,
logger: logger,
nodeRole: nodeRole,
nodeID: nodeID,
closed: make(chan struct{}),
},
}, nil
}
func (c *platformKCPPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
return c.conn.ReadFrom(p)
}
func (c *platformKCPPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return c.conn.WriteTo(p, addr)
}

View File

@@ -0,0 +1,76 @@
package transport
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
)
// KCPPacketDebugRecord 是 KCP 底层 UDP packet kernel timestamp 的一条 JSONL 调试记录。
type KCPPacketDebugRecord struct {
Event string `json:"event"`
NodeRole string `json:"node_role,omitempty"`
NodeID string `json:"node_id,omitempty"`
LocalAddr string `json:"local_addr,omitempty"`
RemoteAddr string `json:"remote_addr,omitempty"`
PacketBytes int `json:"packet_bytes"`
UDPTXID *uint32 `json:"udp_tx_id,omitempty"`
KCPConv *uint32 `json:"kcp_conv,omitempty"`
TSUnixNano int64 `json:"ts_unix_nano"`
}
// KCPPacketDebugLogger 接收 KCP packet 级调试记录。
type KCPPacketDebugLogger interface {
LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error
}
// JSONLKCPPacketDebugLogger 以 JSONL 形式追加写 KCP packet 调试日志。
type JSONLKCPPacketDebugLogger struct {
mu sync.Mutex
closeOnce sync.Once
closeErr error
file *os.File
}
// NewJSONLKCPPacketDebugLogger 创建一个线程安全的 KCP packet JSONL 日志器。
func NewJSONLKCPPacketDebugLogger(path string) (*JSONLKCPPacketDebugLogger, error) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("transport: create kcp packet debug log dir %s: %w", dir, err)
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return nil, fmt.Errorf("transport: open kcp packet debug log %s: %w", path, err)
}
return &JSONLKCPPacketDebugLogger{file: file}, nil
}
// LogKCPPacketDebugRecord 以单行 JSON 的形式追加一条 KCP packet 调试记录。
func (l *JSONLKCPPacketDebugLogger) LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error {
line, err := json.Marshal(record)
if err != nil {
return err
}
l.mu.Lock()
defer l.mu.Unlock()
if _, err := l.file.Write(append(line, '\n')); err != nil {
return err
}
return nil
}
// Close 关闭底层文件;重复调用是安全的。
func (l *JSONLKCPPacketDebugLogger) Close() error {
l.closeOnce.Do(func() {
l.closeErr = l.file.Close()
})
return l.closeErr
}

View File

@@ -0,0 +1,284 @@
package transport
import (
"reflect"
"strings"
"sync"
"testing"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
type recordingKCPPacketDebugLogger struct {
mu sync.Mutex
records []KCPPacketDebugRecord
}
func (l *recordingKCPPacketDebugLogger) LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error {
l.mu.Lock()
defer l.mu.Unlock()
l.records = append(l.records, record)
return nil
}
func (l *recordingKCPPacketDebugLogger) Records() []KCPPacketDebugRecord {
l.mu.Lock()
defer l.mu.Unlock()
return append([]KCPPacketDebugRecord(nil), l.records...)
}
type kcpAcceptResult struct {
conn *KCPConn
err error
}
func TestKCPSendReceiveMessage(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello kcp"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x00, 0x10, 0xff},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sender, accepted, cleanup := newKCPConnPair(
t,
nil,
[]KCPOption{WithKCPLogger(latencylog.NoopLogger{}, latencylog.NodeRolePeer, "peer-b")},
nil,
nil,
)
defer cleanup()
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(tt.msg)
}()
receiver := awaitAcceptedKCPConn(t, accepted)
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
if !reflect.DeepEqual(got, tt.msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
}
})
}
}
func TestKCPSendLogsHandoffEvents(t *testing.T) {
logger := &recordingLogger{}
sender, accepted, cleanup := newKCPConnPair(
t,
[]KCPOption{WithKCPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
nil,
nil,
nil,
)
defer cleanup()
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
receiver := awaitAcceptedKCPConn(t, accepted)
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
if !reflect.DeepEqual(got, msg) {
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) != 2 {
t.Fatalf("event count = %d, want 2", len(events))
}
if events[0].Event != latencylog.EventSendHandoffBegin {
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
}
if events[1].Event != latencylog.EventSendHandoffEnd {
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventSendHandoffEnd)
}
}
func TestKCPReceiveLoopStopsOnClose(t *testing.T) {
sender, accepted, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
defer cleanup()
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
sendErr := make(chan error, 1)
go func() {
sendErr <- sender.Send(msg)
}()
receiver := awaitAcceptedKCPConn(t, accepted)
var (
mu sync.Mutex
got []protocol.Message
)
loopErr := make(chan error, 1)
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
mu.Lock()
got = append(got, msg)
mu.Unlock()
return receiver.Close()
})
}()
if err := <-sendErr; err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
err := <-loopErr
if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "pipe")) {
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
}
mu.Lock()
defer mu.Unlock()
if len(got) != 1 || !reflect.DeepEqual(got[0], msg) {
t.Fatalf("received messages mismatch: got %+v want [%+v]", got, msg)
}
}
func TestKCPCloseIsIdempotent(t *testing.T) {
sender, _, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
defer cleanup()
if err := sender.Close(); err != nil {
t.Fatalf("Close(first) error = %v", err)
}
if err := sender.Close(); err != nil {
t.Fatalf("Close(second) error = %v, want nil", err)
}
}
func newKCPConnPair(t *testing.T, senderOpts []KCPOption, receiverOpts []KCPOption, senderPacketLogger KCPPacketDebugLogger, receiverPacketLogger KCPPacketDebugLogger) (*KCPConn, <-chan kcpAcceptResult, func()) {
t.Helper()
listener, packetConn, err := ListenKCPSessions("127.0.0.1:0", "", receiverPacketLogger, latencylog.NodeRolePeer, "peer-b")
if err != nil {
t.Fatalf("ListenKCPSessions() error = %v", err)
}
accepted := make(chan kcpAcceptResult, 1)
go func() {
session, acceptErr := listener.AcceptKCP()
if acceptErr != nil {
accepted <- kcpAcceptResult{err: acceptErr}
return
}
conn, connErr := NewKCPConn(session, receiverOpts...)
accepted <- kcpAcceptResult{conn: conn, err: connErr}
}()
session, err := DialKCPSession(listener.Addr().String(), "", "", senderPacketLogger, latencylog.NodeRolePeer, "peer-a")
if err != nil {
_ = packetConn.Close()
_ = listener.Close()
t.Fatalf("DialKCPSession() error = %v", err)
}
sender, err := NewKCPConn(session, senderOpts...)
if err != nil {
_ = session.Close()
_ = packetConn.Close()
_ = listener.Close()
t.Fatalf("NewKCPConn(sender) error = %v", err)
}
cleanup := func() {
_ = sender.Close()
select {
case result := <-accepted:
if result.conn != nil {
_ = result.conn.Close()
}
default:
}
_ = listener.Close()
_ = packetConn.Close()
}
return sender, accepted, cleanup
}
func awaitAcceptedKCPConn(t *testing.T, accepted <-chan kcpAcceptResult) *KCPConn {
t.Helper()
result := <-accepted
if result.err != nil {
t.Fatalf("AcceptKCP() error = %v", result.err)
}
if result.conn == nil {
t.Fatal("accepted KCP conn = nil")
}
return result.conn
}
func waitForKCPPacketRecords(t *testing.T, logger *recordingKCPPacketDebugLogger, condition func([]KCPPacketDebugRecord) bool, description string) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
records := logger.Records()
if condition(records) {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for %s", description)
}

View File

@@ -0,0 +1,24 @@
//go:build linux
package transport
import (
"fmt"
"syscall"
)
// udpBindDeviceControl 返回一个 Control 函数,用于在 Linux 上将 UDP socket 绑定到指定网卡设备。
func udpBindDeviceControl(device string) (func(string, string, syscall.RawConn) error, error) {
return func(_, _ string, rawConn syscall.RawConn) error {
var bindErr error
if err := rawConn.Control(func(fd uintptr) {
bindErr = syscall.BindToDevice(int(fd), device)
}); err != nil {
return err
}
if bindErr != nil {
return fmt.Errorf("transport: bind device %s: %w", device, bindErr)
}
return nil
}, nil
}

View File

@@ -0,0 +1,12 @@
//go:build !linux
package transport
import (
"fmt"
"syscall"
)
func udpBindDeviceControl(device string) (func(string, string, syscall.RawConn) error, error) {
return nil, fmt.Errorf("transport: bind device %s is only supported on linux", device)
}

View File

@@ -141,10 +141,11 @@ func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
go func() {
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error {
mu.Lock()
defer mu.Unlock()
got = append(got, msg)
if len(got) >= len(want) {
return nil
done := len(got) >= len(want)
mu.Unlock()
if done {
return receiver.Close()
}
return nil
})
@@ -156,17 +157,12 @@ func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
}
}
// 关闭发送端ReceiveLoop 会因读取错误退出
if err := sender.Close(); err != nil {
t.Fatalf("sender.Close() error = %v", err)
}
err := <-loopErr
if err == nil {
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
t.Fatal("ReceiveLoop() error = nil, want non-nil after receiver close")
}
if !strings.Contains(err.Error(), "udp receive loop read") {
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
if !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection") {
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
}
mu.Lock()
@@ -268,6 +264,7 @@ func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOpti
_ = conn1.Close()
t.Fatalf("ListenUDP(2) error = %v", err)
}
receiverLocalAddr := conn2.LocalAddr().(*net.UDPAddr)
// 用 Dial 模式连接对端
senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr))
@@ -277,14 +274,13 @@ func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOpti
t.Fatalf("DialUDP(sender) error = %v", err)
}
_ = conn1.Close() // 不再需要 conn1
_ = conn2.Close() // 释放 receiver 计划使用的本地地址
receiverRaw, err := net.DialUDP("udp", conn2.LocalAddr().(*net.UDPAddr), senderRaw.LocalAddr().(*net.UDPAddr))
receiverRaw, err := net.DialUDP("udp", receiverLocalAddr, senderRaw.LocalAddr().(*net.UDPAddr))
if err != nil {
_ = senderRaw.Close()
_ = conn2.Close()
t.Fatalf("DialUDP(receiver) error = %v", err)
}
_ = conn2.Close() // 不再需要 conn2
sender, err := NewUDPConn(senderRaw, nil, senderOpts...)
if err != nil {