Files
OmniSocketGo/cmd/internal/transport/kcp_packet_conn_linux.go

345 lines
7.5 KiB
Go

//go:build linux
package transport
import (
"errors"
"fmt"
"net"
"sync"
"syscall"
"time"
"omnisocketgo/cmd/internal/latencylog"
)
type kcpPendingPacketDebug struct {
remoteAddr net.Addr
packetBytes int
kcpConv *uint32
segments []KCPPacketDebugSegment
timestamps map[string]int64
}
type platformKCPPacketConn struct {
*kcpPacketConnBase
raw syscall.RawConn
writeMu sync.Mutex
pendingMu sync.Mutex
pendingTX map[uint32]kcpPendingPacketDebug
nextTXID uint32
}
func newPlatformKCPPacketConn(conn *net.UDPConn, logger KCPPacketDebugLogger, nodeRole, nodeID string) (net.PacketConn, error) {
packetConn := &platformKCPPacketConn{
kcpPacketConnBase: &kcpPacketConnBase{
conn: conn,
logger: logger,
nodeRole: nodeRole,
nodeID: nodeID,
closed: make(chan struct{}),
},
pendingTX: make(map[uint32]kcpPendingPacketDebug),
}
if logger == nil {
return packetConn, nil
}
if err := packetConn.initLinuxTimestamping(); err != nil {
return nil, err
}
go packetConn.collectTXErrqueueLoop()
return packetConn, nil
}
func (c *platformKCPPacketConn) Close() error {
return c.kcpPacketConnBase.Close()
}
func (c *platformKCPPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
if c.raw == nil {
return c.conn.ReadFrom(p)
}
for {
n, addr, rxTimestamp, err := c.recvmsgRaw(p)
if err != nil {
if isWouldBlock(err) {
time.Sleep(linuxDataPollInterval)
continue
}
return 0, nil, err
}
if rxTimestamp > 0 {
kcpConv, segments := parseKCPPacketMetadata(p[:n])
c.logKCPPacketDebugRecord(c.newKCPPacketDebugRecord(
latencylog.EventBRXSoftware,
addr,
n,
rxTimestamp,
nil,
kcpConv,
segments,
))
}
return n, addr, nil
}
}
func (c *platformKCPPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.raw == nil {
return c.conn.WriteTo(p, addr)
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, fmt.Errorf("transport: kcp packet write target must be UDPAddr, got %T", addr)
}
// Reserve the local txID before the send so an immediately-arriving errqueue
// event can still find its pending record. If the send never succeeds, roll
// the reservation back to keep the local txID mirror aligned with the kernel.
kcpConv, segments := parseKCPPacketMetadata(p)
expectedTXID := c.reservePendingTX(udpAddr, len(p), kcpConv, segments)
for {
err := c.sendmsgRaw(p, udpAddr)
if err != nil {
if isWouldBlock(err) {
time.Sleep(linuxDataPollInterval)
continue
}
c.rollbackPendingTX(expectedTXID)
return 0, err
}
return len(p), nil
}
}
func (c *platformKCPPacketConn) initLinuxTimestamping() error {
rawConn, err := c.conn.SyscallConn()
if err != nil || rawConn == nil {
if err != nil {
return fmt.Errorf("transport: kcp get syscall conn: %w", err)
}
return fmt.Errorf("transport: kcp missing syscall conn")
}
if err := configureLinuxSocketWriteBuffer(rawConn); err != nil {
return fmt.Errorf("transport: kcp configure socket write buffer: %w", err)
}
flagCandidates := []int{
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptID |
linuxSOFTimestampingOptTSONLY,
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptTSONLY,
}
var lastErr error
for _, flags := range flagCandidates {
err := rawConn.Control(func(fd uintptr) {
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
})
if err != nil {
return err
}
if lastErr == nil {
c.raw = rawConn
return nil
}
if !errors.Is(lastErr, syscall.EINVAL) {
return lastErr
}
}
return lastErr
}
func (c *platformKCPPacketConn) recvmsgRaw(buf []byte) (int, net.Addr, int64, error) {
var (
n int
rxTimeNS int64
from syscall.Sockaddr
opErr error
)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
readN, oobN, _, sa, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
if recvErr != nil {
opErr = recvErr
return
}
n = readN
from = sa
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
})
if err != nil {
return 0, nil, 0, err
}
if opErr != nil {
return 0, nil, 0, opErr
}
return n, sockaddrToUDPAddr(from), rxTimeNS, nil
}
func (c *platformKCPPacketConn) sendmsgRaw(payload []byte, addr *net.UDPAddr) error {
var opErr error
sa := udpAddrToSockaddr(addr)
if sa == nil {
return fmt.Errorf("transport: invalid udp addr %v", addr)
}
err := c.raw.Control(func(fd uintptr) {
opErr = syscall.Sendmsg(int(fd), payload, nil, sa, 0)
})
if err != nil {
return err
}
return opErr
}
func (c *platformKCPPacketConn) collectTXErrqueueLoop() {
for {
event, err := c.recvTXErrqueueOnce()
if err != nil {
if isWouldBlock(err) {
if c.isClosed() {
return
}
time.Sleep(linuxTXTimestampPollInterval)
continue
}
if c.isClosed() {
return
}
return
}
if event.EventName == "" || event.TSUnixNano <= 0 {
continue
}
if event.EventName != latencylog.EventATXSched && event.EventName != latencylog.EventATXSoftware {
continue
}
record, complete := c.recordPendingTXEvent(event)
if record == nil {
continue
}
udpTxID := event.EEData
c.logKCPPacketDebugRecord(c.newKCPPacketDebugRecord(
event.EventName,
record.remoteAddr,
record.packetBytes,
event.TSUnixNano,
&udpTxID,
record.kcpConv,
record.segments,
))
if complete {
c.pendingMu.Lock()
delete(c.pendingTX, event.EEData)
c.pendingMu.Unlock()
}
}
}
func (c *platformKCPPacketConn) recvTXErrqueueOnce() (txTimestampEvent, error) {
var (
event txTimestampEvent
opErr error
)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
_, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
if recvErr != nil {
opErr = recvErr
return
}
event, _ = parseTXTimestampControlMessages(oob[:oobN])
})
if err != nil {
return txTimestampEvent{}, err
}
if opErr != nil {
return txTimestampEvent{}, opErr
}
return event, nil
}
func (c *platformKCPPacketConn) reservePendingTX(remoteAddr net.Addr, packetBytes int, kcpConv *uint32, segments []KCPPacketDebugSegment) uint32 {
c.pendingMu.Lock()
defer c.pendingMu.Unlock()
txID := c.nextTXID
c.nextTXID++
c.pendingTX[txID] = kcpPendingPacketDebug{
remoteAddr: remoteAddr,
packetBytes: packetBytes,
kcpConv: kcpConv,
segments: append([]KCPPacketDebugSegment(nil), segments...),
timestamps: make(map[string]int64, 2),
}
return txID
}
func (c *platformKCPPacketConn) rollbackPendingTX(txID uint32) {
c.pendingMu.Lock()
defer c.pendingMu.Unlock()
delete(c.pendingTX, txID)
if c.nextTXID == txID+1 {
c.nextTXID = txID
}
}
func (c *platformKCPPacketConn) recordPendingTXEvent(event txTimestampEvent) (*kcpPendingPacketDebug, bool) {
c.pendingMu.Lock()
defer c.pendingMu.Unlock()
record, ok := c.pendingTX[event.EEData]
if !ok {
return nil, false
}
if existing, exists := record.timestamps[event.EventName]; !exists || event.TSUnixNano < existing {
record.timestamps[event.EventName] = event.TSUnixNano
}
c.pendingTX[event.EEData] = record
complete := hasCompleteTXTimestampPair(record.timestamps)
copyRecord := record
return &copyRecord, complete
}
func (c *platformKCPPacketConn) isClosed() bool {
select {
case <-c.closed:
return true
default:
}
return false
}