Files
OmniSocketGo/cmd/internal/server/udp_hub.go

190 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package server
import (
"fmt"
"log"
"net"
"sync"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
// UDPOption 用于配置 UDPHub 的可选行为。
type UDPOption func(*UDPHub)
// WithUDPLogger 为 UDP hub 注入时延日志记录器。
func WithUDPLogger(logger latencylog.Logger) UDPOption {
return func(hub *UDPHub) {
hub.logger = logger
}
}
// WithUDPTXTimestampDebugLogger 为 UDP hub 注入 TX errqueue 调试日志器。
func WithUDPTXTimestampDebugLogger(logger transport.TXTimestampDebugLogger) UDPOption {
return func(hub *UDPHub) {
hub.txTimestampDebugLogger = logger
}
}
// WithUDPLinuxTimestamping controls whether the UDP hub enables Linux timestamping.
func WithUDPLinuxTimestamping(enabled bool) UDPOption {
return func(hub *UDPHub) {
hub.linuxTimestampingEnabled = enabled
}
}
// UDPHub 管理通过 UDP 注册的 peer并负责在它们之间转发消息。
type UDPHub struct {
mu sync.RWMutex
peers map[string]*net.UDPAddr
addrs map[string]string
conn *transport.UDPConn
logger latencylog.Logger
txTimestampDebugLogger transport.TXTimestampDebugLogger
linuxTimestampingEnabled bool
}
// NewUDPHub 创建一个新的 UDP 连接中心。
func NewUDPHub(conn *net.UDPConn, opts ...UDPOption) (*UDPHub, error) {
hub := &UDPHub{
peers: make(map[string]*net.UDPAddr),
addrs: make(map[string]string),
logger: latencylog.NoopLogger{},
linuxTimestampingEnabled: true,
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
udpConn, err := transport.NewUDPConn(
conn,
nil,
transport.WithUDPLogger(hub.logger, latencylog.NodeRoleServer, "hub"),
transport.WithUDPLinuxTimestamping(hub.linuxTimestampingEnabled),
transport.WithUDPTXTimestampDebugLogger(hub.txTimestampDebugLogger),
)
if err != nil {
return nil, fmt.Errorf("server: create udp transport conn: %w", err)
}
hub.conn = udpConn
return hub, nil
}
// Serve 启动 UDP 接收主循环,持续读取消息并处理注册/转发。
func (h *UDPHub) Serve() error {
return h.conn.ReceiveLoop(func(msg protocol.Message, addr *net.UDPAddr) error {
if err := h.handleMessage(msg, addr); err != nil {
log.Printf("udp hub: handle message from %s: %v", addr, err)
}
return nil
})
}
// HasPeer 返回给定 ID 是否已经注册到 hub。
func (h *UDPHub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
func (h *UDPHub) handleMessage(msg protocol.Message, addr *net.UDPAddr) error {
switch msg.Type {
case protocol.MessageTypeRegister:
return h.registerPeer(msg, addr)
case protocol.MessageTypeText, protocol.MessageTypeFile:
return h.forwardMessage(msg, addr)
case protocol.MessageTypeError:
return h.sendErrorTo(addr, msg.From, "peers cannot send error messages")
default:
peerID := h.lookupPeerID(addr)
if peerID == "" {
peerID = msg.From
}
return h.sendErrorTo(addr, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type))
}
}
func (h *UDPHub) registerPeer(msg protocol.Message, addr *net.UDPAddr) error {
peerID := msg.From
if peerID == "" {
return h.sendErrorTo(addr, "", "register: missing peer id")
}
h.mu.Lock()
defer h.mu.Unlock()
if existingAddr, exists := h.peers[peerID]; exists {
delete(h.addrs, existingAddr.String())
}
h.peers[peerID] = addr
h.addrs[addr.String()] = peerID
log.Printf("udp hub: registered peer %s from %s", peerID, addr)
return nil
}
func (h *UDPHub) forwardMessage(msg protocol.Message, senderAddr *net.UDPAddr) error {
senderID := h.lookupPeerID(senderAddr)
if senderID == "" {
return h.sendErrorTo(senderAddr, msg.From, "not registered; send register first")
}
msg.From = senderID
targetAddr := h.lookupAddr(msg.To)
if targetAddr == nil {
return h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("unknown target: %s", msg.To))
}
if err := h.conn.SendTo(msg, targetAddr); err != nil {
_ = h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("failed to forward to %s", msg.To))
return fmt.Errorf("forward to %s at %s: %w", msg.To, targetAddr, err)
}
return nil
}
// lookupPeerID 通过 UDP 地址反查 peerID。
func (h *UDPHub) lookupPeerID(addr *net.UDPAddr) string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.addrs[addr.String()]
}
// lookupAddr 通过 peerID 查找 UDP 地址。
func (h *UDPHub) lookupAddr(peerID string) *net.UDPAddr {
h.mu.RLock()
defer h.mu.RUnlock()
return h.peers[peerID]
}
func (h *UDPHub) sendErrorTo(addr *net.UDPAddr, to, message string) error {
if to == "" {
to = "unknown"
}
return h.conn.SendTo(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
}, addr)
}
// Close 关闭底层 UDP 连接。
func (h *UDPHub) Close() error {
return h.conn.Close()
}