257 lines
6.8 KiB
Go
257 lines
6.8 KiB
Go
package transport
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/binary"
|
||
"fmt"
|
||
"net"
|
||
"sync"
|
||
|
||
kcp "github.com/xtaci/kcp-go/v5"
|
||
|
||
"omnisocketgo/cmd/internal/latencylog"
|
||
"omnisocketgo/cmd/internal/protocol"
|
||
)
|
||
|
||
const (
|
||
kcpNoDelayNodelay = 1
|
||
kcpNoDelayInterval = 10
|
||
kcpNoDelayResend = 2
|
||
kcpNoDelayNC = 1
|
||
kcpWindowSize = 256
|
||
kcpMTU = 1400
|
||
)
|
||
|
||
// KCPConn 是对单条活跃 KCP 会话的轻量封装。
|
||
type KCPConn struct {
|
||
session *kcp.UDPSession
|
||
|
||
logger latencylog.Logger
|
||
nodeRole string
|
||
nodeID string
|
||
|
||
writeMu sync.Mutex
|
||
closeOnce sync.Once
|
||
closeErr error
|
||
}
|
||
|
||
// KCPOption 用于为 KCPConn 注入可选行为。
|
||
type KCPOption func(*KCPConn)
|
||
|
||
// WithKCPLogger 为 KCP 连接发送路径注入业务消息日志上下文。
|
||
func WithKCPLogger(logger latencylog.Logger, nodeRole, nodeID string) KCPOption {
|
||
return func(conn *KCPConn) {
|
||
conn.logger = logger
|
||
conn.nodeRole = nodeRole
|
||
conn.nodeID = nodeID
|
||
}
|
||
}
|
||
|
||
// NewKCPConn 用已有的 KCP 会话创建 transport 连接封装。
|
||
func NewKCPConn(session *kcp.UDPSession, opts ...KCPOption) (*KCPConn, error) {
|
||
if session == nil {
|
||
return nil, fmt.Errorf("transport: nil kcp session")
|
||
}
|
||
|
||
conn := &KCPConn{
|
||
session: session,
|
||
logger: latencylog.NoopLogger{},
|
||
}
|
||
for _, opt := range opts {
|
||
opt(conn)
|
||
}
|
||
if conn.logger == nil {
|
||
conn.logger = latencylog.NoopLogger{}
|
||
}
|
||
|
||
configureKCPSession(session)
|
||
return conn, nil
|
||
}
|
||
|
||
// Send 将一条协议消息完整写入底层 KCP 会话。
|
||
func (c *KCPConn) Send(msg protocol.Message) error {
|
||
c.writeMu.Lock()
|
||
defer c.writeMu.Unlock()
|
||
|
||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
|
||
if err := protocol.WriteMessage(c.session, msg); err != nil {
|
||
return fmt.Errorf("transport: kcp send message: %w", err)
|
||
}
|
||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
|
||
return nil
|
||
}
|
||
|
||
// Receive 从底层 KCP 会话读取一条完整协议消息。
|
||
func (c *KCPConn) Receive() (protocol.Message, error) {
|
||
msg, err := protocol.ReadMessage(c.session)
|
||
if err != nil {
|
||
return protocol.Message{}, fmt.Errorf("transport: kcp receive message: %w", err)
|
||
}
|
||
return msg, nil
|
||
}
|
||
|
||
// ReceiveLoop 持续读取消息并交给 handler 处理。
|
||
func (c *KCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
|
||
for {
|
||
msg, err := c.Receive()
|
||
if err != nil {
|
||
_ = c.Close()
|
||
return fmt.Errorf("transport: kcp receive loop read: %w", err)
|
||
}
|
||
|
||
if err := handler(msg); err != nil {
|
||
_ = c.Close()
|
||
return fmt.Errorf("transport: kcp receive loop handler: %w", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Close 关闭底层 KCP 会话,并保证重复调用是安全的。
|
||
func (c *KCPConn) Close() error {
|
||
c.closeOnce.Do(func() {
|
||
c.closeErr = c.session.Close()
|
||
})
|
||
return c.closeErr
|
||
}
|
||
|
||
// DialKCPSession 创建一条主动发起的 KCP 会话,并按项目默认参数配置底层 UDP socket。
|
||
func DialKCPSession(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.UDPSession, error) {
|
||
packetConn, remoteAddr, err := dialKCPPacketConn(serverAddr, bindIP, bindDevice, logger, nodeRole, nodeID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
convID, err := generateKCPConversationID()
|
||
if err != nil {
|
||
_ = packetConn.Close()
|
||
return nil, fmt.Errorf("transport: generate kcp conversation id: %w", err)
|
||
}
|
||
|
||
session, err := kcp.NewConn4(convID, remoteAddr, nil, 0, 0, true, packetConn)
|
||
if err != nil {
|
||
_ = packetConn.Close()
|
||
return nil, fmt.Errorf("transport: create kcp session: %w", err)
|
||
}
|
||
|
||
return session, nil
|
||
}
|
||
|
||
// ListenKCPSessions 在给定地址上启动 KCP listener,并返回 listener 与底层 packetConn。
|
||
func ListenKCPSessions(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (*kcp.Listener, net.PacketConn, error) {
|
||
packetConn, err := listenKCPPacketConn(listenAddr, bindDevice, logger, nodeRole, nodeID)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
listener, err := kcp.ServeConn(nil, 0, 0, packetConn)
|
||
if err != nil {
|
||
_ = packetConn.Close()
|
||
return nil, nil, fmt.Errorf("transport: serve kcp listener: %w", err)
|
||
}
|
||
|
||
return listener, packetConn, nil
|
||
}
|
||
|
||
func configureKCPSession(session *kcp.UDPSession) {
|
||
session.SetStreamMode(true)
|
||
session.SetNoDelay(kcpNoDelayNodelay, kcpNoDelayInterval, kcpNoDelayResend, kcpNoDelayNC)
|
||
session.SetWindowSize(kcpWindowSize, kcpWindowSize)
|
||
session.SetACKNoDelay(true)
|
||
session.SetWriteDelay(false)
|
||
session.SetMtu(kcpMTU)
|
||
}
|
||
|
||
func generateKCPConversationID() (uint32, error) {
|
||
var convID uint32
|
||
if err := binary.Read(rand.Reader, binary.LittleEndian, &convID); err != nil {
|
||
return 0, err
|
||
}
|
||
return convID, nil
|
||
}
|
||
|
||
func listenKCPPacketConn(listenAddr, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
|
||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("transport: resolve kcp listen addr %s: %w", listenAddr, err)
|
||
}
|
||
|
||
rawConn, err := listenUDPConn("udp", udpAddr, bindDevice)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("transport: listen udp for kcp on %s: %w", listenAddr, err)
|
||
}
|
||
|
||
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
|
||
if err != nil {
|
||
_ = rawConn.Close()
|
||
return nil, err
|
||
}
|
||
|
||
return packetConn, nil
|
||
}
|
||
|
||
func dialKCPPacketConn(serverAddr, bindIP, bindDevice string, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, *net.UDPAddr, error) {
|
||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("transport: resolve kcp server addr %s: %w", serverAddr, err)
|
||
}
|
||
|
||
localAddr := &net.UDPAddr{Port: 0}
|
||
if bindIP != "" {
|
||
ip := net.ParseIP(bindIP)
|
||
if ip == nil {
|
||
return nil, nil, fmt.Errorf("transport: invalid bind ip %q", bindIP)
|
||
}
|
||
localAddr.IP = ip
|
||
}
|
||
|
||
network := "udp"
|
||
if remoteAddr.IP.To4() != nil {
|
||
network = "udp4"
|
||
}
|
||
|
||
rawConn, err := listenUDPConn(network, localAddr, bindDevice)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("transport: listen udp for kcp dial to %s: %w", serverAddr, err)
|
||
}
|
||
|
||
packetConn, err := newKCPPacketConn(rawConn, logger, nodeRole, nodeID)
|
||
if err != nil {
|
||
_ = rawConn.Close()
|
||
return nil, nil, err
|
||
}
|
||
|
||
return packetConn, remoteAddr, nil
|
||
}
|
||
|
||
func listenUDPConn(network string, localAddr *net.UDPAddr, bindDevice string) (*net.UDPConn, error) {
|
||
listenConfig := net.ListenConfig{}
|
||
if bindDevice != "" {
|
||
control, err := udpBindDeviceControl(bindDevice)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
listenConfig.Control = control
|
||
}
|
||
|
||
packetConn, err := listenConfig.ListenPacket(context.Background(), network, udpListenAddr(localAddr))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
udpConn, ok := packetConn.(*net.UDPConn)
|
||
if !ok {
|
||
_ = packetConn.Close()
|
||
return nil, fmt.Errorf("transport: expected *net.UDPConn, got %T", packetConn)
|
||
}
|
||
|
||
return udpConn, nil
|
||
}
|
||
|
||
func udpListenAddr(addr *net.UDPAddr) string {
|
||
if addr == nil {
|
||
return ":0"
|
||
}
|
||
return addr.String()
|
||
}
|