feat:KCP协议

This commit is contained in:
nnbcccscdscdsc
2026-03-24 21:09:06 +08:00
parent 290ba18962
commit be013b701b
20 changed files with 2284 additions and 16 deletions

View File

@@ -0,0 +1,256 @@
package transport
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"sync"
kcp "github.com/xtaci/kcp-go/v5"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
const (
kcpNoDelayNodelay = 1
kcpNoDelayInterval = 10
kcpNoDelayResend = 2
kcpNoDelayNC = 1
kcpWindowSize = 256
kcpMTU = 1400
)
// KCPConn 是对单条活跃 KCP 会话的轻量封装。
type KCPConn struct {
session *kcp.UDPSession
logger latencylog.Logger
nodeRole string
nodeID string
writeMu sync.Mutex
closeOnce sync.Once
closeErr error
}
// KCPOption 用于为 KCPConn 注入可选行为。
type KCPOption func(*KCPConn)
// WithKCPLogger 为 KCP 连接发送路径注入业务消息日志上下文。
func WithKCPLogger(logger latencylog.Logger, nodeRole, nodeID string) KCPOption {
return func(conn *KCPConn) {
conn.logger = logger
conn.nodeRole = nodeRole
conn.nodeID = nodeID
}
}
// NewKCPConn 用已有的 KCP 会话创建 transport 连接封装。
func NewKCPConn(session *kcp.UDPSession, opts ...KCPOption) (*KCPConn, error) {
if session == nil {
return nil, fmt.Errorf("transport: nil kcp session")
}
conn := &KCPConn{
session: session,
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(conn)
}
if conn.logger == nil {
conn.logger = latencylog.NoopLogger{}
}
configureKCPSession(session)
return conn, nil
}
// Send 将一条协议消息完整写入底层 KCP 会话。
func (c *KCPConn) Send(msg protocol.Message) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
if err := protocol.WriteMessage(c.session, msg); err != nil {
return fmt.Errorf("transport: kcp send message: %w", err)
}
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
return nil
}
// Receive 从底层 KCP 会话读取一条完整协议消息。
func (c *KCPConn) Receive() (protocol.Message, error) {
msg, err := protocol.ReadMessage(c.session)
if err != nil {
return protocol.Message{}, fmt.Errorf("transport: kcp receive message: %w", err)
}
return msg, nil
}
// ReceiveLoop 持续读取消息并交给 handler 处理。
func (c *KCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
for {
msg, err := c.Receive()
if err != nil {
_ = c.Close()
return fmt.Errorf("transport: kcp receive loop read: %w", err)
}
if err := handler(msg); err != nil {
_ = c.Close()
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
}
}
}
// Close 关闭底层 KCP 会话,并保证重复调用是安全的。
func (c *KCPConn) Close() error {
c.closeOnce.Do(func() {
c.closeErr = c.session.Close()
})
return c.closeErr
}
// DialKCPSession 创建一条主动发起的 KCP 会话,并按项目默认参数配置底层 UDP socket。
func DialKCPSession(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.UDPSession, error) {
packetConn, remoteAddr, err := dialKCPPacketConn(serverAddr, bindIP, bindDevice, logger, nodeRole, nodeID)
if err != nil {
return nil, err
}
convID, err := generateKCPConversationID()
if err != nil {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: generate kcp conversation id: %w", err)
}
session, err := kcp.NewConn4(convID, remoteAddr, nil, 0, 0, true, packetConn)
if err != nil {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: create kcp session: %w", err)
}
return session, nil
}
// ListenKCPSessions 在给定地址上启动 KCP listener并返回 listener 与底层 packetConn。
func ListenKCPSessions(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.Listener, net.PacketConn, error) {
packetConn, err := listenKCPPacketConn(listenAddr, bindDevice, logger, nodeRole, nodeID)
if err != nil {
return nil, nil, err
}
listener, err := kcp.ServeConn(nil, 0, 0, packetConn)
if err != nil {
_ = packetConn.Close()
return nil, nil, fmt.Errorf("transport: serve kcp listener: %w", err)
}
return listener, packetConn, nil
}
func configureKCPSession(session *kcp.UDPSession) {
session.SetStreamMode(true)
session.SetNoDelay(kcpNoDelayNodelay, kcpNoDelayInterval, kcpNoDelayResend, kcpNoDelayNC)
session.SetWindowSize(kcpWindowSize, kcpWindowSize)
session.SetACKNoDelay(true)
session.SetWriteDelay(false)
session.SetMtu(kcpMTU)
}
func generateKCPConversationID() (uint32, error) {
var convID uint32
if err := binary.Read(rand.Reader, binary.LittleEndian, &convID); err != nil {
return 0, err
}
return convID, nil
}
func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("transport: resolve kcp listen addr %s: %w", listenAddr, err)
}
rawConn, err := listenUDPConn("udp", udpAddr, bindDevice)
if err != nil {
return nil, fmt.Errorf("transport: listen udp for kcp on %s: %w", listenAddr, err)
}
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
if err != nil {
_ = rawConn.Close()
return nil, err
}
return packetConn, nil
}
func dialKCPPacketConn(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, *net.UDPAddr, error) {
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
return nil, nil, fmt.Errorf("transport: resolve kcp server addr %s: %w", serverAddr, err)
}
localAddr := &net.UDPAddr{Port: 0}
if bindIP != "" {
ip := net.ParseIP(bindIP)
if ip == nil {
return nil, nil, fmt.Errorf("transport: invalid bind ip %q", bindIP)
}
localAddr.IP = ip
}
network := "udp"
if remoteAddr.IP.To4() != nil {
network = "udp4"
}
rawConn, err := listenUDPConn(network, localAddr, bindDevice)
if err != nil {
return nil, nil, fmt.Errorf("transport: listen udp for kcp dial to %s: %w", serverAddr, err)
}
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
if err != nil {
_ = rawConn.Close()
return nil, nil, err
}
return packetConn, remoteAddr, nil
}
func listenUDPConn(network string, localAddr *net.UDPAddr, bindDevice string) (*net.UDPConn, error) {
listenConfig := net.ListenConfig{}
if bindDevice != "" {
control, err := udpBindDeviceControl(bindDevice)
if err != nil {
return nil, err
}
listenConfig.Control = control
}
packetConn, err := listenConfig.ListenPacket(context.Background(), network, udpListenAddr(localAddr))
if err != nil {
return nil, err
}
udpConn, ok := packetConn.(*net.UDPConn)
if !ok {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: expected *net.UDPConn, got %T", packetConn)
}
return udpConn, nil
}
func udpListenAddr(addr *net.UDPAddr) string {
if addr == nil {
return ":0"
}
return addr.String()
}