init
This commit is contained in:
151
cmd/internal/transport/tcp.go
Normal file
151
cmd/internal/transport/tcp.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
// TCPConn 是对单条活跃 TCP 连接的轻量封装。
|
||||
// 它负责把协议层的单条消息读写,提升为可复用的收发接口。
|
||||
type TCPConn struct {
|
||||
conn net.Conn
|
||||
raw syscall.RawConn // 连接对应的底层 syscall 句柄,用于 Linux socket timestamping 收发。
|
||||
|
||||
logger latencylog.Logger
|
||||
nodeRole string // 日志中记录的节点角色,例如 "server" 或 "peer"
|
||||
nodeID string // 日志中记录的节点 ID,例如 peer 的 ID 或 server 的 "hub"
|
||||
writeMu sync.Mutex // 保护 Send 方法的互斥锁,确保同一时刻只有一条完整协议消息被写入连接,防止多条消息字节交叉
|
||||
closeOnce sync.Once // 保护 Close 方法的 sync.Once,确保连接只被关闭一次
|
||||
closeErr error // 连接关闭时的错误,如果连接成功关闭则为 nil,重复调用 Close 时会返回同样的错误
|
||||
}
|
||||
|
||||
// Option 用于为 TCPConn 注入可选行为,例如时延日志。
|
||||
type Option func(*TCPConn)
|
||||
|
||||
// WithLogger 为连接发送路径注入业务消息日志上下文。
|
||||
func WithLogger(logger latencylog.Logger, nodeRole, nodeID string) Option {
|
||||
return func(conn *TCPConn) {
|
||||
conn.logger = logger
|
||||
conn.nodeRole = nodeRole
|
||||
conn.nodeID = nodeID
|
||||
}
|
||||
}
|
||||
|
||||
// NewTCPConn 用已有的 net.Conn 创建 transport 连接封装。
|
||||
func NewTCPConn(conn net.Conn, opts ...Option) (*TCPConn, error) {
|
||||
tcpConn := &TCPConn{
|
||||
conn: conn,
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(tcpConn)
|
||||
}
|
||||
|
||||
if tcpConn.logger == nil {
|
||||
tcpConn.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
if err := tcpConn.initLinuxTimestamping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
||||
|
||||
// Send 将一条协议消息完整写入底层连接。
|
||||
// 多个 goroutine 可以并发调用,内部会串行化写入。
|
||||
func (c *TCPConn) Send(msg protocol.Message) error {
|
||||
c.writeMu.Lock() //“同一时刻只能有一条完整协议消息往连接里写,防止多条消息字节交叉
|
||||
defer c.writeMu.Unlock()
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffBegin, msg)
|
||||
|
||||
if err := c.sendMessageLinux(msg); err != nil {
|
||||
return fmt.Errorf("transport: send message: %w", err)
|
||||
}
|
||||
//记录发送完成的时延日志事件,事件类型为 EventSendHandoffEnd,包含消息的基本信息(类型、ID、来源、目标)。
|
||||
latencylog.LogMessageEvent(c.logger, c.nodeRole, c.nodeID, latencylog.EventSendHandoffEnd, msg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive 从底层连接读取一条完整协议消息。
|
||||
// 同一条连接应只由单个 reader 持续调用该方法。
|
||||
func (c *TCPConn) Receive() (protocol.Message, error) {
|
||||
msg, err := c.receiveMessageLinux()
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("transport: receive message: %w", err)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReceiveLoop 持续读取消息并交给 handler 处理。
|
||||
// 读取错误、handler 错误或连接关闭都会结束循环,并关闭连接。
|
||||
func (c *TCPConn) ReceiveLoop(handler func(protocol.Message) error) error {
|
||||
for {
|
||||
msg, err := c.Receive()
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("transport: receive loop read: %w", err)
|
||||
}
|
||||
|
||||
if err := handler(msg); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("transport: receive loop handler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloseGracefully 在支持 half-close 的连接上先关闭写方向,给对端留出读取最终响应的机会,
|
||||
// 然后在短暂等待后再彻底关闭连接。
|
||||
func (c *TCPConn) CloseGracefully(drainTimeout time.Duration) error {
|
||||
if closeWriter, ok := c.conn.(interface{ CloseWrite() error }); ok {
|
||||
if err := closeWriter.CloseWrite(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
if drainTimeout > 0 {
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(drainTimeout))
|
||||
defer func() {
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
var buf [256]byte
|
||||
for {
|
||||
_, err := c.conn.Read(buf[:])
|
||||
switch {
|
||||
case err == nil:
|
||||
continue
|
||||
case errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
|
||||
return c.Close()
|
||||
default:
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return c.Close()
|
||||
}
|
||||
return c.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// Close 关闭底层连接,并保证重复调用是安全的。
|
||||
func (c *TCPConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
|
||||
return c.closeErr
|
||||
}
|
||||
462
cmd/internal/transport/tcp_linux.go
Normal file
462
cmd/internal/transport/tcp_linux.go
Normal file
@@ -0,0 +1,462 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
linuxTimestampControlBufferSize = 256 // 控制消息缓冲区。
|
||||
linuxTXTimestampWaitTimeout = 250 * time.Millisecond // 等待 TX 时间戳的上限。
|
||||
linuxTXTimestampPollInterval = time.Millisecond // 轮询 errqueue 的间隔。
|
||||
|
||||
linuxSOTimestampingNew = 0x41
|
||||
linuxSCMTimestampingNew = linuxSOTimestampingNew
|
||||
linuxSOEEOriginTimestamping = 4 // timestamping errqueue 事件。
|
||||
linuxSCMTstampSnd = 0 // 对应 A_TX_SOFTWARE。
|
||||
linuxSCMTstampSched = 1 // 对应 A_TX_SCHED。
|
||||
|
||||
linuxSOFTimestampingTXSoftware = 1 << 1 // 打开 TX software timestamp。
|
||||
linuxSOFTimestampingRXSoftware = 1 << 3 // 打开 RX software timestamp。
|
||||
linuxSOFTimestampingSoftware = 1 << 4 // software timestamp 总开关。
|
||||
linuxSOFTimestampingOptID = 1 << 7 // 给时间戳关联 ID。
|
||||
linuxSOFTimestampingTXSched = 1 << 8 // 打开 TX sched timestamp。
|
||||
linuxSOFTimestampingOptTSONLY = 1 << 11 // 只回时间戳。
|
||||
linuxSOFTimestampingOptIDTCP = 1 << 16 // 让 TCP 也带 timestamp ID。
|
||||
)
|
||||
|
||||
// 拿到底层 fd,并打开 Linux timestamping。
|
||||
func (c *TCPConn) initLinuxTimestamping() error {
|
||||
sysConn, ok := c.conn.(interface {
|
||||
SyscallConn() (syscall.RawConn, error)
|
||||
})
|
||||
if !ok {
|
||||
return fmt.Errorf("transport: connection does not support SyscallConn")
|
||||
}
|
||||
|
||||
rawConn, err := sysConn.SyscallConn()
|
||||
if err != nil || rawConn == nil {
|
||||
if err != nil {
|
||||
return fmt.Errorf("transport: get syscall conn: %w", err)
|
||||
}
|
||||
return fmt.Errorf("transport: missing syscall conn")
|
||||
}
|
||||
|
||||
//socket是否可以成功打开 timestamping 取决于内核版本和配置,尝试多个 flag 组合直到成功或遇到非 EINVAL 错误。
|
||||
if err := enableLinuxTimestamping(rawConn); err != nil {
|
||||
return fmt.Errorf("transport: enable linux timestamping: %w", err)
|
||||
}
|
||||
//成功打开 timestamping 后,rawConn 就可以用来收 TX/RX 时间戳了。
|
||||
c.raw = rawConn
|
||||
return nil
|
||||
}
|
||||
|
||||
// 给 socket开权限打开TX software timestamping。
|
||||
func enableLinuxTimestamping(rawConn syscall.RawConn) error {
|
||||
flagCandidates := []int{ //不同linux版本可能支持不同的 flag 组合,尝试多个组合直到成功。
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptID | //TCP 协议栈给每个时间戳生成一个序列号
|
||||
linuxSOFTimestampingOptIDTCP |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptID |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
linuxSOFTimestampingTXSched |
|
||||
linuxSOFTimestampingTXSoftware |
|
||||
linuxSOFTimestampingRXSoftware |
|
||||
linuxSOFTimestampingSoftware |
|
||||
linuxSOFTimestampingOptTSONLY,
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, flags := range flagCandidates { //尝试不同的 flag 组合,直到成功或遇到非 EINVAL 错误。
|
||||
// 内核根据 fd 找到对应的内存结构体(Socket 缓冲区)
|
||||
err := rawConn.Control(func(fd uintptr) { //Control 方法保证在回调里 fd 是有效的,可以安全地调用 syscall.SetsockoptInt。
|
||||
lastErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, linuxSOTimestampingNew, flags)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if lastErr == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(lastErr, syscall.EINVAL) {
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// sendMessageLinux 编码消息、写完整帧,再记录 TX 时间戳。
|
||||
func (c *TCPConn) sendMessageLinux(msg protocol.Message) error {
|
||||
payload, err := protocol.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("protocol: encode message: %w", err)
|
||||
}
|
||||
|
||||
//编码后的消息 payload 前面加 4 字节长度,构成完整帧。
|
||||
frame := make([]byte, 4+len(payload))
|
||||
binary.BigEndian.PutUint32(frame[:4], uint32(len(payload)))
|
||||
copy(frame[4:], payload)
|
||||
|
||||
if err := c.writeFrameLinux(frame); err != nil {
|
||||
return fmt.Errorf("protocol: write frame: %w", err)
|
||||
}
|
||||
//记录发送延时日志
|
||||
c.logTXTimestampEvents(msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeFrameLinux 用 sendmsg 写完整帧。
|
||||
func (c *TCPConn) writeFrameLinux(frame []byte) error {
|
||||
written := 0
|
||||
var opErr error
|
||||
|
||||
err := c.raw.Write(func(fd uintptr) bool {
|
||||
if written >= len(frame) {
|
||||
return true
|
||||
}
|
||||
|
||||
n, sendErr := syscall.SendmsgN(int(fd), frame[written:], nil, nil, 0)
|
||||
switch {
|
||||
case sendErr == nil:
|
||||
if n <= 0 {
|
||||
opErr = io.ErrShortWrite
|
||||
return true
|
||||
}
|
||||
written += n
|
||||
return written >= len(frame)
|
||||
case errors.Is(sendErr, syscall.EAGAIN), errors.Is(sendErr, syscall.EWOULDBLOCK):
|
||||
return false
|
||||
default:
|
||||
opErr = sendErr
|
||||
return true
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if opErr != nil {
|
||||
return opErr
|
||||
}
|
||||
if written != len(frame) {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 把 A_TX_SCHED / A_TX_SOFTWARE 写入日志。(发送过程中)
|
||||
func (c *TCPConn) logTXTimestampEvents(msg protocol.Message) {
|
||||
timestamps := c.collectTXTimestampEvents()
|
||||
|
||||
if ts, ok := timestamps[latencylog.EventATXSched]; ok {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSched, ts, msg)
|
||||
}
|
||||
if ts, ok := timestamps[latencylog.EventATXSoftware]; ok {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventATXSoftware, ts, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// 在 errqueue 里等两类 TX 时间戳。
|
||||
func (c *TCPConn) collectTXTimestampEvents() map[string]int64 {
|
||||
timestamps := make(map[string]int64, 2)
|
||||
//设置合理等待上限
|
||||
deadline := time.Now().Add(linuxTXTimestampWaitTimeout)
|
||||
|
||||
//轮询 errqueue 直到拿到两类时间戳,或超时,或遇到非 EAGAIN 错误。
|
||||
for len(timestamps) < 2 && time.Now().Before(deadline) {
|
||||
eventName, ts, err := c.recvTXTimestampOnce()
|
||||
if err != nil {
|
||||
if isWouldBlock(err) {
|
||||
time.Sleep(linuxTXTimestampPollInterval)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if eventName == "" || ts <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := timestamps[eventName]; !exists {
|
||||
timestamps[eventName] = ts
|
||||
}
|
||||
}
|
||||
|
||||
return timestamps
|
||||
}
|
||||
|
||||
// recvTXTimestampOnce 从 errqueue 读一次时间戳事件。
|
||||
func (c *TCPConn) recvTXTimestampOnce() (string, int64, error) {
|
||||
var (
|
||||
eventName string // 事件名,例如 A_TX_SCHED 或 A_TX_SOFTWARE。
|
||||
tsUnixNS int64 // 时间戳的 UnixNano 表示。
|
||||
opErr error
|
||||
)
|
||||
|
||||
err := c.raw.Control(func(fd uintptr) {
|
||||
//设置足够大的 oob buffer 来接收控制消息,调用 recvmsg 从 errqueue 读一条消息。
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
//recvmsg 的 flags 里必须带 MSG_ERRQUEUE,才能从 errqueue 里读消息,非阻塞模式下如果没有消息可读会返回 EAGAIN。
|
||||
_, oobn, _, _, recvErr := syscall.Recvmsg(int(fd), nil, oob, syscall.MSG_ERRQUEUE|syscall.MSG_DONTWAIT)
|
||||
if recvErr != nil {
|
||||
opErr = recvErr
|
||||
return
|
||||
}
|
||||
//解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
|
||||
eventName, tsUnixNS = parseTXTimestampControlMessages(oob[:oobn])
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if opErr != nil {
|
||||
return "", 0, opErr
|
||||
}
|
||||
|
||||
return eventName, tsUnixNS, nil //如果成功拿到时间戳事件,eventName 会是 A_TX_SCHED 或 A_TX_SOFTWARE 之一,tsUnixNS 是对应的时间戳;如果没有拿到事件或时间戳无效,eventName 会是空字符串,tsUnixNS 会是 0。
|
||||
}
|
||||
|
||||
// 把底层时间戳映射成日志事件名。
|
||||
func parseTXTimestampControlMessages(oob []byte) (string, int64) {
|
||||
if len(oob) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
//解析控制消息,看看是不是我们关心的 TX 时间戳事件,如果是就拿到事件名和时间戳。
|
||||
controlMessages, err := syscall.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var (
|
||||
tsUnixNS int64 //时间戳的 UnixNano 表示。
|
||||
tsKind uint32 //extended err里,告诉我们这个时间戳是 sched 还是 software。
|
||||
hasTS bool // 是否拿到时间戳了。
|
||||
hasKind bool // 是否拿到时间戳类型了。
|
||||
)
|
||||
//一个 recvmsg 可能会收到多个控制消息,循环找我们关心的时间戳事件,拿到时间戳和事件类型。
|
||||
for _, controlMessage := range controlMessages {
|
||||
switch {
|
||||
case controlMessage.Header.Level == syscall.SOL_SOCKET && controlMessage.Header.Type == linuxSCMTimestampingNew:
|
||||
if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
|
||||
tsUnixNS = ts
|
||||
hasTS = true
|
||||
}
|
||||
case isSocketExtendedErr(controlMessage): //判断时间戳是否进入了errqueue,
|
||||
if info, ok := parseSocketExtendedErrInfo(controlMessage.Data); ok {
|
||||
tsKind = info //时间戳类型被内核放在 extended err 的附加信息里,解析出来。
|
||||
hasKind = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTS || !hasKind {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
switch tsKind { //把内核的时间戳类型映射成日志事件名。(记录时只关心 sched 和 software 两类时间戳)
|
||||
case linuxSCMTstampSched:
|
||||
return latencylog.EventATXSched, tsUnixNS
|
||||
case linuxSCMTstampSnd:
|
||||
return latencylog.EventATXSoftware, tsUnixNS
|
||||
default:
|
||||
return "", 0
|
||||
}
|
||||
}
|
||||
|
||||
// 判断控制消息是否来自 socket extended err。
|
||||
// 内核产生的时间戳并不会混合在普通的数据流里,而是被包装成一种特殊的“错误消息”丢进 Error Queue。
|
||||
func isSocketExtendedErr(controlMessage syscall.SocketControlMessage) bool {
|
||||
switch {
|
||||
case controlMessage.Header.Level == syscall.SOL_IP && controlMessage.Header.Type == syscall.IP_RECVERR:
|
||||
return true
|
||||
case controlMessage.Header.Level == syscall.SOL_IPV6 && controlMessage.Header.Type == syscall.IPV6_RECVERR:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 从 socket extended err 的数据里取 origin timestamping 信息。
|
||||
func parseSocketExtendedErrInfo(data []byte) (uint32, bool) {
|
||||
if len(data) < 16 {
|
||||
return 0, false
|
||||
}
|
||||
if data[4] != linuxSOEEOriginTimestamping {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return binary.NativeEndian.Uint32(data[8:12]), true
|
||||
}
|
||||
|
||||
// 读一条完整消息,并记录 B_RX_SOFTWARE。
|
||||
func (c *TCPConn) receiveMessageLinux() (protocol.Message, error) {
|
||||
payload, rxTimestamp, err := c.readFrameLinux()
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("protocol: read frame: %w", err)
|
||||
}
|
||||
|
||||
msg, err := protocol.DecodeMessage(payload)
|
||||
if err != nil {
|
||||
return protocol.Message{}, fmt.Errorf("protocol: decode message: %w", err)
|
||||
}
|
||||
|
||||
if rxTimestamp > 0 {
|
||||
latencylog.LogMessageEventAt(c.logger, c.nodeRole, c.nodeID, latencylog.EventBRXSoftware, rxTimestamp, msg)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// readFrameLinux 先读 4 字节长度,再读整条 payload。
|
||||
func (c *TCPConn) readFrameLinux() ([]byte, int64, error) {
|
||||
var frameHeader [4]byte
|
||||
rxTimestamp, err := c.readFullLinux(frameHeader[:])
|
||||
if err != nil {
|
||||
return nil, rxTimestamp, err
|
||||
}
|
||||
|
||||
size := binary.BigEndian.Uint32(frameHeader[:])
|
||||
switch {
|
||||
case size == 0:
|
||||
return nil, rxTimestamp, protocol.ErrInvalidFrameLength
|
||||
case size > protocol.MaxFrameSize:
|
||||
return nil, rxTimestamp, protocol.ErrFrameTooLarge
|
||||
}
|
||||
|
||||
payload := make([]byte, int(size))
|
||||
bodyTimestamp, err := c.readFullLinux(payload)
|
||||
if rxTimestamp == 0 {
|
||||
rxTimestamp = bodyTimestamp
|
||||
}
|
||||
if err != nil {
|
||||
return nil, rxTimestamp, err
|
||||
}
|
||||
|
||||
return payload, rxTimestamp, nil
|
||||
}
|
||||
|
||||
// 读满 buf,并保留首个 RX_SOFTWARE(返回进入tcp协议栈的时间戳)。
|
||||
func (c *TCPConn) readFullLinux(buf []byte) (int64, error) {
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
offset int
|
||||
firstRXTime int64
|
||||
)
|
||||
|
||||
for offset < len(buf) {
|
||||
n, rxTimestamp, err := c.recvmsgLinux(buf[offset:])
|
||||
if firstRXTime == 0 && rxTimestamp > 0 {
|
||||
firstRXTime = rxTimestamp
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) && offset > 0 {
|
||||
return firstRXTime, io.ErrUnexpectedEOF
|
||||
}
|
||||
return firstRXTime, err
|
||||
}
|
||||
|
||||
offset += n
|
||||
}
|
||||
|
||||
return firstRXTime, nil
|
||||
}
|
||||
|
||||
// recvmsgLinux 用 recvmsg 同时读取数据和控制消息。
|
||||
func (c *TCPConn) recvmsgLinux(buf []byte) (int, int64, error) {
|
||||
var (
|
||||
n int
|
||||
rxTimeNS int64
|
||||
opErr error
|
||||
)
|
||||
|
||||
err := c.raw.Read(func(fd uintptr) bool {
|
||||
oob := make([]byte, linuxTimestampControlBufferSize)
|
||||
readN, oobN, _, _, recvErr := syscall.Recvmsg(int(fd), buf, oob, 0)
|
||||
switch {
|
||||
case recvErr == nil:
|
||||
if readN == 0 {
|
||||
opErr = io.EOF
|
||||
return true
|
||||
}
|
||||
n = readN
|
||||
rxTimeNS = parseRXTimestampControlMessages(oob[:oobN])
|
||||
return true
|
||||
case errors.Is(recvErr, syscall.EAGAIN), errors.Is(recvErr, syscall.EWOULDBLOCK):
|
||||
return false
|
||||
default:
|
||||
opErr = recvErr
|
||||
return true
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if opErr != nil {
|
||||
return 0, 0, opErr
|
||||
}
|
||||
|
||||
return n, rxTimeNS, nil
|
||||
}
|
||||
|
||||
// 从控制消息里取 RX_SOFTWARE。
|
||||
func parseRXTimestampControlMessages(oob []byte) int64 {
|
||||
if len(oob) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
controlMessages, err := syscall.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, controlMessage := range controlMessages {
|
||||
if controlMessage.Header.Level != syscall.SOL_SOCKET || controlMessage.Header.Type != linuxSCMTimestampingNew {
|
||||
continue
|
||||
}
|
||||
|
||||
if ts := parseSCMTimestampingData(controlMessage.Data); ts > 0 {
|
||||
return ts
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// 取第一个非零 timespec。
|
||||
func parseSCMTimestampingData(data []byte) int64 {
|
||||
const timespec64Size = 16
|
||||
|
||||
for offset := 0; offset+timespec64Size <= len(data); offset += timespec64Size {
|
||||
sec := int64(binary.NativeEndian.Uint64(data[offset : offset+8]))
|
||||
nsec := int64(binary.NativeEndian.Uint64(data[offset+8 : offset+16]))
|
||||
if sec == 0 && nsec == 0 {
|
||||
continue
|
||||
}
|
||||
return sec*int64(time.Second) + nsec
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// 判断错误是否是 EAGAIN 或 EWOULDBLOCK。
|
||||
func isWouldBlock(err error) bool {
|
||||
return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK)
|
||||
}
|
||||
140
cmd/internal/transport/tcp_linux_test.go
Normal file
140
cmd/internal/transport/tcp_linux_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
//go:build linux
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
func TestLinuxTimestampingRecordsKernelEvents(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 41,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello over tcp"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 42,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x00, 0x01, 0x02, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientConn, serverConn := newTCPPair(t)
|
||||
|
||||
senderLogger := &recordingLogger{}
|
||||
receiverLogger := &recordingLogger{}
|
||||
sender, err := NewTCPConn(
|
||||
clientConn,
|
||||
WithLogger(senderLogger, latencylog.NodeRolePeer, "peer-a"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(sender) error = %v", err)
|
||||
}
|
||||
receiver, err := NewTCPConn(
|
||||
serverConn,
|
||||
WithLogger(receiverLogger, latencylog.NodeRolePeer, "peer-b"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(receiver) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
})
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
|
||||
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSched, tt.msg.ID)
|
||||
assertHasEvent(t, senderLogger.Events(), latencylog.EventATXSoftware, tt.msg.ID)
|
||||
assertHasEvent(t, receiverLogger.Events(), latencylog.EventBRXSoftware, tt.msg.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTCPPair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
|
||||
type acceptResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
accepted := make(chan acceptResult, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
accepted <- acceptResult{conn: conn, err: acceptErr}
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp", listener.Addr().String())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
t.Fatalf("net.Dial() error = %v", err)
|
||||
}
|
||||
|
||||
result := <-accepted
|
||||
if err := listener.Close(); err != nil {
|
||||
t.Fatalf("listener.Close() error = %v", err)
|
||||
}
|
||||
if result.err != nil {
|
||||
_ = clientConn.Close()
|
||||
t.Fatalf("listener.Accept() error = %v", result.err)
|
||||
}
|
||||
|
||||
return clientConn, result.conn
|
||||
}
|
||||
|
||||
func assertHasEvent(t *testing.T, events []latencylog.Event, wantEvent string, wantMessageID uint64) {
|
||||
t.Helper()
|
||||
|
||||
for _, event := range events {
|
||||
if event.Event == wantEvent && event.MessageID == wantMessageID {
|
||||
if event.TsUnixNano <= 0 {
|
||||
t.Fatalf("event %s timestamp must be positive: %+v", wantEvent, event)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("missing event %s for message %d in %+v", wantEvent, wantMessageID, events)
|
||||
}
|
||||
416
cmd/internal/transport/tcp_test.go
Normal file
416
cmd/internal/transport/tcp_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
)
|
||||
|
||||
type recordingLogger struct {
|
||||
mu sync.Mutex
|
||||
events []latencylog.Event
|
||||
}
|
||||
|
||||
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.events = append(l.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *recordingLogger) Events() []latencylog.Event {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
return append([]latencylog.Event(nil), l.events...)
|
||||
}
|
||||
|
||||
type failingLogger struct{}
|
||||
|
||||
func (failingLogger) LogEvent(latencylog.Event) error {
|
||||
return errors.New("log failed")
|
||||
}
|
||||
|
||||
// TestSendReceiveMessage 验证 transport 可以在单条连接上正常收发 text 和 file 消息。
|
||||
func TestSendReceiveMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg protocol.Message
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
msg: protocol.Message{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "data.bin",
|
||||
Body: []byte{0x00, 0x10, 0xff},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
//创建一个容量为1的缓冲通道sendErr,用于接收发送操作的错误结果。
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(tt.msg) //发送消息,并将结果(错误或nil)发送到sendErr通道。
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil { //接受发送结果,如果发送过程中发生错误,则测试失败。
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLogsHandoffEvents(t *testing.T) {
|
||||
logger := &recordingLogger{}
|
||||
sender, receiver := newTransportConnPair(
|
||||
t,
|
||||
[]Option{WithLogger(logger, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
|
||||
events := logger.Events()
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("event count = %d, want 4", len(events))
|
||||
}
|
||||
if events[0].Event != latencylog.EventSendHandoffBegin {
|
||||
t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin)
|
||||
}
|
||||
if events[1].Event != latencylog.EventATXSched {
|
||||
t.Fatalf("second event = %q, want %q", events[1].Event, latencylog.EventATXSched)
|
||||
}
|
||||
if events[2].Event != latencylog.EventATXSoftware {
|
||||
t.Fatalf("third event = %q, want %q", events[2].Event, latencylog.EventATXSoftware)
|
||||
}
|
||||
if events[3].Event != latencylog.EventSendHandoffEnd {
|
||||
t.Fatalf("fourth event = %q, want %q", events[3].Event, latencylog.EventSendHandoffEnd)
|
||||
}
|
||||
for i, event := range events {
|
||||
if event.MessageID != msg.ID {
|
||||
t.Fatalf("event[%d] message ID = %d, want %d", i, event.MessageID, msg.ID)
|
||||
}
|
||||
}
|
||||
if events[0].NodeRole != latencylog.NodeRolePeer || events[0].NodeID != "peer-a" {
|
||||
t.Fatalf("node info = (%s,%s), want (%s,%s)", events[0].NodeRole, events[0].NodeID, latencylog.NodeRolePeer, "peer-a")
|
||||
}
|
||||
if events[0].TsUnixNano <= 0 || events[1].TsUnixNano <= 0 || events[2].TsUnixNano <= 0 || events[3].TsUnixNano <= 0 {
|
||||
t.Fatalf("timestamps must be positive: %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendIgnoresLoggerFailure(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(
|
||||
t,
|
||||
[]Option{WithLogger(failingLogger{}, latencylog.NodeRolePeer, "peer-a")},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 9,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
sendErr := make(chan error, 1)
|
||||
go func() {
|
||||
sendErr <- sender.Send(msg)
|
||||
}()
|
||||
|
||||
got, err := receiver.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
if err := <-sendErr; err != nil {
|
||||
t.Fatalf("Send() error = %v, want nil even if logger fails", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch: got %+v want %+v", got, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。
|
||||
func TestReceiveLoopDeliversMessages(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
want := []protocol.Message{
|
||||
{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
{
|
||||
Type: protocol.MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
got []protocol.Message
|
||||
)
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
got = append(got, msg)
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
for _, msg := range want {
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
}
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("received messages mismatch: got %+v want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentSendKeepsMessagesIntact 验证并发发送时消息不会因为写入交叉而损坏。
|
||||
func TestConcurrentSendKeepsMessagesIntact(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
// 发送方将多条消息并发发送到接收方,接收方通过 ReceiveLoop 逐条读取并验证每条消息的完整性和正确性。
|
||||
want := []protocol.Message{
|
||||
{Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("one")},
|
||||
{Type: protocol.MessageTypeText, ID: 2, From: "peer-a", To: "peer-b", Body: []byte("two")},
|
||||
{Type: protocol.MessageTypeText, ID: 3, From: "peer-a", To: "peer-b", Body: []byte("three")},
|
||||
{Type: protocol.MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", Body: []byte("four")},
|
||||
}
|
||||
|
||||
received := make(chan protocol.Message, len(want))
|
||||
readErr := make(chan error, 1)
|
||||
go func() { //异步地运行一个 goroutine
|
||||
for range want {
|
||||
msg, err := receiver.Receive()
|
||||
if err != nil {
|
||||
readErr <- err
|
||||
return
|
||||
}
|
||||
received <- msg
|
||||
}
|
||||
readErr <- nil
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, msg := range want {
|
||||
msg := msg
|
||||
wg.Add(1)
|
||||
go func() { //异步处理
|
||||
defer wg.Done()
|
||||
if err := sender.Send(msg); err != nil {
|
||||
t.Errorf("Send() error = %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if err := <-readErr; err != nil {
|
||||
t.Fatalf("Receive() error = %v", err)
|
||||
}
|
||||
|
||||
gotByID := make(map[uint64]protocol.Message, len(want))
|
||||
for range want {
|
||||
msg := <-received
|
||||
gotByID[msg.ID] = msg
|
||||
}
|
||||
|
||||
for _, msg := range want {
|
||||
got, ok := gotByID[msg.ID]
|
||||
if !ok {
|
||||
t.Fatalf("missing message with ID %d", msg.ID)
|
||||
}
|
||||
if !reflect.DeepEqual(got, msg) {
|
||||
t.Fatalf("message mismatch for ID %d: got %+v want %+v", msg.ID, got, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopStopsOnHandlerError 验证 handler 返回错误时 ReceiveLoop 会退出并关闭连接。
|
||||
func TestReceiveLoopStopsOnHandlerError(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
wantErr := errors.New("stop loop")
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
return wantErr
|
||||
})
|
||||
}()
|
||||
|
||||
first := protocol.Message{
|
||||
Type: protocol.MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
if err := sender.Send(first); err != nil {
|
||||
t.Fatalf("Send(first) error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop handler") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want handler context", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveLoopStopsOnReadError 验证对端关闭时 ReceiveLoop 会以读取错误退出。
|
||||
func TestReceiveLoopStopsOnReadError(t *testing.T) {
|
||||
sender, receiver := newTransportConnPair(t, nil, nil)
|
||||
|
||||
loopErr := make(chan error, 1)
|
||||
go func() {
|
||||
loopErr <- receiver.ReceiveLoop(func(msg protocol.Message) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
if err := sender.Close(); err != nil {
|
||||
t.Fatalf("sender.Close() error = %v", err)
|
||||
}
|
||||
|
||||
err := <-loopErr
|
||||
if err == nil {
|
||||
t.Fatal("ReceiveLoop() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "receive loop read") {
|
||||
t.Fatalf("ReceiveLoop() error = %v, want read context", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloseIsIdempotent 验证 Close 可以安全地被重复调用。
|
||||
func TestCloseIsIdempotent(t *testing.T) {
|
||||
conn, peer := newTransportConnPair(t, nil, nil)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(first) error = %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatalf("Close(second) error = %v, want nil", err)
|
||||
}
|
||||
if err := peer.Close(); err != nil && !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("peer.Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveReturnsWrappedReadError 验证 Receive 在底层读取失败时会保留 transport 上下文。
|
||||
func TestReceiveReturnsWrappedReadError(t *testing.T) {
|
||||
conn, peer := newTransportConnPair(t, nil, nil)
|
||||
go func() {
|
||||
_ = peer.Close()
|
||||
}()
|
||||
|
||||
_, err := conn.Receive()
|
||||
if err == nil {
|
||||
t.Fatal("Receive() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "transport: receive message") {
|
||||
t.Fatalf("Receive() error = %v, want wrapped receive error", err)
|
||||
}
|
||||
if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("Receive() error = %v, want underlying read failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTransportConnPair(t *testing.T, senderOpts []Option, receiverOpts []Option) (*TCPConn, *TCPConn) {
|
||||
t.Helper()
|
||||
|
||||
left, right := newTCPPair(t)
|
||||
|
||||
sender, err := NewTCPConn(left, senderOpts...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(sender) error = %v", err)
|
||||
}
|
||||
receiver, err := NewTCPConn(right, receiverOpts...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPConn(receiver) error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = sender.Close()
|
||||
_ = receiver.Close()
|
||||
})
|
||||
|
||||
return sender, receiver
|
||||
}
|
||||
Reference in New Issue
Block a user