package transport import ( "context" "crypto/rand" "encoding/binary" "fmt" "net" "sync" "time" kcp "github.com/xtaci/kcp-go/v5" "omnisocketgo/cmd/internal/latencylog" "omnisocketgo/cmd/internal/protocol" ) const ( kcpNoDelayNodelay = 1 kcpNoDelayInterval = 10 kcpNoDelayResend = 2 kcpNoDelayNC = 1 kcpWindowSize = 256 kcpMTU = 1400 ) // KCPConn 是对单条活跃 KCP 会话的轻量封装。 type KCPConn struct { session *kcp.UDPSession logger latencylog.Logger nodeRole string nodeID string sessionStatsLogger KCPSessionStatsLogger sessionStatsInterval time.Duration sessionStatsSampler *kcpSessionStatsSampler writeMu sync.Mutex closeOnce sync.Once closeErr error } // KCPOption 用于为 KCPConn 注入可选行为。 type KCPOption func(*KCPConn) // WithKCPLogger 为 KCP 连接发送路径注入业务消息日志上下文。 func WithKCPLogger(logger latencylog.Logger, nodeRole, nodeID string) KCPOption { return func(conn *KCPConn) { conn.logger = logger conn.nodeRole = nodeRole conn.nodeID = nodeID } } // WithKCPSessionStatsLogger 为 KCP 连接注入会话级与进程级统计日志器。 func WithKCPSessionStatsLogger(logger KCPSessionStatsLogger, interval time.Duration) KCPOption { return func(conn *KCPConn) { conn.sessionStatsLogger = logger conn.sessionStatsInterval = interval } } // NewKCPConn 用已有的 KCP 会话创建 transport 连接封装。 func NewKCPConn(session *kcp.UDPSession, opts ...KCPOption) (*KCPConn, error) { if session == nil { return nil, fmt.Errorf("transport: nil kcp session") } conn := &KCPConn{ session: session, logger: latencylog.NoopLogger{}, } for _, opt := range opts { opt(conn) } if conn.logger == nil { conn.logger = latencylog.NoopLogger{} } configureKCPSession(session) conn.sessionStatsSampler = newKCPSessionStatsSampler(session, conn.sessionStatsLogger, conn.nodeRole, conn.nodeID, conn.sessionStatsInterval) return conn, nil } // Send 将一条协议消息完整写入底层 KCP 会话。 func (c *KCPConn) Send(msg protocol.Message) error { c.writeMu.Lock() defer c.writeMu.Unlock() latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg) if c.sessionStatsSampler != nil { c.sessionStatsSampler.SampleEvent(kcpStatsSampleReasonSendHandoffBegin) } if err := protocol.WriteMessage(c.session, msg); err != nil { return fmt.Errorf("transport: kcp send message: %w", err) } latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg) if c.sessionStatsSampler != nil { c.sessionStatsSampler.SampleEvent(kcpStatsSampleReasonSendHandoffEnd) } return nil } // Receive 从底层 KCP 会话读取一条完整协议消息。 func (c *KCPConn) Receive() (protocol.Message, error) { msg, err := protocol.ReadMessage(c.session) if err != nil { return protocol.Message{}, fmt.Errorf("transport: kcp receive message: %w", err) } if c.sessionStatsSampler != nil { c.sessionStatsSampler.SampleEvent(kcpStatsSampleReasonReceive) } return msg, nil } // ReceiveLoop 持续读取消息并交给 handler 处理。 func (c *KCPConn) ReceiveLoop(handler func(protocol.Message) error) error { for { msg, err := c.Receive() if err != nil { _ = c.Close() return fmt.Errorf("transport: kcp receive loop read: %w", err) } if err := handler(msg); err != nil { _ = c.Close() return fmt.Errorf("transport: kcp receive loop handler: %w", err) } } } // Close 关闭底层 KCP 会话,并保证重复调用是安全的。 func (c *KCPConn) Close() error { c.closeOnce.Do(func() { if c.sessionStatsSampler != nil { c.sessionStatsSampler.Close() } c.closeErr = c.session.Close() }) return c.closeErr } // DialKCPSession 创建一条主动发起的 KCP 会话,并按项目默认参数配置底层 UDP socket。 func DialKCPSession(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.UDPSession, error) { packetConn, remoteAddr, err := dialKCPPacketConn(serverAddr, bindIP, bindDevice, logger, nodeRole, nodeID) if err != nil { return nil, err } convID, err := generateKCPConversationID() if err != nil { _ = packetConn.Close() return nil, fmt.Errorf("transport: generate kcp conversation id: %w", err) } session, err := kcp.NewConn4(convID, remoteAddr, nil, 0, 0, true, packetConn) if err != nil { _ = packetConn.Close() return nil, fmt.Errorf("transport: create kcp session: %w", err) } return session, nil } // ListenKCPSessions 在给定地址上启动 KCP listener,并返回 listener 与底层 packetConn。 func ListenKCPSessions(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.Listener, net.PacketConn, error) { packetConn, err := listenKCPPacketConn(listenAddr, bindDevice, logger, nodeRole, nodeID) if err != nil { return nil, nil, err } listener, err := kcp.ServeConn(nil, 0, 0, packetConn) if err != nil { _ = packetConn.Close() return nil, nil, fmt.Errorf("transport: serve kcp listener: %w", err) } return listener, packetConn, nil } func configureKCPSession(session *kcp.UDPSession) { session.SetStreamMode(true) session.SetNoDelay(kcpNoDelayNodelay, kcpNoDelayInterval, kcpNoDelayResend, kcpNoDelayNC) session.SetWindowSize(kcpWindowSize, kcpWindowSize) session.SetACKNoDelay(true) session.SetWriteDelay(false) session.SetMtu(kcpMTU) } func generateKCPConversationID() (uint32, error) { var convID uint32 if err := binary.Read(rand.Reader, binary.LittleEndian, &convID); err != nil { return 0, err } return convID, nil } // ResolveUDPListenConfig parses a UDP listen address and returns the socket // family that should be used for binding it. func ResolveUDPListenConfig(listenAddr string) (string, *net.UDPAddr, error) { udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { return "", nil, fmt.Errorf("transport: resolve udp listen addr %s: %w", listenAddr, err) } return udpListenNetwork(udpAddr), udpAddr, nil } func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) { network, udpAddr, err := ResolveUDPListenConfig(listenAddr) if err != nil { return nil, fmt.Errorf("transport: resolve kcp listen addr %s: %w", listenAddr, err) } rawConn, err := listenUDPConn(network, udpAddr, bindDevice) if err != nil { return nil, fmt.Errorf("transport: listen %s for kcp on %s: %w", network, udpListenAddr(udpAddr), err) } packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID) if err != nil { _ = rawConn.Close() return nil, err } return packetConn, nil } func dialKCPPacketConn(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, *net.UDPAddr, error) { remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) if err != nil { return nil, nil, fmt.Errorf("transport: resolve kcp server addr %s: %w", serverAddr, err) } localAddr := &net.UDPAddr{Port: 0} if bindIP != "" { ip := net.ParseIP(bindIP) if ip == nil { return nil, nil, fmt.Errorf("transport: invalid bind ip %q", bindIP) } localAddr.IP = ip } network := "udp" if remoteAddr.IP.To4() != nil { network = "udp4" } rawConn, err := listenUDPConn(network, localAddr, bindDevice) if err != nil { return nil, nil, fmt.Errorf("transport: listen udp for kcp dial to %s: %w", serverAddr, err) } packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID) if err != nil { _ = rawConn.Close() return nil, nil, err } return packetConn, remoteAddr, nil } func listenUDPConn(network string, localAddr *net.UDPAddr, bindDevice string) (*net.UDPConn, error) { listenConfig := net.ListenConfig{} if bindDevice != "" { control, err := udpBindDeviceControl(bindDevice) if err != nil { return nil, err } listenConfig.Control = control } packetConn, err := listenConfig.ListenPacket(context.Background(), network, udpListenAddr(localAddr)) if err != nil { return nil, err } udpConn, ok := packetConn.(*net.UDPConn) if !ok { _ = packetConn.Close() return nil, fmt.Errorf("transport: expected *net.UDPConn, got %T", packetConn) } return udpConn, nil } func udpListenAddr(addr *net.UDPAddr) string { if addr == nil { return ":0" } return addr.String() } func udpListenNetwork(addr *net.UDPAddr) string { if addr == nil || addr.IP == nil { return "udp" } if addr.IP.To4() != nil { return "udp4" } return "udp6" }