Feat: UDP 框架

This commit is contained in:
2026-03-24 15:39:00 +08:00
parent 44f39c12ed
commit c126b05961
12 changed files with 2119 additions and 0 deletions

View File

@@ -0,0 +1,185 @@
package server
import (
"fmt"
"log"
"net"
"sync"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
// UDPOption 用于配置 UDPHub 的可选行为。
type UDPOption func(*UDPHub)
// WithUDPLogger 为 UDP hub 注入时延日志记录器。
func WithUDPLogger(logger latencylog.Logger) UDPOption {
return func(hub *UDPHub) {
hub.logger = logger
}
}
// UDPHub 管理通过 UDP 注册的 peer并负责在它们之间转发消息。
// 与 TCP Hub 不同UDPHub 使用单个 net.UDPConn 与所有 peer 通信,
// 通过维护 peerID -> UDPAddr 映射表来寻址。
type UDPHub struct {
mu sync.RWMutex
peers map[string]*net.UDPAddr // peerID -> 对端 UDP 地址
addrs map[string]string // addr.String() -> peerID用于反查
conn *transport.UDPConn
logger latencylog.Logger
}
// NewUDPHub 创建一个新的 UDP 连接中心。
func NewUDPHub(conn *net.UDPConn, opts ...UDPOption) (*UDPHub, error) {
hub := &UDPHub{
peers: make(map[string]*net.UDPAddr),
addrs: make(map[string]string),
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
udpConn, err := transport.NewUDPConn(
conn,
nil,
transport.WithUDPLogger(hub.logger, latencylog.NodeRoleServer, "hub"),
)
if err != nil {
return nil, fmt.Errorf("server: create udp transport conn: %w", err)
}
hub.conn = udpConn
return hub, nil
}
// Serve 启动 UDP 接收主循环,持续读取消息并处理注册/转发。
// 此方法会阻塞,直到底层连接关闭或发生不可恢复的错误。
func (h *UDPHub) Serve() error {
return h.conn.ReceiveLoop(func(msg protocol.Message, addr *net.UDPAddr) error {
if err := h.handleMessage(msg, addr); err != nil {
log.Printf("udp hub: handle message from %s: %v", addr, err)
}
return nil // 不因为单条消息处理失败而退出主循环
})
}
// HasPeer 返回给定 ID 是否已注册到 hub。
func (h *UDPHub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
// handleMessage 处理从指定地址收到的消息。
func (h *UDPHub) handleMessage(msg protocol.Message, addr *net.UDPAddr) error {
switch msg.Type {
case protocol.MessageTypeRegister:
return h.registerPeer(msg, addr)
case protocol.MessageTypeText, protocol.MessageTypeFile:
return h.forwardMessage(msg, addr)
case protocol.MessageTypeError:
return h.sendErrorTo(addr, msg.From, "peers cannot send error messages")
default:
peerID := h.lookupPeerID(addr)
if peerID == "" {
peerID = msg.From
}
return h.sendErrorTo(addr, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type))
}
}
// registerPeer 处理 peer 的注册请求。
func (h *UDPHub) registerPeer(msg protocol.Message, addr *net.UDPAddr) error {
peerID := msg.From
if peerID == "" {
return h.sendErrorTo(addr, "", "register: missing peer id")
}
h.mu.Lock()
defer h.mu.Unlock()
// 如果同一个 peerID 从新地址注册,更新地址映射(支持 peer 重启换端口)。
if existingAddr, exists := h.peers[peerID]; exists {
// 清理旧地址的反查映射
delete(h.addrs, existingAddr.String())
}
h.peers[peerID] = addr
h.addrs[addr.String()] = peerID
log.Printf("udp hub: registered peer %s from %s", peerID, addr)
return nil
}
// forwardMessage 转发业务消息到目标 peer。
func (h *UDPHub) forwardMessage(msg protocol.Message, senderAddr *net.UDPAddr) error {
// 通过来源地址反查发送者 peerID
senderID := h.lookupPeerID(senderAddr)
if senderID == "" {
return h.sendErrorTo(senderAddr, msg.From, "not registered; send register first")
}
// server 覆盖 From不信任客户端自报身份
msg.From = senderID
// 查找目标 peer 地址
targetAddr := h.lookupAddr(msg.To)
if targetAddr == nil {
return h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("unknown target: %s", msg.To))
}
// 转发消息
if err := h.conn.SendTo(msg, targetAddr); err != nil {
// 转发失败,通知发送方
_ = h.sendErrorTo(senderAddr, senderID, fmt.Sprintf("failed to forward to %s", msg.To))
return fmt.Errorf("forward to %s at %s: %w", msg.To, targetAddr, err)
}
return nil
}
// lookupPeerID 通过 UDP 地址反查 peerID。
func (h *UDPHub) lookupPeerID(addr *net.UDPAddr) string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.addrs[addr.String()]
}
// lookupAddr 通过 peerID 查找 UDP 地址。
func (h *UDPHub) lookupAddr(peerID string) *net.UDPAddr {
h.mu.RLock()
defer h.mu.RUnlock()
return h.peers[peerID]
}
// sendErrorTo 向指定地址发送错误消息。
func (h *UDPHub) sendErrorTo(addr *net.UDPAddr, to, message string) error {
if to == "" {
to = "unknown"
}
return h.conn.SendTo(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
}, addr)
}
// Close 关闭底层 UDP 连接。
func (h *UDPHub) Close() error {
return h.conn.Close()
}

