Files
OmniSocketGo/cmd/internal/transport/kcp.go
nnbcccscdscdsc be013b701b feat:KCP协议
2026-03-24 21:09:06 +08:00

257 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package transport
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"sync"
kcp "github.com/xtaci/kcp-go/v5"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
const (
kcpNoDelayNodelay = 1
kcpNoDelayInterval = 10
kcpNoDelayResend = 2
kcpNoDelayNC = 1
kcpWindowSize = 256
kcpMTU = 1400
)
// KCPConn 是对单条活跃 KCP 会话的轻量封装。
type KCPConn struct {
session *kcp.UDPSession
logger latencylog.Logger
nodeRole string
nodeID string
writeMu sync.Mutex
closeOnce sync.Once
closeErr error
}
// KCPOption 用于为 KCPConn 注入可选行为。
type KCPOption func(*KCPConn)
// WithKCPLogger 为 KCP 连接发送路径注入业务消息日志上下文。
func WithKCPLogger(logger latencylog.Logger, nodeRole, nodeID string) KCPOption {
return func(conn *KCPConn) {
conn.logger = logger
conn.nodeRole = nodeRole
conn.nodeID = nodeID
}
}
// NewKCPConn 用已有的 KCP 会话创建 transport 连接封装。
func NewKCPConn(session *kcp.UDPSession, opts ...KCPOption) (*KCPConn, error) {
if session == nil {
return nil, fmt.Errorf("transport: nil kcp session")
}
conn := &KCPConn{
session: session,
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(conn)
}
if conn.logger == nil {
conn.logger = latencylog.NoopLogger{}
}
configureKCPSession(session)
return conn, nil
}
// Send 将一条协议消息完整写入底层 KCP 会话。
func (c *KCPConn) Send(msg protocol.Message) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
if err := protocol.WriteMessage(c.session, msg); err != nil {
return fmt.Errorf("transport: kcp send message: %w", err)
}
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
return nil
}
// Receive 从底层 KCP 会话读取一条完整协议消息。
func (c *KCPConn) Receive() (protocol.Message, error) {
msg, err := protocol.ReadMessage(c.session)
if err != nil {
return protocol.Message{}, fmt.Errorf("transport: kcp receive message: %w", err)
}
return msg, nil
}
// ReceiveLoop 持续读取消息并交给 handler 处理。
func (c *KCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
for {
msg, err := c.Receive()
if err != nil {
_ = c.Close()
return fmt.Errorf("transport: kcp receive loop read: %w", err)
}
if err := handler(msg); err != nil {
_ = c.Close()
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
}
}
}
// Close 关闭底层 KCP 会话,并保证重复调用是安全的。
func (c *KCPConn) Close() error {
c.closeOnce.Do(func() {
c.closeErr = c.session.Close()
})
return c.closeErr
}
// DialKCPSession 创建一条主动发起的 KCP 会话,并按项目默认参数配置底层 UDP socket。
func DialKCPSession(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.UDPSession, error) {
packetConn, remoteAddr, err := dialKCPPacketConn(serverAddr, bindIP, bindDevice, logger, nodeRole, nodeID)
if err != nil {
return nil, err
}
convID, err := generateKCPConversationID()
if err != nil {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: generate kcp conversation id: %w", err)
}
session, err := kcp.NewConn4(convID, remoteAddr, nil, 0, 0, true, packetConn)
if err != nil {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: create kcp session: %w", err)
}
return session, nil
}
// ListenKCPSessions 在给定地址上启动 KCP listener并返回 listener 与底层 packetConn。
func ListenKCPSessions(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.Listener, net.PacketConn, error) {
packetConn, err := listenKCPPacketConn(listenAddr, bindDevice, logger, nodeRole, nodeID)
if err != nil {
return nil, nil, err
}
listener, err := kcp.ServeConn(nil, 0, 0, packetConn)
if err != nil {
_ = packetConn.Close()
return nil, nil, fmt.Errorf("transport: serve kcp listener: %w", err)
}
return listener, packetConn, nil
}
func configureKCPSession(session *kcp.UDPSession) {
session.SetStreamMode(true)
session.SetNoDelay(kcpNoDelayNodelay, kcpNoDelayInterval, kcpNoDelayResend, kcpNoDelayNC)
session.SetWindowSize(kcpWindowSize, kcpWindowSize)
session.SetACKNoDelay(true)
session.SetWriteDelay(false)
session.SetMtu(kcpMTU)
}
func generateKCPConversationID() (uint32, error) {
var convID uint32
if err := binary.Read(rand.Reader, binary.LittleEndian, &convID); err != nil {
return 0, err
}
return convID, nil
}
func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("transport: resolve kcp listen addr %s: %w", listenAddr, err)
}
rawConn, err := listenUDPConn("udp", udpAddr, bindDevice)
if err != nil {
return nil, fmt.Errorf("transport: listen udp for kcp on %s: %w", listenAddr, err)
}
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
if err != nil {
_ = rawConn.Close()
return nil, err
}
return packetConn, nil
}
func dialKCPPacketConn(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, *net.UDPAddr, error) {
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
return nil, nil, fmt.Errorf("transport: resolve kcp server addr %s: %w", serverAddr, err)
}
localAddr := &net.UDPAddr{Port: 0}
if bindIP != "" {
ip := net.ParseIP(bindIP)
if ip == nil {
return nil, nil, fmt.Errorf("transport: invalid bind ip %q", bindIP)
}
localAddr.IP = ip
}
network := "udp"
if remoteAddr.IP.To4() != nil {
network = "udp4"
}
rawConn, err := listenUDPConn(network, localAddr, bindDevice)
if err != nil {
return nil, nil, fmt.Errorf("transport: listen udp for kcp dial to %s: %w", serverAddr, err)
}
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
if err != nil {
_ = rawConn.Close()
return nil, nil, err
}
return packetConn, remoteAddr, nil
}
func listenUDPConn(network string, localAddr *net.UDPAddr, bindDevice string) (*net.UDPConn, error) {
listenConfig := net.ListenConfig{}
if bindDevice != "" {
control, err := udpBindDeviceControl(bindDevice)
if err != nil {
return nil, err
}
listenConfig.Control = control
}
packetConn, err := listenConfig.ListenPacket(context.Background(), network, udpListenAddr(localAddr))
if err != nil {
return nil, err
}
udpConn, ok := packetConn.(*net.UDPConn)
if !ok {
_ = packetConn.Close()
return nil, fmt.Errorf("transport: expected *net.UDPConn, got %T", packetConn)
}
return udpConn, nil
}
func udpListenAddr(addr *net.UDPAddr) string {
if addr == nil {
return ":0"
}
return addr.String()
}