190 lines
4.9 KiB
Go
190 lines
4.9 KiB
Go
package server
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"sync"
|
||
|
||
"omnisocketgo/cmd/internal/latencylog"
|
||
"omnisocketgo/cmd/internal/protocol"
|
||
"omnisocketgo/cmd/internal/transport"
|
||
)
|
||
|
||
// UDPOption 用于配置 UDPHub 的可选行为。
|
||
type UDPOption func(*UDPHub)
|
||
|
||
// WithUDPLogger 为 UDP hub 注入时延日志记录器。
|
||
func WithUDPLogger(logger latencylog.Logger) UDPOption {
|
||
return func(hub *UDPHub) {
|
||
hub.logger = logger
|
||
}
|
||
}
|
||
|
||
// WithUDPTXTimestampDebugLogger 为 UDP hub 注入 TX errqueue 调试日志器。
|
||
func WithUDPTXTimestampDebugLogger(logger transport.TXTimestampDebugLogger) UDPOption {
|
||
return func(hub *UDPHub) {
|
||
hub.txTimestampDebugLogger = logger
|
||
}
|
||
}
|
||
|
||
// WithUDPLinuxTimestamping controls whether the UDP hub enables Linux timestamping.
|
||
func WithUDPLinuxTimestamping(enabled bool) UDPOption {
|
||
return func(hub *UDPHub) {
|
||
hub.linuxTimestampingEnabled = enabled
|
||
}
|
||
}
|
||
|
||
// UDPHub 管理通过 UDP 注册的 peer,并负责在它们之间转发消息。
|
||
type UDPHub struct {
|
||
mu sync.RWMutex
|
||
peers map[string]*net.UDPAddr
|
||
addrs map[string]string
|
||
|
||
conn *transport.UDPConn
|
||
logger latencylog.Logger
|
||
txTimestampDebugLogger transport.TXTimestampDebugLogger
|
||
linuxTimestampingEnabled bool
|
||
}
|
||
|
||
// NewUDPHub 创建一个新的 UDP 连接中心。
|
||
func NewUDPHub(conn *net.UDPConn, opts ...UDPOption) (*UDPHub, error) {
|
||
hub := &UDPHub{
|
||
peers: make(map[string]*net.UDPAddr),
|
||
addrs: make(map[string]string),
|
||
logger: latencylog.NoopLogger{},
|
||
linuxTimestampingEnabled: true,
|
||
}
|
||
|
||
for _, opt := range opts {
|
||
opt(hub)
|
||
}
|
||
|
||
if hub.logger == nil {
|
||
hub.logger = latencylog.NoopLogger{}
|
||
}
|
||
|
||
udpConn, err := transport.NewUDPConn(
|
||
conn,
|
||
nil,
|
||
transport.WithUDPLogger(hub.logger, latencylog.NodeRoleServer, "hub"),
|
||
transport.WithUDPLinuxTimestamping(hub.linuxTimestampingEnabled),
|
||
transport.WithUDPTXTimestampDebugLogger(hub.txTimestampDebugLogger),
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("server: create udp transport conn: %w", err)
|
||
}
|
||
|
||
hub.conn = udpConn
|
||
return hub, nil
|
||
}
|
||
|
||
// Serve 启动 UDP 接收主循环,持续读取消息并处理注册/转发。
|
||
func (h *UDPHub) Serve() error {
|
||
return h.conn.ReceiveLoop(func(msg protocol.Message, addr *net.UDPAddr) error {
|
||
if err := h.handleMessage(msg, addr); err != nil {
|
||
log.Printf("udp hub: handle message from %s: %v", addr, err)
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// HasPeer 返回给定 ID 是否已经注册到 hub。
|
||
func (h *UDPHub) HasPeer(peerID string) bool {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
_, ok := h.peers[peerID]
|
||
return ok
|
||
}
|
||
|
||
func (h *UDPHub) handleMessage(msg protocol.Message, addr *net.UDPAddr) error {
|
||
switch msg.Type {
|
||
case protocol.MessageTypeRegister:
|
||
return h.registerPeer(msg, addr)
|
||
case protocol.MessageTypeText, protocol.MessageTypeFile:
|
||
return h.forwardMessage(msg, addr)
|
||
case protocol.MessageTypeError:
|
||
return h.sendErrorTo(addr, msg.From, "peers cannot send error messages")
|
||
default:
|
||
peerID := h.lookupPeerID(addr)
|
||
if peerID == "" {
|
||
peerID = msg.From
|
||
}
|
||
return h.sendErrorTo(addr, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type))
|
||
}
|
||
}
|
||
|
||
func (h *UDPHub) registerPeer(msg protocol.Message, addr *net.UDPAddr) error {
|
||
peerID := msg.From
|
||
if peerID == "" {
|
||
return h.sendErrorTo(addr, "", "register: missing peer id")
|
||
}
|
||
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
if existingAddr, exists := h.peers[peerID]; exists {
|
||
delete(h.addrs, existingAddr.String())
|
||
}
|
||
|
||
h.peers[peerID] = addr
|
||
h.addrs[addr.String()] = peerID
|
||
log.Printf("udp hub: registered peer %s from %s", peerID, addr)
|
||
return nil
|
||
}
|
||
|
||
func (h *UDPHub) forwardMessage(msg protocol.Message, senderAddr *net.UDPAddr) error {
|
||
senderID := h.lookupPeerID(senderAddr)
|
||
if senderID == "" {
|
||
return h.sendErrorTo(senderAddr, msg.From, "not registered; send register first")
|
||
}
|
||
|
||
msg.From = senderID
|
||
|
||
targetAddr := h.lookupAddr(msg.To)
|
||
if targetAddr == nil {
|
||
return h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("unknown target: %s", msg.To))
|
||
}
|
||
|
||
if err := h.conn.SendTo(msg, targetAddr); err != nil {
|
||
_ = h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("failed to forward to %s", msg.To))
|
||
return fmt.Errorf("forward to %s at %s: %w", msg.To, targetAddr, err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// lookupPeerID 通过 UDP 地址反查 peerID。
|
||
func (h *UDPHub) lookupPeerID(addr *net.UDPAddr) string {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
return h.addrs[addr.String()]
|
||
}
|
||
|
||
// lookupAddr 通过 peerID 查找 UDP 地址。
|
||
func (h *UDPHub) lookupAddr(peerID string) *net.UDPAddr {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
return h.peers[peerID]
|
||
}
|
||
|
||
func (h *UDPHub) sendErrorTo(addr *net.UDPAddr, to, message string) error {
|
||
if to == "" {
|
||
to = "unknown"
|
||
}
|
||
return h.conn.SendTo(protocol.Message{
|
||
Type: protocol.MessageTypeError,
|
||
From: protocol.ServerPeerID,
|
||
To: to,
|
||
Body: []byte(message),
|
||
}, addr)
|
||
}
|
||
|
||
// Close 关闭底层 UDP 连接。
|
||
func (h *UDPHub) Close() error {
|
||
return h.conn.Close()
|
||
}
|