View File

@@ -0,0 +1,238 @@
package server
import (
"net"
"testing"
"time"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
// TestUDPHubRegisterAndForward 验证 peer 注册后可以互相转发消息。
func TestUDPHubRegisterAndForward(t *testing.T) {
hub, hubAddr := startUDPHub(t)
_ = hub
peerA := dialUDPPeer(t, hubAddr)
peerB := dialUDPPeer(t, hubAddr)
// 注册 peer-a
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
})
// 注册 peer-b
sendUDPMessage(t, peerB, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-b",
To: protocol.ServerPeerID,
})
// 等待注册被处理
time.Sleep(50 * time.Millisecond)
// peer-a 发送消息给 peer-b
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello from peer-a"),
})
// peer-b 应该收到消息
msg := receiveUDPMessage(t, peerB)
if msg.Type != protocol.MessageTypeText {
t.Fatalf("message type = %s, want text", msg.Type)
}
if msg.From != "peer-a" {
t.Fatalf("message from = %s, want peer-a", msg.From)
}
if msg.To != "peer-b" {
t.Fatalf("message to = %s, want peer-b", msg.To)
}
if string(msg.Body) != "hello from peer-a" {
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-a")
}
}
// TestUDPHubRejectsUnregistered 验证未注册的 peer 发送业务消息会收到错误。
func TestUDPHubRejectsUnregistered(t *testing.T) {
_, hubAddr := startUDPHub(t)
peer := dialUDPPeer(t, hubAddr)
// 直接发送业务消息而不注册
sendUDPMessage(t, peer, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("should fail"),
})
// 应该收到错误响应
msg := receiveUDPMessage(t, peer)
if msg.Type != protocol.MessageTypeError {
t.Fatalf("message type = %s, want error", msg.Type)
}
}
// TestUDPHubRejectsUnknownTarget 验证发送到不存在的目标会返回错误。
func TestUDPHubRejectsUnknownTarget(t *testing.T) {
_, hubAddr := startUDPHub(t)
peer := dialUDPPeer(t, hubAddr)
// 注册
sendUDPMessage(t, peer, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
})
time.Sleep(50 * time.Millisecond)
// 发送到不存在的目标
sendUDPMessage(t, peer, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-nonexistent",
Body: []byte("should fail"),
})
// 应该收到错误响应
msg := receiveUDPMessage(t, peer)
if msg.Type != protocol.MessageTypeError {
t.Fatalf("message type = %s, want error", msg.Type)
}
}
// TestUDPHubOverridesFromField 验证 server 会覆盖消息的 From 字段。
func TestUDPHubOverridesFromField(t *testing.T) {
_, hubAddr := startUDPHub(t)
peerA := dialUDPPeer(t, hubAddr)
peerB := dialUDPPeer(t, hubAddr)
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
})
sendUDPMessage(t, peerB, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-b",
To: protocol.ServerPeerID,
})
time.Sleep(50 * time.Millisecond)
// peer-a 伪造 From 为 "fake-id"
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "fake-id",
To: "peer-b",
Body: []byte("spoofed"),
})
msg := receiveUDPMessage(t, peerB)
// server 应该用实际注册的 "peer-a" 覆盖 From
if msg.From != "peer-a" {
t.Fatalf("message from = %s, want peer-a (server should override)", msg.From)
}
}
// startUDPHub 创建并启动一个 UDPHub返回 hub 和监听地址。
func startUDPHub(t *testing.T) (*UDPHub, *net.UDPAddr) {
t.Helper()
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("ListenUDP() error = %v", err)
}
hub, err := NewUDPHub(conn)
if err != nil {
_ = conn.Close()
t.Fatalf("NewUDPHub() error = %v", err)
}
go func() {
_ = hub.Serve()
}()
t.Cleanup(func() {
_ = hub.Close()
})
return hub, conn.LocalAddr().(*net.UDPAddr)
}
// dialUDPPeer 创建一个连接到指定地址的 UDP transport 连接。
func dialUDPPeer(t *testing.T, serverAddr *net.UDPAddr) *transport.UDPConn {
t.Helper()
raw, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
t.Fatalf("DialUDP() error = %v", err)
}
conn, err := transport.NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}
// sendUDPMessage 发送一条 UDP 消息。
func sendUDPMessage(t *testing.T, conn *transport.UDPConn, msg protocol.Message) {
t.Helper()
if err := conn.Send(msg); err != nil {
t.Fatalf("Send() error = %v", err)
}
}
// receiveUDPMessage 接收一条 UDP 消息,带超时。
func receiveUDPMessage(t *testing.T, conn *transport.UDPConn) protocol.Message {
t.Helper()
type result struct {
msg protocol.Message
err error
}
ch := make(chan result, 1)
go func() {
msg, _, err := conn.Receive()
ch <- result{msg: msg, err: err}
}()
select {
case r := <-ch:
if r.err != nil {
t.Fatalf("Receive() error = %v", r.err)
}
return r.msg
case <-time.After(2 * time.Second):
t.Fatal("Receive() timed out after 2s")
return protocol.Message{}
}
}