Files
OmniSocketGo/cmd/internal/server/udp_hub.go
2026-03-24 15:39:00 +08:00

186 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package server
import (
"fmt"
"log"
"net"
"sync"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
// UDPOption 用于配置 UDPHub 的可选行为。
type UDPOption func(*UDPHub)
// WithUDPLogger 为 UDP hub 注入时延日志记录器。
func WithUDPLogger(logger latencylog.Logger) UDPOption {
return func(hub *UDPHub) {
hub.logger = logger
}
}
// UDPHub 管理通过 UDP 注册的 peer并负责在它们之间转发消息。
// 与 TCP Hub 不同UDPHub 使用单个 net.UDPConn 与所有 peer 通信,
// 通过维护 peerID -> UDPAddr 映射表来寻址。
type UDPHub struct {
mu sync.RWMutex
peers map[string]*net.UDPAddr // peerID -> 对端 UDP 地址
addrs map[string]string // addr.String() -> peerID用于反查
conn *transport.UDPConn
logger latencylog.Logger
}
// NewUDPHub 创建一个新的 UDP 连接中心。
func NewUDPHub(conn *net.UDPConn, opts ...UDPOption) (*UDPHub, error) {
hub := &UDPHub{
peers: make(map[string]*net.UDPAddr),
addrs: make(map[string]string),
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
udpConn, err := transport.NewUDPConn(
conn,
nil,
transport.WithUDPLogger(hub.logger, latencylog.NodeRoleServer, "hub"),
)
if err != nil {
return nil, fmt.Errorf("server: create udp transport conn: %w", err)
}
hub.conn = udpConn
return hub, nil
}
// Serve 启动 UDP 接收主循环,持续读取消息并处理注册/转发。
// 此方法会阻塞,直到底层连接关闭或发生不可恢复的错误。
func (h *UDPHub) Serve() error {
return h.conn.ReceiveLoop(func(msg protocol.Message, addr *net.UDPAddr) error {
if err := h.handleMessage(msg, addr); err != nil {
log.Printf("udp hub: handle message from %s: %v", addr, err)
}
return nil // 不因为单条消息处理失败而退出主循环
})
}
// HasPeer 返回给定 ID 是否已注册到 hub。
func (h *UDPHub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
// handleMessage 处理从指定地址收到的消息。
func (h *UDPHub) handleMessage(msg protocol.Message, addr *net.UDPAddr) error {
switch msg.Type {
case protocol.MessageTypeRegister:
return h.registerPeer(msg, addr)
case protocol.MessageTypeText, protocol.MessageTypeFile:
return h.forwardMessage(msg, addr)
case protocol.MessageTypeError:
return h.sendErrorTo(addr, msg.From, "peers cannot send error messages")
default:
peerID := h.lookupPeerID(addr)
if peerID == "" {
peerID = msg.From
}
return h.sendErrorTo(addr, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type))
}
}
// registerPeer 处理 peer 的注册请求。
func (h *UDPHub) registerPeer(msg protocol.Message, addr *net.UDPAddr) error {
peerID := msg.From
if peerID == "" {
return h.sendErrorTo(addr, "", "register: missing peer id")
}
h.mu.Lock()
defer h.mu.Unlock()
// 如果同一个 peerID 从新地址注册,更新地址映射(支持 peer 重启换端口)。
if existingAddr, exists := h.peers[peerID]; exists {
// 清理旧地址的反查映射
delete(h.addrs, existingAddr.String())
}
h.peers[peerID] = addr
h.addrs[addr.String()] = peerID
log.Printf("udp hub: registered peer %s from %s", peerID, addr)
return nil
}
// forwardMessage 转发业务消息到目标 peer。
func (h *UDPHub) forwardMessage(msg protocol.Message, senderAddr *net.UDPAddr) error {
// 通过来源地址反查发送者 peerID
senderID := h.lookupPeerID(senderAddr)
if senderID == "" {
return h.sendErrorTo(senderAddr, msg.From, "not registered; send register first")
}
// server 覆盖 From不信任客户端自报身份
msg.From = senderID
// 查找目标 peer 地址
targetAddr := h.lookupAddr(msg.To)
if targetAddr == nil {
return h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("unknown target: %s", msg.To))
}
// 转发消息
if err := h.conn.SendTo(msg, targetAddr); err != nil {
// 转发失败,通知发送方
_ = h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("failed to forward to %s", msg.To))
return fmt.Errorf("forward to %s at %s: %w", msg.To, targetAddr, err)
}
return nil
}
// lookupPeerID 通过 UDP 地址反查 peerID。
func (h *UDPHub) lookupPeerID(addr *net.UDPAddr) string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.addrs[addr.String()]
}
// lookupAddr 通过 peerID 查找 UDP 地址。
func (h *UDPHub) lookupAddr(peerID string) *net.UDPAddr {
h.mu.RLock()
defer h.mu.RUnlock()
return h.peers[peerID]
}
// sendErrorTo 向指定地址发送错误消息。
func (h *UDPHub) sendErrorTo(addr *net.UDPAddr, to, message string) error {
if to == "" {
to = "unknown"
}
return h.conn.SendTo(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
}, addr)
}
// Close 关闭底层 UDP 连接。
func (h *UDPHub) Close() error {
return h.conn.Close()
}