Files
OmniSocketGo/cmd/internal/server/udp_hub_test.go
2026-03-24 15:39:00 +08:00

239 lines
5.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package server
import (
"net"
"testing"
"time"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
// TestUDPHubRegisterAndForward 验证 peer 注册后可以互相转发消息。
func TestUDPHubRegisterAndForward(t *testing.T) {
hub, hubAddr := startUDPHub(t)
_ = hub
peerA := dialUDPPeer(t, hubAddr)
peerB := dialUDPPeer(t, hubAddr)
// 注册 peer-a
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
})
// 注册 peer-b
sendUDPMessage(t, peerB, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-b",
To: protocol.ServerPeerID,
})
// 等待注册被处理
time.Sleep(50 * time.Millisecond)
// peer-a 发送消息给 peer-b
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello from peer-a"),
})
// peer-b 应该收到消息
msg := receiveUDPMessage(t, peerB)
if msg.Type != protocol.MessageTypeText {
t.Fatalf("message type = %s, want text", msg.Type)
}
if msg.From != "peer-a" {
t.Fatalf("message from = %s, want peer-a", msg.From)
}
if msg.To != "peer-b" {
t.Fatalf("message to = %s, want peer-b", msg.To)
}
if string(msg.Body) != "hello from peer-a" {
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-a")
}
}
// TestUDPHubRejectsUnregistered 验证未注册的 peer 发送业务消息会收到错误。
func TestUDPHubRejectsUnregistered(t *testing.T) {
_, hubAddr := startUDPHub(t)
peer := dialUDPPeer(t, hubAddr)
// 直接发送业务消息而不注册
sendUDPMessage(t, peer, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("should fail"),
})
// 应该收到错误响应
msg := receiveUDPMessage(t, peer)
if msg.Type != protocol.MessageTypeError {
t.Fatalf("message type = %s, want error", msg.Type)
}
}
// TestUDPHubRejectsUnknownTarget 验证发送到不存在的目标会返回错误。
func TestUDPHubRejectsUnknownTarget(t *testing.T) {
_, hubAddr := startUDPHub(t)
peer := dialUDPPeer(t, hubAddr)
// 注册
sendUDPMessage(t, peer, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
})
time.Sleep(50 * time.Millisecond)
// 发送到不存在的目标
sendUDPMessage(t, peer, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-nonexistent",
Body: []byte("should fail"),
})
// 应该收到错误响应
msg := receiveUDPMessage(t, peer)
if msg.Type != protocol.MessageTypeError {
t.Fatalf("message type = %s, want error", msg.Type)
}
}
// TestUDPHubOverridesFromField 验证 server 会覆盖消息的 From 字段。
func TestUDPHubOverridesFromField(t *testing.T) {
_, hubAddr := startUDPHub(t)
peerA := dialUDPPeer(t, hubAddr)
peerB := dialUDPPeer(t, hubAddr)
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
})
sendUDPMessage(t, peerB, protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-b",
To: protocol.ServerPeerID,
})
time.Sleep(50 * time.Millisecond)
// peer-a 伪造 From 为 "fake-id"
sendUDPMessage(t, peerA, protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "fake-id",
To: "peer-b",
Body: []byte("spoofed"),
})
msg := receiveUDPMessage(t, peerB)
// server 应该用实际注册的 "peer-a" 覆盖 From
if msg.From != "peer-a" {
t.Fatalf("message from = %s, want peer-a (server should override)", msg.From)
}
}
// startUDPHub 创建并启动一个 UDPHub返回 hub 和监听地址。
func startUDPHub(t *testing.T) (*UDPHub, *net.UDPAddr) {
t.Helper()
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ResolveUDPAddr() error = %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("ListenUDP() error = %v", err)
}
hub, err := NewUDPHub(conn)
if err != nil {
_ = conn.Close()
t.Fatalf("NewUDPHub() error = %v", err)
}
go func() {
_ = hub.Serve()
}()
t.Cleanup(func() {
_ = hub.Close()
})
return hub, conn.LocalAddr().(*net.UDPAddr)
}
// dialUDPPeer 创建一个连接到指定地址的 UDP transport 连接。
func dialUDPPeer(t *testing.T, serverAddr *net.UDPAddr) *transport.UDPConn {
t.Helper()
raw, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
t.Fatalf("DialUDP() error = %v", err)
}
conn, err := transport.NewUDPConn(raw, nil)
if err != nil {
_ = raw.Close()
t.Fatalf("NewUDPConn() error = %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}
// sendUDPMessage 发送一条 UDP 消息。
func sendUDPMessage(t *testing.T, conn *transport.UDPConn, msg protocol.Message) {
t.Helper()
if err := conn.Send(msg); err != nil {
t.Fatalf("Send() error = %v", err)
}
}
// receiveUDPMessage 接收一条 UDP 消息,带超时。
func receiveUDPMessage(t *testing.T, conn *transport.UDPConn) protocol.Message {
t.Helper()
type result struct {
msg protocol.Message
err error
}
ch := make(chan result, 1)
go func() {
msg, _, err := conn.Receive()
ch <- result{msg: msg, err: err}
}()
select {
case r := <-ch:
if r.err != nil {
t.Fatalf("Receive() error = %v", r.err)
}
return r.msg
case <-time.After(2 * time.Second):
t.Fatal("Receive() timed out after 2s")
return protocol.Message{}
}
}