feat:KCP协议
This commit is contained in:
@@ -18,6 +18,7 @@ var dialServer = dialServerWithOptions
|
||||
type clientOptions struct {
|
||||
logger latencylog.Logger
|
||||
txTimestampDebugLogger transport.TXTimestampDebugLogger
|
||||
kcpPacketDebugLogger transport.KCPPacketDebugLogger
|
||||
bindIP string
|
||||
bindDevice string
|
||||
}
|
||||
@@ -39,6 +40,13 @@ func WithTXTimestampDebugLogger(logger transport.TXTimestampDebugLogger) Option
|
||||
}
|
||||
}
|
||||
|
||||
// WithKCPPacketDebugLogger 为 KCP UDP packet timestamp 调试日志注入记录器。
|
||||
func WithKCPPacketDebugLogger(logger transport.KCPPacketDebugLogger) Option {
|
||||
return func(options *clientOptions) {
|
||||
options.kcpPacketDebugLogger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithBindIP 指定拨号时使用的本地源 IP。
|
||||
func WithBindIP(ip string) Option {
|
||||
return func(options *clientOptions) {
|
||||
|
||||
184
cmd/internal/peer/kcp_client.go
Normal file
184
cmd/internal/peer/kcp_client.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// KCPClient 表示一个通过 KCP 连接到 server 的 peer。
|
||||
type KCPClient struct {
|
||||
id string
|
||||
conn *transport.KCPConn
|
||||
logger latencylog.Logger
|
||||
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
// DialKCP 通过 KCP 连接到 server,并发送 register 消息完成身份注册。
|
||||
func DialKCP(serverAddr, peerID string, opts ...Option) (*KCPClient, error) {
|
||||
options := clientOptions{
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
if options.logger == nil {
|
||||
options.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
session, err := transport.DialKCPSession(
|
||||
serverAddr,
|
||||
options.bindIP,
|
||||
options.bindDevice,
|
||||
options.kcpPacketDebugLogger,
|
||||
latencylog.NodeRolePeer,
|
||||
peerID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer: dial kcp server: %w", err)
|
||||
}
|
||||
|
||||
conn, err := transport.NewKCPConn(
|
||||
session,
|
||||
transport.WithKCPLogger(options.logger, latencylog.NodeRolePeer, peerID),
|
||||
)
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, fmt.Errorf("peer: create kcp transport conn: %w", err)
|
||||
}
|
||||
|
||||
client := &KCPClient{
|
||||
id: peerID,
|
||||
conn: conn,
|
||||
logger: options.logger,
|
||||
}
|
||||
|
||||
if err := conn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeRegister,
|
||||
From: peerID,
|
||||
To: protocol.ServerPeerID,
|
||||
}); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("peer: register with kcp server: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// ID 返回当前 client 的 peer 标识。
|
||||
func (c *KCPClient) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
// SendText 向目标 peer 发送一条文本消息。
|
||||
func (c *KCPClient) SendText(to, body string) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
}
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
msg.Body = []byte(body)
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// SendFile 向目标 peer 发送一条文件消息。
|
||||
func (c *KCPClient) SendFile(to, fileName string, body []byte) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
FileName: fileName,
|
||||
}
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
bodyCopy := make([]byte, len(body))
|
||||
copy(bodyCopy, body)
|
||||
msg.Body = bodyCopy
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// SendFilePath 从本地文件读取内容并发送给目标 peer。
|
||||
func (c *KCPClient) SendFilePath(to, path string) error {
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: c.nextMessageID(),
|
||||
From: c.id,
|
||||
To: to,
|
||||
FileName: filepath.Base(path),
|
||||
}
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventAAppPrepBegin, msg)
|
||||
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("peer: read file %s: %w", path, err)
|
||||
}
|
||||
msg.Body = body
|
||||
|
||||
return c.conn.Send(msg)
|
||||
}
|
||||
|
||||
// Receive 读取一条来自 server 的消息。
|
||||
func (c *KCPClient) Receive() (protocol.Message, error) {
|
||||
msg, err := c.conn.Receive()
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("peer: receive from kcp server: %w", err)
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReceiveLoop 持续接收 server 消息并交给 handler 处理。
|
||||
func (c *KCPClient) ReceiveLoop(handler func(protocol.Message) error) error {
|
||||
return c.conn.ReceiveLoop(func(msg protocol.Message) error {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile, protocol.MessageTypeError:
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBAppRecv, msg)
|
||||
return handler(msg)
|
||||
default:
|
||||
return fmt.Errorf("peer: unexpected message type from kcp server: %s", msg.Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PersistMessage 将收到的业务消息写入本地磁盘。
|
||||
func (c *KCPClient) PersistMessage(msg protocol.Message, inboxDir string) (string, error) {
|
||||
if !latencylog.IsBusinessMessage(msg) {
|
||||
return "", fmt.Errorf("peer: cannot persist message type %s", msg.Type)
|
||||
}
|
||||
if inboxDir == "" {
|
||||
return "", fmt.Errorf("peer: inbox directory is required")
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistBegin, msg)
|
||||
if err := os.MkdirAll(inboxDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("peer: create inbox dir %s: %w", inboxDir, err)
|
||||
}
|
||||
|
||||
path, err := persistMessageToDisk(msg, inboxDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, latencylog.NodeRolePeer, c.id, latencylog.EventBPersistEnd, msg)
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Close 关闭与 server 的 KCP 会话。
|
||||
func (c *KCPClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *KCPClient) nextMessageID() uint64 {
|
||||
return atomic.AddUint64(&c.nextID, 1)
|
||||
}
|
||||
263
cmd/internal/peer/kcp_client_test.go
Normal file
263
cmd/internal/peer/kcp_client_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
kcp "github.com/xtaci/kcp-go/v5"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/server"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
func TestKCPDialRegistersPeer(t *testing.T) {
|
||||
hub := server.NewKCPHub()
|
||||
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
client, err := DialKCP(serverAddr, "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("DialKCP() error = %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
}
|
||||
|
||||
func TestKCPDialRejectsInvalidBindIP(t *testing.T) {
|
||||
_, err := DialKCP("127.0.0.1:9002", "peer-a", WithBindIP("not-an-ip"))
|
||||
if err == nil {
|
||||
t.Fatal("DialKCP() error = nil, want invalid bind ip error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `invalid bind ip "not-an-ip"`) {
|
||||
t.Fatalf("DialKCP() error = %v, want invalid bind ip error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKCPClientsExchangeTextAndFileMessages(t *testing.T) {
|
||||
hub := server.NewKCPHub()
|
||||
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
peerA, err := DialKCP(serverAddr, "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerA.Close() }()
|
||||
|
||||
peerB, err := DialKCP(serverAddr, "peer-b")
|
||||
if err != nil {
|
||||
t.Fatalf("DialKCP(peer-b) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerB.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
received := make(chan protocol.Message, 2)
|
||||
receiveErr := make(chan error, 1)
|
||||
go func() {
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
receiveErr <- err
|
||||
return
|
||||
}
|
||||
received <- msg
|
||||
}
|
||||
receiveErr <- nil
|
||||
}()
|
||||
|
||||
if err := peerA.SendText("peer-b", "hello over kcp"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
fileBody := []byte{0x01, 0x02, 0x03}
|
||||
if err := peerA.SendFile("peer-b", "payload.bin", fileBody); err != nil {
|
||||
t.Fatalf("SendFile() error = %v", err)
|
||||
}
|
||||
|
||||
if err := <-receiveErr; err != nil {
|
||||
t.Fatalf("peerB.Receive() error = %v", err)
|
||||
}
|
||||
|
||||
gotFirst := <-received
|
||||
wantFirst := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello over kcp"),
|
||||
}
|
||||
if !reflect.DeepEqual(gotFirst, wantFirst) {
|
||||
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, wantFirst)
|
||||
}
|
||||
|
||||
gotSecond := <-received
|
||||
wantSecond := protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: fileBody,
|
||||
}
|
||||
if !reflect.DeepEqual(gotSecond, wantSecond) {
|
||||
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, wantSecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKCPClientReceivesServerErrorForUnknownTarget(t *testing.T) {
|
||||
hub := server.NewKCPHub()
|
||||
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
client, err := DialKCP(serverAddr, "peer-a")
|
||||
if err != nil {
|
||||
t.Fatalf("DialKCP() error = %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
|
||||
|
||||
if err := client.SendText("missing-peer", "hello"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := client.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if got.Type != protocol.MessageTypeError {
|
||||
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
|
||||
}
|
||||
if string(got.Body) != "unknown target: missing-peer" {
|
||||
t.Fatalf("error body = %q, want unknown target message", got.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKCPClientsExchangeMessagesWithLatencyLogs(t *testing.T) {
|
||||
hub := server.NewKCPHub()
|
||||
serverAddr, cleanup := startRealKCPHubServer(t, hub)
|
||||
defer cleanup()
|
||||
|
||||
peerALogger := &recordingLogger{}
|
||||
peerA, err := DialKCP(serverAddr, "peer-a", WithLogger(peerALogger))
|
||||
if err != nil {
|
||||
t.Fatalf("DialKCP(peer-a) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerA.Close() }()
|
||||
|
||||
peerBLogger := &recordingLogger{}
|
||||
peerB, err := DialKCP(serverAddr, "peer-b", WithLogger(peerBLogger))
|
||||
if err != nil {
|
||||
t.Fatalf("DialKCP(peer-b) error = %v", err)
|
||||
}
|
||||
defer func() { _ = peerB.Close() }()
|
||||
|
||||
inboxDir := t.TempDir()
|
||||
|
||||
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
|
||||
|
||||
if err := peerA.SendText("peer-b", "hello"); err != nil {
|
||||
t.Fatalf("SendText() error = %v", err)
|
||||
}
|
||||
textMsg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peerB.Receive(text) error = %v", err)
|
||||
}
|
||||
if _, err := peerB.PersistMessage(textMsg, inboxDir); err != nil {
|
||||
t.Fatalf("peerB.PersistMessage(text) error = %v", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(t.TempDir(), "payload.bin")
|
||||
if err := os.WriteFile(filePath, []byte{0x01, 0x02, 0x03}, 0o644); err != nil {
|
||||
t.Fatalf("os.WriteFile() error = %v", err)
|
||||
}
|
||||
if err := peerA.SendFilePath("peer-b", filePath); err != nil {
|
||||
t.Fatalf("SendFilePath() error = %v", err)
|
||||
}
|
||||
fileMsg, err := peerB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("peerB.Receive(file) error = %v", err)
|
||||
}
|
||||
if _, err := peerB.PersistMessage(fileMsg, inboxDir); err != nil {
|
||||
t.Fatalf("peerB.PersistMessage(file) error = %v", err)
|
||||
}
|
||||
|
||||
waitFor(t, func() bool { return len(peerALogger.Events()) == 6 }, "peer-a latency events")
|
||||
waitFor(t, func() bool { return len(peerBLogger.Events()) == 6 }, "peer-b latency events")
|
||||
|
||||
assertEventSequencesByMessage(t, peerALogger.Events(), map[uint64][]string{
|
||||
1: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd},
|
||||
2: {latencylog.EventAAppPrepBegin, latencylog.EventSendHandoffBegin, latencylog.EventSendHandoffEnd},
|
||||
})
|
||||
assertEventSequencesByMessage(t, peerBLogger.Events(), map[uint64][]string{
|
||||
1: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
|
||||
2: {latencylog.EventBAppRecv, latencylog.EventBPersistBegin, latencylog.EventBPersistEnd},
|
||||
})
|
||||
}
|
||||
|
||||
func startRealKCPHubServer(t *testing.T, hub *server.KCPHub) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
listener, packetConn, err := transport.ListenKCPSessions("127.0.0.1:0", "", nil, latencylog.NodeRoleServer, "hub")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenKCPSessions() error = %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
stop = make(chan struct{})
|
||||
)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
session, acceptErr := listener.AcceptKCP()
|
||||
if acceptErr != nil {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if strings.Contains(acceptErr.Error(), "closed") {
|
||||
return
|
||||
}
|
||||
t.Errorf("AcceptKCP() error = %v", acceptErr)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(sess *kcp.UDPSession) {
|
||||
defer wg.Done()
|
||||
if serveErr := hub.ServeSession(sess); serveErr != nil && !isExpectedKCPHubServeExit(serveErr) {
|
||||
t.Logf("hub.ServeSession() ended with %v", serveErr)
|
||||
}
|
||||
}(session)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() {
|
||||
close(stop)
|
||||
_ = listener.Close()
|
||||
_ = packetConn.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
return listener.Addr().String(), cleanup
|
||||
}
|
||||
|
||||
func isExpectedKCPHubServeExit(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
message := err.Error()
|
||||
return strings.Contains(message, "closed") || strings.Contains(message, "broken pipe") || strings.Contains(message, "io: read/write on closed pipe")
|
||||
}
|
||||
173
cmd/internal/server/kcp_hub.go
Normal file
173
cmd/internal/server/kcp_hub.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
kcp "github.com/xtaci/kcp-go/v5"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// KCPOption 用于配置 KCPHub 的可选行为。
|
||||
type KCPOption func(*KCPHub)
|
||||
|
||||
// WithKCPLogger 为 KCP hub 注入时延日志记录器。
|
||||
func WithKCPLogger(logger latencylog.Logger) KCPOption {
|
||||
return func(hub *KCPHub) {
|
||||
hub.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// KCPHub 管理已注册 peer 的 KCP 会话,并负责在它们之间转发消息。
|
||||
type KCPHub struct {
|
||||
mu sync.RWMutex
|
||||
peers map[string]*transport.KCPConn
|
||||
logger latencylog.Logger
|
||||
}
|
||||
|
||||
// NewKCPHub 创建一个空的 KCP 连接中心。
|
||||
func NewKCPHub(opts ...KCPOption) *KCPHub {
|
||||
hub := &KCPHub{
|
||||
peers: make(map[string]*transport.KCPConn),
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(hub)
|
||||
}
|
||||
if hub.logger == nil {
|
||||
hub.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
return hub
|
||||
}
|
||||
|
||||
// HasPeer 返回给定 ID 是否已经注册到 hub。
|
||||
func (h *KCPHub) HasPeer(peerID string) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
_, ok := h.peers[peerID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ServeSession 处理一条新接入的 KCP 会话。
|
||||
func (h *KCPHub) ServeSession(session *kcp.UDPSession) error {
|
||||
conn, err := transport.NewKCPConn(
|
||||
session,
|
||||
transport.WithKCPLogger(h.logger, latencylog.NodeRoleServer, "hub"),
|
||||
)
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return fmt.Errorf("server: create kcp transport conn: %w", err)
|
||||
}
|
||||
|
||||
peerID, err := h.registerConn(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
defer h.unregister(peerID, conn)
|
||||
|
||||
return h.receivePeerLoop(peerID, conn)
|
||||
}
|
||||
|
||||
// 注册新连接时,KCPHub 期望第一条消息是一个 register 消息,包含 peer 的 ID
|
||||
func (h *KCPHub) registerConn(conn *transport.KCPConn) (string, error) {
|
||||
msg, err := conn.Receive()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("server: receive kcp register: %w", err)
|
||||
}
|
||||
|
||||
if msg.Type != protocol.MessageTypeRegister {
|
||||
if sendErr := sendKCPServerError(conn, msg.From, "first message must be register"); sendErr != nil {
|
||||
return "", fmt.Errorf("server: reject unregistered kcp peer: %w", sendErr)
|
||||
}
|
||||
return "", fmt.Errorf("server: first kcp message must be register, got %s", msg.Type)
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if _, exists := h.peers[msg.From]; exists {
|
||||
if sendErr := sendKCPServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
|
||||
return "", fmt.Errorf("server: duplicate kcp peer id %s: %w", msg.From, sendErr)
|
||||
}
|
||||
return "", fmt.Errorf("server: duplicate kcp peer id: %s", msg.From)
|
||||
}
|
||||
|
||||
h.peers[msg.From] = conn
|
||||
return msg.From, nil
|
||||
}
|
||||
|
||||
// handlePeerMessage 处理已注册 peer 发来的消息,并将其转发给目标 peer。
|
||||
func (h *KCPHub) handlePeerMessage(peerID string, conn *transport.KCPConn, msg protocol.Message) error {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile:
|
||||
msg.From = peerID
|
||||
targetConn, ok := h.lookup(msg.To)
|
||||
if !ok {
|
||||
return sendKCPServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
|
||||
}
|
||||
if err := targetConn.Send(msg); err != nil {
|
||||
h.unregister(msg.To, targetConn)
|
||||
_ = targetConn.Close()
|
||||
return sendKCPServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
|
||||
}
|
||||
return nil
|
||||
case protocol.MessageTypeRegister, protocol.MessageTypeError:
|
||||
if err := sendKCPServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
|
||||
return fmt.Errorf("server: send kcp protocol error: %w", err)
|
||||
}
|
||||
return fmt.Errorf("server: unexpected kcp message type from peer %s: %s", peerID, msg.Type)
|
||||
default:
|
||||
if err := sendKCPServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
|
||||
return fmt.Errorf("server: send unsupported kcp type error: %w", err)
|
||||
}
|
||||
return fmt.Errorf("server: unsupported kcp message type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// receivePeerLoop 持续读取 peer 发来的消息,并交给 handlePeerMessage 处理,直到连接出错。
|
||||
func (h *KCPHub) receivePeerLoop(peerID string, conn *transport.KCPConn) error {
|
||||
for {
|
||||
msg, err := conn.Receive()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("transport: kcp receive loop read: %w", err)
|
||||
}
|
||||
|
||||
if err := h.handlePeerMessage(peerID, conn, msg); err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *KCPHub) lookup(peerID string) (*transport.KCPConn, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
conn, ok := h.peers[peerID]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
func (h *KCPHub) unregister(peerID string, conn *transport.KCPConn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
current, ok := h.peers[peerID]
|
||||
if ok && current == conn {
|
||||
delete(h.peers, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func sendKCPServerError(conn *transport.KCPConn, to, message string) error {
|
||||
return conn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeError,
|
||||
From: protocol.ServerPeerID,
|
||||
To: to,
|
||||
Body: []byte(message),
|
||||
})
|
||||
}
|
||||
256
cmd/internal/transport/kcp.go
Normal file
256
cmd/internal/transport/kcp.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
kcp "github.com/xtaci/kcp-go/v5"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
kcpNoDelayNodelay = 1
|
||||
kcpNoDelayInterval = 10
|
||||
kcpNoDelayResend = 2
|
||||
kcpNoDelayNC = 1
|
||||
kcpWindowSize = 256
|
||||
kcpMTU = 1400
|
||||
)
|
||||
|
||||
// KCPConn 是对单条活跃 KCP 会话的轻量封装。
|
||||
type KCPConn struct {
|
||||
session *kcp.UDPSession
|
||||
|
||||
logger latencylog.Logger
|
||||
nodeRole string
|
||||
nodeID string
|
||||
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// KCPOption 用于为 KCPConn 注入可选行为。
|
||||
type KCPOption func(*KCPConn)
|
||||
|
||||
// WithKCPLogger 为 KCP 连接发送路径注入业务消息日志上下文。
|
||||
func WithKCPLogger(logger latencylog.Logger, nodeRole, nodeID string) KCPOption {
|
||||
return func(conn *KCPConn) {
|
||||
conn.logger = logger
|
||||
conn.nodeRole = nodeRole
|
||||
conn.nodeID = nodeID
|
||||
}
|
||||
}
|
||||
|
||||
// NewKCPConn 用已有的 KCP 会话创建 transport 连接封装。
|
||||
func NewKCPConn(session *kcp.UDPSession, opts ...KCPOption) (*KCPConn, error) {
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("transport: nil kcp session")
|
||||
}
|
||||
|
||||
conn := &KCPConn{
|
||||
session: session,
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
if conn.logger == nil {
|
||||
conn.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
configureKCPSession(session)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Send 将一条协议消息完整写入底层 KCP 会话。
|
||||
func (c *KCPConn) Send(msg protocol.Message) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
|
||||
if err := protocol.WriteMessage(c.session, msg); err != nil {
|
||||
return fmt.Errorf("transport: kcp send message: %w", err)
|
||||
}
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive 从底层 KCP 会话读取一条完整协议消息。
|
||||
func (c *KCPConn) Receive() (protocol.Message, error) {
|
||||
msg, err := protocol.ReadMessage(c.session)
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("transport: kcp receive message: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReceiveLoop 持续读取消息并交给 handler 处理。
|
||||
func (c *KCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
|
||||
for {
|
||||
msg, err := c.Receive()
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("transport: kcp receive loop read: %w", err)
|
||||
}
|
||||
|
||||
if err := handler(msg); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭底层 KCP 会话,并保证重复调用是安全的。
|
||||
func (c *KCPConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeErr = c.session.Close()
|
||||
})
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
// DialKCPSession 创建一条主动发起的 KCP 会话,并按项目默认参数配置底层 UDP socket。
|
||||
func DialKCPSession(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.UDPSession, error) {
|
||||
packetConn, remoteAddr, err := dialKCPPacketConn(serverAddr, bindIP, bindDevice, logger, nodeRole, nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
convID, err := generateKCPConversationID()
|
||||
if err != nil {
|
||||
_ = packetConn.Close()
|
||||
return nil, fmt.Errorf("transport: generate kcp conversation id: %w", err)
|
||||
}
|
||||
|
||||
session, err := kcp.NewConn4(convID, remoteAddr, nil, 0, 0, true, packetConn)
|
||||
if err != nil {
|
||||
_ = packetConn.Close()
|
||||
return nil, fmt.Errorf("transport: create kcp session: %w", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// ListenKCPSessions 在给定地址上启动 KCP listener,并返回 listener 与底层 packetConn。
|
||||
func ListenKCPSessions(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.Listener, net.PacketConn, error) {
|
||||
packetConn, err := listenKCPPacketConn(listenAddr, bindDevice, logger, nodeRole, nodeID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
listener, err := kcp.ServeConn(nil, 0, 0, packetConn)
|
||||
if err != nil {
|
||||
_ = packetConn.Close()
|
||||
return nil, nil, fmt.Errorf("transport: serve kcp listener: %w", err)
|
||||
}
|
||||
|
||||
return listener, packetConn, nil
|
||||
}
|
||||
|
||||
func configureKCPSession(session *kcp.UDPSession) {
|
||||
session.SetStreamMode(true)
|
||||
session.SetNoDelay(kcpNoDelayNodelay, kcpNoDelayInterval, kcpNoDelayResend, kcpNoDelayNC)
|
||||
session.SetWindowSize(kcpWindowSize, kcpWindowSize)
|
||||
session.SetACKNoDelay(true)
|
||||
session.SetWriteDelay(false)
|
||||
session.SetMtu(kcpMTU)
|
||||
}
|
||||
|
||||
func generateKCPConversationID() (uint32, error) {
|
||||
var convID uint32
|
||||
if err := binary.Read(rand.Reader, binary.LittleEndian, &convID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return convID, nil
|
||||
}
|
||||
|
||||
func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transport: resolve kcp listen addr %s: %w", listenAddr, err)
|
||||
}
|
||||
|
||||
rawConn, err := listenUDPConn("udp", udpAddr, bindDevice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transport: listen udp for kcp on %s: %w", listenAddr, err)
|
||||
}
|
||||
|
||||
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func dialKCPPacketConn(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, *net.UDPAddr, error) {
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("transport: resolve kcp server addr %s: %w", serverAddr, err)
|
||||
}
|
||||
|
||||
localAddr := &net.UDPAddr{Port: 0}
|
||||
if bindIP != "" {
|
||||
ip := net.ParseIP(bindIP)
|
||||
if ip == nil {
|
||||
return nil, nil, fmt.Errorf("transport: invalid bind ip %q", bindIP)
|
||||
}
|
||||
localAddr.IP = ip
|
||||
}
|
||||
|
||||
network := "udp"
|
||||
if remoteAddr.IP.To4() != nil {
|
||||
network = "udp4"
|
||||
}
|
||||
|
||||
rawConn, err := listenUDPConn(network, localAddr, bindDevice)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("transport: listen udp for kcp dial to %s: %w", serverAddr, err)
|
||||
}
|
||||
|
||||
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return packetConn, remoteAddr, nil
|
||||
}
|
||||
|
||||
func listenUDPConn(network string, localAddr *net.UDPAddr, bindDevice string) (*net.UDPConn, error) {
|
||||
listenConfig := net.ListenConfig{}
|
||||
if bindDevice != "" {
|
||||
control, err := udpBindDeviceControl(bindDevice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listenConfig.Control = control
|
||||
}
|
||||
|
||||
packetConn, err := listenConfig.ListenPacket(context.Background(), network, udpListenAddr(localAddr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpConn, ok := packetConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
_ = packetConn.Close()
|
||||
return nil, fmt.Errorf("transport: expected *net.UDPConn, got %T", packetConn)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func udpListenAddr(addr *net.UDPAddr) string {
|
||||
if addr == nil {
|
||||
return ":0"
|
||||
}
|
||||
return addr.String()
|
||||
}
|
||||
90
cmd/internal/transport/kcp_linux_test.go
Normal file
90
cmd/internal/transport/kcp_linux_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
func TestKCPLinuxPacketDebugLogsKernelEvents(t *testing.T) {
|
||||
senderPacketLogger := &recordingKCPPacketDebugLogger{}
|
||||
receiverPacketLogger := &recordingKCPPacketDebugLogger{}
|
||||
|
||||
sender, accepted, cleanup := newKCPConnPair(t, nil, nil, senderPacketLogger, receiverPacketLogger)
|
||||
defer cleanup()
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello kcp linux"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
receiver := awaitAcceptedKCPConn(t, accepted)
|
||||
if _, err := receiver.Receive(); err != nil {
|
||||
t.Fatalf("receiver.Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("sender.Send() error = %v", err)
|
||||
}
|
||||
|
||||
waitForKCPPacketRecords(t, senderPacketLogger, func(records []KCPPacketDebugRecord) bool {
|
||||
return hasKCPPacketEvent(records, latencylog.EventATXSched) && hasKCPPacketEvent(records, latencylog.EventATXSoftware)
|
||||
}, "sender tx kernel timestamp records")
|
||||
waitForKCPPacketRecords(t, receiverPacketLogger, func(records []KCPPacketDebugRecord) bool {
|
||||
return hasKCPPacketEvent(records, latencylog.EventBRXSoftware)
|
||||
}, "receiver rx kernel timestamp records")
|
||||
|
||||
senderRecords := senderPacketLogger.Records()
|
||||
receiverRecords := receiverPacketLogger.Records()
|
||||
|
||||
assertKCPPacketRecord(t, senderRecords, latencylog.EventATXSched, true)
|
||||
assertKCPPacketRecord(t, senderRecords, latencylog.EventATXSoftware, true)
|
||||
assertKCPPacketRecord(t, receiverRecords, latencylog.EventBRXSoftware, false)
|
||||
}
|
||||
|
||||
func hasKCPPacketEvent(records []KCPPacketDebugRecord, wantEvent string) bool {
|
||||
for _, record := range records {
|
||||
if record.Event == wantEvent {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func assertKCPPacketRecord(t *testing.T, records []KCPPacketDebugRecord, wantEvent string, wantUDPTXID bool) {
|
||||
t.Helper()
|
||||
|
||||
for _, record := range records {
|
||||
if record.Event != wantEvent {
|
||||
continue
|
||||
}
|
||||
if record.TSUnixNano <= 0 {
|
||||
t.Fatalf("record %s timestamp must be positive: %+v", wantEvent, record)
|
||||
}
|
||||
if record.PacketBytes <= 0 {
|
||||
t.Fatalf("record %s packet bytes must be positive: %+v", wantEvent, record)
|
||||
}
|
||||
if record.KCPConv == nil {
|
||||
t.Fatalf("record %s missing kcp_conv: %+v", wantEvent, record)
|
||||
}
|
||||
if wantUDPTXID && record.UDPTXID == nil {
|
||||
t.Fatalf("record %s missing udp_tx_id: %+v", wantEvent, record)
|
||||
}
|
||||
if !wantUDPTXID && record.UDPTXID != nil {
|
||||
t.Fatalf("record %s unexpected udp_tx_id: %+v", wantEvent, record)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
t.Fatalf("missing KCP packet debug event %s in %+v", wantEvent, records)
|
||||
}
|
||||
91
cmd/internal/transport/kcp_packet_conn.go
Normal file
91
cmd/internal/transport/kcp_packet_conn.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newKCPPacketConn(conn *net.UDPConn, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
|
||||
return newPlatformKCPPacketConn(conn, logger, nodeRole, nodeID)
|
||||
}
|
||||
|
||||
type kcpPacketConnBase struct {
|
||||
conn *net.UDPConn
|
||||
logger KCPPacketDebugLogger
|
||||
nodeRole string
|
||||
nodeID string
|
||||
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.closed)
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) SetReadBuffer(bytes int) error {
|
||||
return c.conn.SetReadBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) SetWriteBuffer(bytes int) error {
|
||||
return c.conn.SetWriteBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) logKCPPacketDebugRecord(record KCPPacketDebugRecord) {
|
||||
if c.logger == nil {
|
||||
return
|
||||
}
|
||||
_ = c.logger.LogKCPPacketDebugRecord(record)
|
||||
}
|
||||
|
||||
func (c *kcpPacketConnBase) newKCPPacketDebugRecord(event string, remoteAddr net.Addr, packetBytes int, tsUnixNano int64, udpTxID *uint32, kcpConv *uint32) KCPPacketDebugRecord {
|
||||
record := KCPPacketDebugRecord{
|
||||
Event: event,
|
||||
NodeRole: c.nodeRole,
|
||||
NodeID: c.nodeID,
|
||||
LocalAddr: "",
|
||||
RemoteAddr: "",
|
||||
PacketBytes: packetBytes,
|
||||
UDPTXID: udpTxID,
|
||||
KCPConv: kcpConv,
|
||||
TSUnixNano: tsUnixNano,
|
||||
}
|
||||
if localAddr := c.conn.LocalAddr(); localAddr != nil {
|
||||
record.LocalAddr = localAddr.String()
|
||||
}
|
||||
if remoteAddr != nil {
|
||||
record.RemoteAddr = remoteAddr.String()
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
func parseKCPConversationID(packet []byte) *uint32 {
|
||||
if len(packet) < 4 {
|
||||
return nil
|
||||
}
|
||||
conv := binary.LittleEndian.Uint32(packet[:4])
|
||||
return &conv
|
||||
}
|
||||
330
cmd/internal/transport/kcp_packet_conn_linux.go
Normal file
330
cmd/internal/transport/kcp_packet_conn_linux.go
Normal file
@@ -0,0 +1,330 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
)
|
||||
|
||||
type kcpPendingPacketDebug struct {
|
||||
remoteAddr net.Addr
|
||||
packetBytes int
|
||||
kcpConv *uint32
|
||||
timestamps map[string]int64
|
||||
}
|
||||
|
||||
type platformKCPPacketConn struct {
|
||||
*kcpPacketConnBase
|
||||
|
||||
raw syscall.RawConn
|
||||
|
||||
writeMu sync.Mutex
|
||||
pendingMu sync.Mutex
|
||||
pendingTX map[uint32]kcpPendingPacketDebug
|
||||
nextTXID uint32
|
||||
}
|
||||
|
||||
func newPlatformKCPPacketConn(conn *net.UDPConn, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
|
||||
packetConn := &platformKCPPacketConn{
|
||||
kcpPacketConnBase: &kcpPacketConnBase{
|
||||
conn: conn,
|
||||
logger: logger,
|
||||
nodeRole: nodeRole,
|
||||
nodeID: nodeID,
|
||||
closed: make(chan struct{}),
|
||||
},
|
||||
pendingTX: make(map[uint32]kcpPendingPacketDebug),
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
if err := packetConn.initLinuxTimestamping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go packetConn.collectTXErrqueueLoop()
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) Close() error {
|
||||
return c.kcpPacketConnBase.Close()
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
if c.raw == nil {
|
||||
return c.conn.ReadFrom(p)
|
||||
}
|
||||
|
||||
for {
|
||||
n, addr, rxTimestamp, err := c.recvmsgRaw(p)
|
||||
if err != nil {
|
||||
if isWouldBlock(err) {
|
||||
time.Sleep(linuxDataPollInterval)
|
||||
continue
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if rxTimestamp > 0 {
|
||||
c.logKCPPacketDebugRecord(c.newKCPPacketDebugRecord(
|
||||
latencylog.EventBRXSoftware,
|
||||
addr,
|
||||
n,
|
||||
rxTimestamp,
|
||||
nil,
|
||||
parseKCPConversationID(p[:n]),
|
||||
))
|
||||
}
|
||||
|
||||
return n, addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if c.raw == nil {
|
||||
return c.conn.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("transport: kcp packet write target must be UDPAddr, got %T", addr)
|
||||
}
|
||||
|
||||
expectedTXID := c.nextExpectedTXID()
|
||||
for {
|
||||
err := c.sendmsgRaw(p, udpAddr)
|
||||
if err != nil {
|
||||
if isWouldBlock(err) {
|
||||
time.Sleep(linuxDataPollInterval)
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.storePendingTX(expectedTXID, udpAddr, len(p), parseKCPConversationID(p))
|
||||
return len(p), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) initLinuxTimestamping() error {
|
||||
rawConn, err := c.conn.SyscallConn()
|
||||
if err != nil || rawConn == nil {
|
||||
if err != nil {
|
||||
return fmt.Errorf("transport: kcp get syscall conn: %w", err)
|
||||
}
|
||||
return fmt.Errorf("transport: kcp missing syscall conn")
|
||||
}
|
||||
|
||||
if err := configureLinuxSocketWriteBuffer(rawConn); err != nil {
|
||||
return fmt.Errorf("transport: kcp configure socket write buffer: %w", err)
|
||||
}
|
||||
|
||||
flagCandidates := []int{
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptID |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, flags := range flagCandidates {
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if lastErr == nil {
|
||||
c.raw = rawConn
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(lastErr, syscall.EINVAL) {
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) recvmsgRaw(buf []byte) (int, net.Addr, int64, error) {
|
||||
var (
|
||||
n int
|
||||
rxTimeNS int64
|
||||
from syscall.Sockaddr
|
||||
opErr error
|
||||
)
|
||||
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
readN, oobN, _, sa, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
|
||||
if recvErr != nil {
|
||||
opErr = recvErr
|
||||
return
|
||||
}
|
||||
n = readN
|
||||
from = sa
|
||||
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
|
||||
})
|
||||
if err != nil {
|
||||
return 0, nil, 0, err
|
||||
}
|
||||
if opErr != nil {
|
||||
return 0, nil, 0, opErr
|
||||
}
|
||||
|
||||
return n, sockaddrToUDPAddr(from), rxTimeNS, nil
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) sendmsgRaw(payload []byte, addr *net.UDPAddr) error {
|
||||
var opErr error
|
||||
sa := udpAddrToSockaddr(addr)
|
||||
if sa == nil {
|
||||
return fmt.Errorf("transport: invalid udp addr %v", addr)
|
||||
}
|
||||
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
opErr = syscall.Sendmsg(int(fd), payload, nil, sa, 0)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return opErr
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) collectTXErrqueueLoop() {
|
||||
for {
|
||||
event, err := c.recvTXErrqueueOnce()
|
||||
if err != nil {
|
||||
if isWouldBlock(err) {
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
time.Sleep(linuxTXTimestampPollInterval)
|
||||
continue
|
||||
}
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
if event.EventName == "" || event.TSUnixNano <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if event.EventName != latencylog.EventATXSched && event.EventName != latencylog.EventATXSoftware {
|
||||
continue
|
||||
}
|
||||
|
||||
record, complete := c.recordPendingTXEvent(event)
|
||||
if record == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
udpTxID := event.EEData
|
||||
c.logKCPPacketDebugRecord(c.newKCPPacketDebugRecord(
|
||||
event.EventName,
|
||||
record.remoteAddr,
|
||||
record.packetBytes,
|
||||
event.TSUnixNano,
|
||||
&udpTxID,
|
||||
record.kcpConv,
|
||||
))
|
||||
|
||||
if complete {
|
||||
c.pendingMu.Lock()
|
||||
delete(c.pendingTX, event.EEData)
|
||||
c.pendingMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) recvTXErrqueueOnce() (txTimestampEvent, error) {
|
||||
var (
|
||||
event txTimestampEvent
|
||||
opErr error
|
||||
)
|
||||
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
_, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
|
||||
if recvErr != nil {
|
||||
opErr = recvErr
|
||||
return
|
||||
}
|
||||
event, _ = parseTXTimestampControlMessages(oob[:oobN])
|
||||
})
|
||||
if err != nil {
|
||||
return txTimestampEvent{}, err
|
||||
}
|
||||
if opErr != nil {
|
||||
return txTimestampEvent{}, opErr
|
||||
}
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) nextExpectedTXID() uint32 {
|
||||
c.pendingMu.Lock()
|
||||
defer c.pendingMu.Unlock()
|
||||
|
||||
next := c.nextTXID
|
||||
c.nextTXID++
|
||||
return next
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) storePendingTX(txID uint32, remoteAddr net.Addr, packetBytes int, kcpConv *uint32) {
|
||||
c.pendingMu.Lock()
|
||||
defer c.pendingMu.Unlock()
|
||||
|
||||
c.pendingTX[txID] = kcpPendingPacketDebug{
|
||||
remoteAddr: remoteAddr,
|
||||
packetBytes: packetBytes,
|
||||
kcpConv: kcpConv,
|
||||
timestamps: make(map[string]int64, 2),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) recordPendingTXEvent(event txTimestampEvent) (*kcpPendingPacketDebug, bool) {
|
||||
c.pendingMu.Lock()
|
||||
defer c.pendingMu.Unlock()
|
||||
|
||||
record, ok := c.pendingTX[event.EEData]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if existing, exists := record.timestamps[event.EventName]; !exists || event.TSUnixNano < existing {
|
||||
record.timestamps[event.EventName] = event.TSUnixNano
|
||||
}
|
||||
c.pendingTX[event.EEData] = record
|
||||
|
||||
complete := hasCompleteTXTimestampPair(record.timestamps)
|
||||
copyRecord := record
|
||||
return ©Record, complete
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) isClosed() bool {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return true
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
||||
29
cmd/internal/transport/kcp_packet_conn_other.go
Normal file
29
cmd/internal/transport/kcp_packet_conn_other.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build !linux
|
||||
|
||||
package transport
|
||||
|
||||
import "net"
|
||||
|
||||
type platformKCPPacketConn struct {
|
||||
*kcpPacketConnBase
|
||||
}
|
||||
|
||||
func newPlatformKCPPacketConn(conn *net.UDPConn, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
|
||||
return &platformKCPPacketConn{
|
||||
kcpPacketConnBase: &kcpPacketConnBase{
|
||||
conn: conn,
|
||||
logger: logger,
|
||||
nodeRole: nodeRole,
|
||||
nodeID: nodeID,
|
||||
closed: make(chan struct{}),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
return c.conn.ReadFrom(p)
|
||||
}
|
||||
|
||||
func (c *platformKCPPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
return c.conn.WriteTo(p, addr)
|
||||
}
|
||||
76
cmd/internal/transport/kcp_packet_debug.go
Normal file
76
cmd/internal/transport/kcp_packet_debug.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// KCPPacketDebugRecord 是 KCP 底层 UDP packet kernel timestamp 的一条 JSONL 调试记录。
|
||||
type KCPPacketDebugRecord struct {
|
||||
Event string `json:"event"`
|
||||
NodeRole string `json:"node_role,omitempty"`
|
||||
NodeID string `json:"node_id,omitempty"`
|
||||
LocalAddr string `json:"local_addr,omitempty"`
|
||||
RemoteAddr string `json:"remote_addr,omitempty"`
|
||||
PacketBytes int `json:"packet_bytes"`
|
||||
UDPTXID *uint32 `json:"udp_tx_id,omitempty"`
|
||||
KCPConv *uint32 `json:"kcp_conv,omitempty"`
|
||||
TSUnixNano int64 `json:"ts_unix_nano"`
|
||||
}
|
||||
|
||||
// KCPPacketDebugLogger 接收 KCP packet 级调试记录。
|
||||
type KCPPacketDebugLogger interface {
|
||||
LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error
|
||||
}
|
||||
|
||||
// JSONLKCPPacketDebugLogger 以 JSONL 形式追加写 KCP packet 调试日志。
|
||||
type JSONLKCPPacketDebugLogger struct {
|
||||
mu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
file *os.File
|
||||
}
|
||||
|
||||
// NewJSONLKCPPacketDebugLogger 创建一个线程安全的 KCP packet JSONL 日志器。
|
||||
func NewJSONLKCPPacketDebugLogger(path string) (*JSONLKCPPacketDebugLogger, error) {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("transport: create kcp packet debug log dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transport: open kcp packet debug log %s: %w", path, err)
|
||||
}
|
||||
|
||||
return &JSONLKCPPacketDebugLogger{file: file}, nil
|
||||
}
|
||||
|
||||
// LogKCPPacketDebugRecord 以单行 JSON 的形式追加一条 KCP packet 调试记录。
|
||||
func (l *JSONLKCPPacketDebugLogger) LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error {
|
||||
line, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if _, err := l.file.Write(append(line, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭底层文件;重复调用是安全的。
|
||||
func (l *JSONLKCPPacketDebugLogger) Close() error {
|
||||
l.closeOnce.Do(func() {
|
||||
l.closeErr = l.file.Close()
|
||||
})
|
||||
|
||||
return l.closeErr
|
||||
}
|
||||
284
cmd/internal/transport/kcp_test.go
Normal file
284
cmd/internal/transport/kcp_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
type recordingKCPPacketDebugLogger struct {
|
||||
mu sync.Mutex
|
||||
records []KCPPacketDebugRecord
|
||||
}
|
||||
|
||||
func (l *recordingKCPPacketDebugLogger) LogKCPPacketDebugRecord(record KCPPacketDebugRecord) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.records = append(l.records, record)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *recordingKCPPacketDebugLogger) Records() []KCPPacketDebugRecord {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
return append([]KCPPacketDebugRecord(nil), l.records...)
|
||||
}
|
||||
|
||||
type kcpAcceptResult struct {
|
||||
conn *KCPConn
|
||||
err error
|
||||
}
|
||||
|
||||
func TestKCPSendReceiveMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello kcp"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x00, 0x10, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sender, accepted, cleanup := newKCPConnPair(
|
||||
t,
|
||||
nil,
|
||||
[]KCPOption{WithKCPLogger(latencylog.NoopLogger{}, latencylog.NodeRolePeer, "peer-b")},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg)
|
||||
}()
|
||||
|
||||
receiver := awaitAcceptedKCPConn(t, accepted)
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receiver.Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("sender.Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKCPSendLogsHandoffEvents(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
sender, accepted, cleanup := newKCPConnPair(
|
||||
t,
|
||||
[]KCPOption{WithKCPLogger(logger, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
receiver := awaitAcceptedKCPConn(t, accepted)
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receiver.Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("sender.Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("event count = %d, want 2", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventSendHandoffBegin {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
|
||||
}
|
||||
if events[1].Event != latencylog.EventSendHandoffEnd {
|
||||
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventSendHandoffEnd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKCPReceiveLoopStopsOnClose(t *testing.T) {
|
||||
sender, accepted, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
receiver := awaitAcceptedKCPConn(t, accepted)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
got []protocol.Message
|
||||
)
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
mu.Lock()
|
||||
got = append(got, msg)
|
||||
mu.Unlock()
|
||||
return receiver.Close()
|
||||
})
|
||||
}()
|
||||
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("sender.Send() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil || (!strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "pipe")) {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(got) != 1 || !reflect.DeepEqual(got[0], msg) {
|
||||
t.Fatalf("received messages mismatch: got %+v want [%+v]", got, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKCPCloseIsIdempotent(t *testing.T) {
|
||||
sender, _, cleanup := newKCPConnPair(t, nil, nil, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("Close(first) error = %v", err)
|
||||
}
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("Close(second) error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newKCPConnPair(t *testing.T, senderOpts []KCPOption, receiverOpts []KCPOption, senderPacketLogger KCPPacketDebugLogger, receiverPacketLogger KCPPacketDebugLogger) (*KCPConn, <-chan kcpAcceptResult, func()) {
|
||||
t.Helper()
|
||||
|
||||
listener, packetConn, err := ListenKCPSessions("127.0.0.1:0", "", receiverPacketLogger, latencylog.NodeRolePeer, "peer-b")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenKCPSessions() error = %v", err)
|
||||
}
|
||||
|
||||
accepted := make(chan kcpAcceptResult, 1)
|
||||
go func() {
|
||||
session, acceptErr := listener.AcceptKCP()
|
||||
if acceptErr != nil {
|
||||
accepted <- kcpAcceptResult{err: acceptErr}
|
||||
return
|
||||
}
|
||||
|
||||
conn, connErr := NewKCPConn(session, receiverOpts...)
|
||||
accepted <- kcpAcceptResult{conn: conn, err: connErr}
|
||||
}()
|
||||
|
||||
session, err := DialKCPSession(listener.Addr().String(), "", "", senderPacketLogger, latencylog.NodeRolePeer, "peer-a")
|
||||
if err != nil {
|
||||
_ = packetConn.Close()
|
||||
_ = listener.Close()
|
||||
t.Fatalf("DialKCPSession() error = %v", err)
|
||||
}
|
||||
|
||||
sender, err := NewKCPConn(session, senderOpts...)
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
_ = packetConn.Close()
|
||||
_ = listener.Close()
|
||||
t.Fatalf("NewKCPConn(sender) error = %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
_ = sender.Close()
|
||||
select {
|
||||
case result := <-accepted:
|
||||
if result.conn != nil {
|
||||
_ = result.conn.Close()
|
||||
}
|
||||
default:
|
||||
}
|
||||
_ = listener.Close()
|
||||
_ = packetConn.Close()
|
||||
}
|
||||
|
||||
return sender, accepted, cleanup
|
||||
}
|
||||
|
||||
func awaitAcceptedKCPConn(t *testing.T, accepted <-chan kcpAcceptResult) *KCPConn {
|
||||
t.Helper()
|
||||
|
||||
result := <-accepted
|
||||
if result.err != nil {
|
||||
t.Fatalf("AcceptKCP() error = %v", result.err)
|
||||
}
|
||||
if result.conn == nil {
|
||||
t.Fatal("accepted KCP conn = nil")
|
||||
}
|
||||
return result.conn
|
||||
}
|
||||
|
||||
func waitForKCPPacketRecords(t *testing.T, logger *recordingKCPPacketDebugLogger, condition func([]KCPPacketDebugRecord) bool, description string) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
records := logger.Records()
|
||||
if condition(records) {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("timed out waiting for %s", description)
|
||||
}
|
||||
24
cmd/internal/transport/udp_device_linux.go
Normal file
24
cmd/internal/transport/udp_device_linux.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// udpBindDeviceControl 返回一个 Control 函数,用于在 Linux 上将 UDP socket 绑定到指定网卡设备。
|
||||
func udpBindDeviceControl(device string) (func(string, string, syscall.RawConn) error, error) {
|
||||
return func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var bindErr error
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
bindErr = syscall.BindToDevice(int(fd), device)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if bindErr != nil {
|
||||
return fmt.Errorf("transport: bind device %s: %w", device, bindErr)
|
||||
}
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
12
cmd/internal/transport/udp_device_other.go
Normal file
12
cmd/internal/transport/udp_device_other.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func udpBindDeviceControl(device string) (func(string, string, syscall.RawConn) error, error) {
|
||||
return nil, fmt.Errorf("transport: bind device %s is only supported on linux", device)
|
||||
}
|
||||
@@ -141,10 +141,11 @@ func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
got = append(got, msg)
|
||||
if len(got) >= len(want) {
|
||||
return nil
|
||||
done := len(got) >= len(want)
|
||||
mu.Unlock()
|
||||
if done {
|
||||
return receiver.Close()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -156,17 +157,12 @@ func TestUDPReceiveLoopDeliversMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭发送端,ReceiveLoop 会因读取错误退出
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil after receiver close")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "udp receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
if !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want close-related error", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
@@ -268,6 +264,7 @@ func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOpti
|
||||
_ = conn1.Close()
|
||||
t.Fatalf("ListenUDP(2) error = %v", err)
|
||||
}
|
||||
receiverLocalAddr := conn2.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// 用 Dial 模式连接对端
|
||||
senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr))
|
||||
@@ -277,14 +274,13 @@ func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOpti
|
||||
t.Fatalf("DialUDP(sender) error = %v", err)
|
||||
}
|
||||
_ = conn1.Close() // 不再需要 conn1
|
||||
_ = conn2.Close() // 释放 receiver 计划使用的本地地址
|
||||
|
||||
receiverRaw, err := net.DialUDP("udp", conn2.LocalAddr().(*net.UDPAddr), senderRaw.LocalAddr().(*net.UDPAddr))
|
||||
receiverRaw, err := net.DialUDP("udp", receiverLocalAddr, senderRaw.LocalAddr().(*net.UDPAddr))
|
||||
if err != nil {
|
||||
_ = senderRaw.Close()
|
||||
_ = conn2.Close()
|
||||
t.Fatalf("DialUDP(receiver) error = %v", err)
|
||||
}
|
||||
_ = conn2.Close() // 不再需要 conn2
|
||||
|
||||
sender, err := NewUDPConn(senderRaw, nil, senderOpts...)
|
||||
if err != nil {
|
||||
|
||||
86
cmd/kcppeer/interactive.go
Normal file
86
cmd/kcppeer/interactive.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
kcpInteractiveCommandHelp = "help"
|
||||
kcpInteractiveCommandQuit = "quit"
|
||||
kcpInteractiveCommandText = "text"
|
||||
kcpInteractiveCommandFile = "file"
|
||||
)
|
||||
|
||||
var errKCPEmptyInteractiveCommand = errors.New("interactive command is empty")
|
||||
|
||||
type kcpInteractiveCommand struct {
|
||||
name string
|
||||
to string
|
||||
value string
|
||||
}
|
||||
|
||||
func parseKCPInteractiveCommand(line string) (kcpInteractiveCommand, error) {
|
||||
commandName, rest, ok := cutKCPInteractiveField(strings.TrimSpace(line))
|
||||
if !ok {
|
||||
return kcpInteractiveCommand{}, errKCPEmptyInteractiveCommand
|
||||
}
|
||||
|
||||
switch strings.ToLower(commandName) {
|
||||
case "help", "h", "?":
|
||||
return kcpInteractiveCommand{name: kcpInteractiveCommandHelp}, nil
|
||||
case "quit", "exit":
|
||||
return kcpInteractiveCommand{name: kcpInteractiveCommandQuit}, nil
|
||||
case kcpInteractiveCommandText:
|
||||
to, body, err := parseKCPInteractiveTargetValue(rest, kcpInteractiveCommandText)
|
||||
if err != nil {
|
||||
return kcpInteractiveCommand{}, err
|
||||
}
|
||||
return kcpInteractiveCommand{name: kcpInteractiveCommandText, to: to, value: body}, nil
|
||||
case kcpInteractiveCommandFile:
|
||||
to, path, err := parseKCPInteractiveTargetValue(rest, kcpInteractiveCommandFile)
|
||||
if err != nil {
|
||||
return kcpInteractiveCommand{}, err
|
||||
}
|
||||
return kcpInteractiveCommand{name: kcpInteractiveCommandFile, to: to, value: path}, nil
|
||||
default:
|
||||
return kcpInteractiveCommand{}, fmt.Errorf("unknown command %q; type help for usage", commandName)
|
||||
}
|
||||
}
|
||||
|
||||
func parseKCPInteractiveTargetValue(rest, commandName string) (string, string, error) {
|
||||
to, value, ok := cutKCPInteractiveField(strings.TrimSpace(rest))
|
||||
if !ok {
|
||||
return "", "", fmt.Errorf("%s command requires a target peer and payload", commandName)
|
||||
}
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return "", "", fmt.Errorf("%s command requires a non-empty payload", commandName)
|
||||
}
|
||||
|
||||
return to, strings.TrimSpace(value), nil
|
||||
}
|
||||
|
||||
func cutKCPInteractiveField(input string) (string, string, bool) {
|
||||
trimmed := strings.TrimSpace(input)
|
||||
if trimmed == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
for i, r := range trimmed {
|
||||
if r == ' ' || r == '\t' {
|
||||
return trimmed[:i], strings.TrimSpace(trimmed[i+1:]), true
|
||||
}
|
||||
}
|
||||
|
||||
return trimmed, "", true
|
||||
}
|
||||
|
||||
func printKCPInteractiveHelp(w io.Writer) {
|
||||
_, _ = fmt.Fprintln(w, "interactive mode commands (KCP):")
|
||||
_, _ = fmt.Fprintln(w, " help show this help")
|
||||
_, _ = fmt.Fprintln(w, " text <peer> <message> send one text message over the existing KCP session")
|
||||
_, _ = fmt.Fprintln(w, " file <peer> <path> send one file over the existing KCP session")
|
||||
_, _ = fmt.Fprintln(w, " quit exit this peer process")
|
||||
}
|
||||
188
cmd/kcppeer/main.go
Normal file
188
cmd/kcppeer/main.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
peerpkg "omnisocketgo/cmd/internal/peer"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
func main() {
|
||||
peerID := flag.String("id", "peer-a", "peer identity")
|
||||
serverAddr := flag.String("server", "127.0.0.1:9002", "KCP server address")
|
||||
targetPeer := flag.String("to", "", "optional target peer for one outgoing message")
|
||||
text := flag.String("text", "", "optional text to send after connecting")
|
||||
filePath := flag.String("file", "", "optional file path to send after connecting")
|
||||
bindIP := flag.String("bind-ip", "", "optional local source IP used when dialing the server")
|
||||
bindDevice := flag.String("bind-device", "", "optional Linux network device used when dialing the server")
|
||||
inboxDir := flag.String("inbox-dir", "inbox", "directory used to persist received text and file messages")
|
||||
logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs")
|
||||
kcpTimestampDebugLogPath := flag.String("kcp-ts-debug-log", "", "optional JSONL file path for KCP packet kernel timestamp debug records")
|
||||
interactive := flag.Bool("interactive", true, "enable interactive REPL for repeated text/file sends on the same connection")
|
||||
flag.Parse()
|
||||
|
||||
clientOptions := make([]peerpkg.Option, 0, 5)
|
||||
if *logPath != "" {
|
||||
logger, err := latencylog.NewJSONLLogger(*logPath)
|
||||
if err != nil {
|
||||
log.Fatalf("create latency logger %s: %v", *logPath, err)
|
||||
}
|
||||
defer logger.Close()
|
||||
clientOptions = append(clientOptions, peerpkg.WithLogger(logger))
|
||||
}
|
||||
if *kcpTimestampDebugLogPath != "" {
|
||||
logger, err := transport.NewJSONLKCPPacketDebugLogger(*kcpTimestampDebugLogPath)
|
||||
if err != nil {
|
||||
log.Fatalf("create kcp packet debug logger %s: %v", *kcpTimestampDebugLogPath, err)
|
||||
}
|
||||
defer logger.Close()
|
||||
clientOptions = append(clientOptions, peerpkg.WithKCPPacketDebugLogger(logger))
|
||||
}
|
||||
if *bindIP != "" {
|
||||
clientOptions = append(clientOptions, peerpkg.WithBindIP(*bindIP))
|
||||
}
|
||||
if *bindDevice != "" {
|
||||
clientOptions = append(clientOptions, peerpkg.WithBindDevice(*bindDevice))
|
||||
}
|
||||
|
||||
client, err := peerpkg.DialKCP(*serverAddr, *peerID, clientOptions...)
|
||||
if err != nil {
|
||||
log.Fatalf("dial kcp server %s: %v", *serverAddr, err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
log.Printf("connected to %s as %s (KCP)", *serverAddr, client.ID())
|
||||
|
||||
receiveErr := make(chan error, 1)
|
||||
go func() {
|
||||
receiveErr <- client.ReceiveLoop(func(msg protocol.Message) error {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText:
|
||||
path, err := client.PersistMessage(msg, *inboxDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("received text from %s to %s and persisted to %s", msg.From, msg.To, path)
|
||||
case protocol.MessageTypeFile:
|
||||
path, err := client.PersistMessage(msg, *inboxDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("received file from %s to %s: %s (%d bytes) -> %s", msg.From, msg.To, msg.FileName, len(msg.Body), path)
|
||||
case protocol.MessageTypeError:
|
||||
log.Printf("received %s from %s to %s: %s", msg.Type, msg.From, msg.To, string(msg.Body))
|
||||
default:
|
||||
log.Printf("received unexpected message type %s from %s", msg.Type, msg.From)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
if *text != "" && *filePath != "" {
|
||||
log.Fatal("only one of -text or -file may be specified")
|
||||
}
|
||||
if (*text != "" || *filePath != "") && *targetPeer == "" {
|
||||
log.Fatal("flag -to is required when sending text or file")
|
||||
}
|
||||
|
||||
if *targetPeer != "" && *text != "" {
|
||||
if err := client.SendText(*targetPeer, *text); err != nil {
|
||||
log.Fatalf("send text to %s: %v", *targetPeer, err)
|
||||
}
|
||||
log.Printf("sent text to %s", *targetPeer)
|
||||
}
|
||||
if *targetPeer != "" && *filePath != "" {
|
||||
if err := client.SendFilePath(*targetPeer, *filePath); err != nil {
|
||||
log.Fatalf("send file %s to %s: %v", *filePath, *targetPeer, err)
|
||||
}
|
||||
log.Printf("sent file %s to %s", *filePath, *targetPeer)
|
||||
}
|
||||
|
||||
if *interactive {
|
||||
if err := runKCPInteractiveShell(client, os.Stdin, os.Stdout, receiveErr); err != nil {
|
||||
log.Printf("interactive shell ended: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := <-receiveErr; err != nil {
|
||||
log.Printf("receive loop ended: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runKCPInteractiveShell(client *peerpkg.KCPClient, in io.Reader, out io.Writer, receiveErr <-chan error) error {
|
||||
printKCPInteractiveHelp(out)
|
||||
lines, inputErr := readKCPInteractiveLines(in, out, fmt.Sprintf("%s> ", client.ID()))
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-receiveErr:
|
||||
return err
|
||||
case line, ok := <-lines:
|
||||
if !ok {
|
||||
return <-inputErr
|
||||
}
|
||||
|
||||
command, err := parseKCPInteractiveCommand(line)
|
||||
if err != nil {
|
||||
if err == errKCPEmptyInteractiveCommand {
|
||||
continue
|
||||
}
|
||||
log.Printf("interactive command error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch command.name {
|
||||
case kcpInteractiveCommandHelp:
|
||||
printKCPInteractiveHelp(out)
|
||||
case kcpInteractiveCommandQuit:
|
||||
return nil
|
||||
case kcpInteractiveCommandText:
|
||||
if err := client.SendText(command.to, command.value); err != nil {
|
||||
log.Printf("send text to %s: %v", command.to, err)
|
||||
continue
|
||||
}
|
||||
log.Printf("sent text to %s", command.to)
|
||||
case kcpInteractiveCommandFile:
|
||||
if err := client.SendFilePath(command.to, command.value); err != nil {
|
||||
log.Printf("send file %s to %s: %v", command.value, command.to, err)
|
||||
continue
|
||||
}
|
||||
log.Printf("sent file %s to %s", command.value, command.to)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readKCPInteractiveLines(in io.Reader, out io.Writer, prompt string) (<-chan string, <-chan error) {
|
||||
lines := make(chan string)
|
||||
errs := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(lines)
|
||||
|
||||
scanner := bufio.NewScanner(in)
|
||||
scanner.Buffer(make([]byte, 0, 1024), 1024*1024)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(out, prompt); err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
if !scanner.Scan() {
|
||||
errs <- scanner.Err()
|
||||
return
|
||||
}
|
||||
lines <- scanner.Text()
|
||||
}
|
||||
}()
|
||||
|
||||
return lines, errs
|
||||
}
|
||||
68
cmd/kcpserver/main.go
Normal file
68
cmd/kcpserver/main.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
kcp "github.com/xtaci/kcp-go/v5"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/server"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
func main() {
|
||||
listenAddr := flag.String("listen", ":9002", "KCP server listen address")
|
||||
bindDevice := flag.String("bind-device", "", "optional Linux network device used when listening")
|
||||
logPath := flag.String("latency-log", "", "optional JSONL file path for latency timestamp logs")
|
||||
kcpTimestampDebugLogPath := flag.String("kcp-ts-debug-log", "", "optional JSONL file path for KCP packet kernel timestamp debug records")
|
||||
flag.Parse()
|
||||
|
||||
hubOptions := make([]server.KCPOption, 0, 1)
|
||||
if *logPath != "" {
|
||||
logger, err := latencylog.NewJSONLLogger(*logPath)
|
||||
if err != nil {
|
||||
log.Fatalf("create latency logger %s: %v", *logPath, err)
|
||||
}
|
||||
defer logger.Close()
|
||||
hubOptions = append(hubOptions, server.WithKCPLogger(logger))
|
||||
}
|
||||
|
||||
var packetLogger transport.KCPPacketDebugLogger
|
||||
if *kcpTimestampDebugLogPath != "" {
|
||||
logger, err := transport.NewJSONLKCPPacketDebugLogger(*kcpTimestampDebugLogPath)
|
||||
if err != nil {
|
||||
log.Fatalf("create kcp packet debug logger %s: %v", *kcpTimestampDebugLogPath, err)
|
||||
}
|
||||
defer logger.Close()
|
||||
packetLogger = logger
|
||||
}
|
||||
|
||||
listener, packetConn, err := transport.ListenKCPSessions(*listenAddr, *bindDevice, packetLogger, latencylog.NodeRoleServer, "hub")
|
||||
if err != nil {
|
||||
log.Fatalf("listen kcp on %s: %v", *listenAddr, err)
|
||||
}
|
||||
defer packetConn.Close()
|
||||
defer listener.Close()
|
||||
|
||||
hub := server.NewKCPHub(hubOptions...)
|
||||
log.Printf("kcp server listening on %s", listener.Addr())
|
||||
|
||||
for {
|
||||
session, err := listener.AcceptKCP()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "closed") {
|
||||
return
|
||||
}
|
||||
log.Printf("accept kcp session: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func(sess *kcp.UDPSession) {
|
||||
if serveErr := hub.ServeSession(sess); serveErr != nil {
|
||||
log.Printf("kcp session closed: %v", serveErr)
|
||||
}
|
||||
}(session)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user