Files
OmniSocketGo/cmd/internal/transport/udp_linux.go
2026-03-24 15:39:00 +08:00

359 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build linux
package transport
import (
"errors"
"fmt"
"net"
"syscall"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
)
// UDP 接收缓冲区大小,足以容纳 MaxFrameSize 加上协议头。
const udpReceiveBufferSize = protocol.MaxFrameSize + 1024
// initUDPLinuxTimestamping 拿到底层 fd并打开 Linux timestamping。
func (c *UDPConn) initUDPLinuxTimestamping() error {
rawConn, err := c.conn.SyscallConn()
if err != nil || rawConn == nil {
if err != nil {
return fmt.Errorf("transport: udp get syscall conn: %w", err)
}
return fmt.Errorf("transport: udp missing syscall conn")
}
// UDP 不需要 OPT_ID_TCP使用标准的 OPT_ID 即可。
flagCandidates := []int{
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptID |
linuxSOFTimestampingOptTSONLY,
linuxSOFTimestampingTXSched |
linuxSOFTimestampingTXSoftware |
linuxSOFTimestampingRXSoftware |
linuxSOFTimestampingSoftware |
linuxSOFTimestampingOptTSONLY,
}
var lastErr error
for _, flags := range flagCandidates {
err := rawConn.Control(func(fd uintptr) {
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
})
if err != nil {
return err
}
if lastErr == nil {
c.raw = rawConn
return nil
}
if !errors.Is(lastErr, syscall.EINVAL) {
return lastErr
}
}
return lastErr
}
// sendMessageLinux 编码消息并通过 UDP 发送,采集 TX 时间戳。
func (c *UDPConn) sendMessageLinux(msg protocol.Message) error {
payload, err := protocol.EncodeMessage(msg)
if err != nil {
return fmt.Errorf("protocol: encode message: %w", err)
}
readIndex := 0
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex)
if c.peerAddr != nil {
if err := c.udpSendTo(payload, c.peerAddr); err != nil {
return err
}
} else {
if err := c.udpSend(payload); err != nil {
return err
}
}
c.collectAndLogUDPTXTimestampEvents(msg, &readIndex)
return nil
}
// sendMessageToLinux 编码消息并通过 UDP 发送到指定地址,采集 TX 时间戳。
func (c *UDPConn) sendMessageToLinux(msg protocol.Message, addr *net.UDPAddr) error {
payload, err := protocol.EncodeMessage(msg)
if err != nil {
return fmt.Errorf("protocol: encode message: %w", err)
}
readIndex := 0
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePreSendDrain, &readIndex)
if err := c.udpSendTo(payload, addr); err != nil {
return err
}
c.collectAndLogUDPTXTimestampEvents(msg, &readIndex)
return nil
}
// udpSend 通过已连接的 UDP socket 发送数据。
func (c *UDPConn) udpSend(payload []byte) error {
if c.raw != nil {
return c.udpSendmsgRaw(payload, nil)
}
_, err := c.conn.Write(payload)
return err
}
// udpSendTo 通过 UDP socket 发送数据到指定地址。
func (c *UDPConn) udpSendTo(payload []byte, addr *net.UDPAddr) error {
if c.raw != nil {
sa := udpAddrToSockaddr(addr)
if sa != nil {
return c.udpSendmsgRaw(payload, sa)
}
}
_, err := c.conn.WriteToUDP(payload, addr)
return err
}
// udpSendmsgRaw 通过 sendmsg syscall 发送 UDP 数据。
func (c *UDPConn) udpSendmsgRaw(payload []byte, to syscall.Sockaddr) error {
var opErr error
for {
err := c.raw.Control(func(fd uintptr) {
opErr = syscall.Sendmsg(int(fd), payload, nil, to, 0)
})
if err != nil {
return err
}
if opErr == nil {
return nil
}
if isWouldBlock(opErr) {
time.Sleep(linuxDataPollInterval)
continue
}
return opErr
}
}
// receiveMessageLinux 从 UDP 连接读取一条完整消息,并记录 RX 时间戳。
func (c *UDPConn) receiveMessageLinux() (protocol.Message, *net.UDPAddr, error) {
payload, addr, rxTimestamp, err := c.udpRecvFrom()
if err != nil {
return protocol.Message{}, nil, fmt.Errorf("protocol: udp read: %w", err)
}
msg, err := protocol.DecodeMessage(payload)
if err != nil {
return protocol.Message{}, nil, fmt.Errorf("protocol: decode message: %w", err)
}
if rxTimestamp > 0 {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg)
}
return msg, addr, nil
}
// udpRecvFrom 从 UDP socket 接收一个完整数据报,返回数据、来源地址和 RX 时间戳。
func (c *UDPConn) udpRecvFrom() ([]byte, *net.UDPAddr, int64, error) {
if c.raw != nil {
return c.udpRecvmsgRaw()
}
buf := make([]byte, udpReceiveBufferSize)
n, addr, err := c.conn.ReadFromUDP(buf)
if err != nil {
return nil, nil, 0, err
}
return buf[:n], addr, 0, nil
}
// udpRecvmsgRaw 通过 recvmsg syscall 接收 UDP 数据,同时采集 RX 时间戳。
func (c *UDPConn) udpRecvmsgRaw() ([]byte, *net.UDPAddr, int64, error) {
for {
var (
n int
rxTimeNS int64
from syscall.Sockaddr
opErr error
)
buf := make([]byte, udpReceiveBufferSize)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
readN, oobN, _, sa, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
if recvErr != nil {
opErr = recvErr
return
}
n = readN
from = sa
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
})
if err != nil {
return nil, nil, 0, err
}
if opErr != nil {
if isWouldBlock(opErr) {
time.Sleep(linuxDataPollInterval)
continue
}
return nil, nil, 0, opErr
}
addr := sockaddrToUDPAddr(from)
return buf[:n], addr, rxTimeNS, nil
}
}
// collectAndLogUDPTXTimestampEvents 采集并记录 UDP 发送的 TX 时间戳事件。
func (c *UDPConn) collectAndLogUDPTXTimestampEvents(msg protocol.Message, readIndex *int) {
timestamps := c.collectUDPTXTimestampEvents(msg, readIndex)
if ts, ok := timestamps[latencylog.EventATXSched]; ok {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg)
}
if ts, ok := timestamps[latencylog.EventATXSoftware]; ok {
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg)
}
}
// collectUDPTXTimestampEvents 在 errqueue 中等待 TX 时间戳。
func (c *UDPConn) collectUDPTXTimestampEvents(msg protocol.Message, readIndex *int) map[string]int64 {
if c.raw == nil {
return nil
}
deadline := time.Now().Add(linuxTXTimestampWaitTimeout)
timestamps := make(map[string]int64, 2)
for time.Now().Before(deadline) {
event, err := c.recvUDPTXTimestampOnce()
if err != nil {
if isWouldBlock(err) {
time.Sleep(linuxTXTimestampPollInterval)
continue
}
break
}
if event.EventName == "" || event.TSUnixNano <= 0 {
continue
}
*readIndex++
if isBusinessTXTimestampEventName(event.EventName) {
if _, exists := timestamps[event.EventName]; !exists {
timestamps[event.EventName] = event.TSUnixNano
}
}
if hasCompleteTXTimestampPair(timestamps) {
break
}
}
c.drainPendingUDPTXTimestampEvents(msg, linuxTXTimestampPhasePostSelectDrain, readIndex)
return timestamps
}
// drainPendingUDPTXTimestampEvents 清空 errqueue 中残留的时间戳事件。
func (c *UDPConn) drainPendingUDPTXTimestampEvents(msg protocol.Message, phase string, readIndex *int) {
if c.raw == nil {
return
}
for {
event, err := c.recvUDPTXTimestampOnce()
if err != nil {
return
}
if event.EventName == "" || event.TSUnixNano <= 0 {
continue
}
*readIndex++
}
}
// recvUDPTXTimestampOnce 从 errqueue 读一次时间戳事件。
func (c *UDPConn) recvUDPTXTimestampOnce() (txTimestampEvent, error) {
var (
event txTimestampEvent
opErr error
)
err := c.raw.Control(func(fd uintptr) {
oob := make([]byte, linuxTimestampControlBufferSize)
_, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
if recvErr != nil {
opErr = recvErr
return
}
event, _ = parseTXTimestampControlMessages(oob[:oobn])
})
if err != nil {
return txTimestampEvent{}, err
}
if opErr != nil {
return txTimestampEvent{}, opErr
}
return event, nil
}
// udpAddrToSockaddr 将 net.UDPAddr 转换为 syscall.Sockaddr。
func udpAddrToSockaddr(addr *net.UDPAddr) syscall.Sockaddr {
if ip4 := addr.IP.To4(); ip4 != nil {
sa := &syscall.SockaddrInet4{Port: addr.Port}
copy(sa.Addr[:], ip4)
return sa
}
if ip6 := addr.IP.To16(); ip6 != nil {
sa := &syscall.SockaddrInet6{Port: addr.Port}
copy(sa.Addr[:], ip6)
return sa
}
return nil
}
// sockaddrToUDPAddr 将 syscall.Sockaddr 转换为 net.UDPAddr。
func sockaddrToUDPAddr(sa syscall.Sockaddr) *net.UDPAddr {
switch addr := sa.(type) {
case *syscall.SockaddrInet4:
return &net.UDPAddr{
IP: net.IP(addr.Addr[:]),
Port: addr.Port,
}
case *syscall.SockaddrInet6:
return &net.UDPAddr{
IP: net.IP(addr.Addr[:]),
Port: addr.Port,
Zone: zoneToString(addr.ZoneId),
}
default:
return nil
}
}
func zoneToString(zone uint32) string {
if zone == 0 {
return ""
}
iface, err := net.InterfaceByIndex(int(zone))
if err != nil {
return ""
}
return iface.Name
}