239 lines
5.4 KiB
Go
239 lines
5.4 KiB
Go
package server
|
||
|
||
import (
|
||
"net"
|
||
"testing"
|
||
"time"
|
||
|
||
"omnisocketgo/cmd/internal/protocol"
|
||
"omnisocketgo/cmd/internal/transport"
|
||
)
|
||
|
||
// TestUDPHubRegisterAndForward 验证 peer 注册后可以互相转发消息。
|
||
func TestUDPHubRegisterAndForward(t *testing.T) {
|
||
hub, hubAddr := startUDPHub(t)
|
||
_ = hub
|
||
|
||
peerA := dialUDPPeer(t, hubAddr)
|
||
peerB := dialUDPPeer(t, hubAddr)
|
||
|
||
// 注册 peer-a
|
||
sendUDPMessage(t, peerA, protocol.Message{
|
||
Type: protocol.MessageTypeRegister,
|
||
From: "peer-a",
|
||
To: protocol.ServerPeerID,
|
||
})
|
||
|
||
// 注册 peer-b
|
||
sendUDPMessage(t, peerB, protocol.Message{
|
||
Type: protocol.MessageTypeRegister,
|
||
From: "peer-b",
|
||
To: protocol.ServerPeerID,
|
||
})
|
||
|
||
// 等待注册被处理
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
// peer-a 发送消息给 peer-b
|
||
sendUDPMessage(t, peerA, protocol.Message{
|
||
Type: protocol.MessageTypeText,
|
||
ID: 1,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte("hello from peer-a"),
|
||
})
|
||
|
||
// peer-b 应该收到消息
|
||
msg := receiveUDPMessage(t, peerB)
|
||
if msg.Type != protocol.MessageTypeText {
|
||
t.Fatalf("message type = %s, want text", msg.Type)
|
||
}
|
||
if msg.From != "peer-a" {
|
||
t.Fatalf("message from = %s, want peer-a", msg.From)
|
||
}
|
||
if msg.To != "peer-b" {
|
||
t.Fatalf("message to = %s, want peer-b", msg.To)
|
||
}
|
||
if string(msg.Body) != "hello from peer-a" {
|
||
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-a")
|
||
}
|
||
}
|
||
|
||
// TestUDPHubRejectsUnregistered 验证未注册的 peer 发送业务消息会收到错误。
|
||
func TestUDPHubRejectsUnregistered(t *testing.T) {
|
||
_, hubAddr := startUDPHub(t)
|
||
|
||
peer := dialUDPPeer(t, hubAddr)
|
||
|
||
// 直接发送业务消息而不注册
|
||
sendUDPMessage(t, peer, protocol.Message{
|
||
Type: protocol.MessageTypeText,
|
||
ID: 1,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte("should fail"),
|
||
})
|
||
|
||
// 应该收到错误响应
|
||
msg := receiveUDPMessage(t, peer)
|
||
if msg.Type != protocol.MessageTypeError {
|
||
t.Fatalf("message type = %s, want error", msg.Type)
|
||
}
|
||
}
|
||
|
||
// TestUDPHubRejectsUnknownTarget 验证发送到不存在的目标会返回错误。
|
||
func TestUDPHubRejectsUnknownTarget(t *testing.T) {
|
||
_, hubAddr := startUDPHub(t)
|
||
|
||
peer := dialUDPPeer(t, hubAddr)
|
||
|
||
// 注册
|
||
sendUDPMessage(t, peer, protocol.Message{
|
||
Type: protocol.MessageTypeRegister,
|
||
From: "peer-a",
|
||
To: protocol.ServerPeerID,
|
||
})
|
||
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
// 发送到不存在的目标
|
||
sendUDPMessage(t, peer, protocol.Message{
|
||
Type: protocol.MessageTypeText,
|
||
ID: 1,
|
||
From: "peer-a",
|
||
To: "peer-nonexistent",
|
||
Body: []byte("should fail"),
|
||
})
|
||
|
||
// 应该收到错误响应
|
||
msg := receiveUDPMessage(t, peer)
|
||
if msg.Type != protocol.MessageTypeError {
|
||
t.Fatalf("message type = %s, want error", msg.Type)
|
||
}
|
||
}
|
||
|
||
// TestUDPHubOverridesFromField 验证 server 会覆盖消息的 From 字段。
|
||
func TestUDPHubOverridesFromField(t *testing.T) {
|
||
_, hubAddr := startUDPHub(t)
|
||
|
||
peerA := dialUDPPeer(t, hubAddr)
|
||
peerB := dialUDPPeer(t, hubAddr)
|
||
|
||
sendUDPMessage(t, peerA, protocol.Message{
|
||
Type: protocol.MessageTypeRegister,
|
||
From: "peer-a",
|
||
To: protocol.ServerPeerID,
|
||
})
|
||
sendUDPMessage(t, peerB, protocol.Message{
|
||
Type: protocol.MessageTypeRegister,
|
||
From: "peer-b",
|
||
To: protocol.ServerPeerID,
|
||
})
|
||
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
// peer-a 伪造 From 为 "fake-id"
|
||
sendUDPMessage(t, peerA, protocol.Message{
|
||
Type: protocol.MessageTypeText,
|
||
ID: 1,
|
||
From: "fake-id",
|
||
To: "peer-b",
|
||
Body: []byte("spoofed"),
|
||
})
|
||
|
||
msg := receiveUDPMessage(t, peerB)
|
||
// server 应该用实际注册的 "peer-a" 覆盖 From
|
||
if msg.From != "peer-a" {
|
||
t.Fatalf("message from = %s, want peer-a (server should override)", msg.From)
|
||
}
|
||
}
|
||
|
||
// startUDPHub 创建并启动一个 UDPHub,返回 hub 和监听地址。
|
||
func startUDPHub(t *testing.T) (*UDPHub, *net.UDPAddr) {
|
||
t.Helper()
|
||
|
||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("ResolveUDPAddr() error = %v", err)
|
||
}
|
||
|
||
conn, err := net.ListenUDP("udp", addr)
|
||
if err != nil {
|
||
t.Fatalf("ListenUDP() error = %v", err)
|
||
}
|
||
|
||
hub, err := NewUDPHub(conn)
|
||
if err != nil {
|
||
_ = conn.Close()
|
||
t.Fatalf("NewUDPHub() error = %v", err)
|
||
}
|
||
|
||
go func() {
|
||
_ = hub.Serve()
|
||
}()
|
||
|
||
t.Cleanup(func() {
|
||
_ = hub.Close()
|
||
})
|
||
|
||
return hub, conn.LocalAddr().(*net.UDPAddr)
|
||
}
|
||
|
||
// dialUDPPeer 创建一个连接到指定地址的 UDP transport 连接。
|
||
func dialUDPPeer(t *testing.T, serverAddr *net.UDPAddr) *transport.UDPConn {
|
||
t.Helper()
|
||
|
||
raw, err := net.DialUDP("udp", nil, serverAddr)
|
||
if err != nil {
|
||
t.Fatalf("DialUDP() error = %v", err)
|
||
}
|
||
|
||
conn, err := transport.NewUDPConn(raw, nil)
|
||
if err != nil {
|
||
_ = raw.Close()
|
||
t.Fatalf("NewUDPConn() error = %v", err)
|
||
}
|
||
|
||
t.Cleanup(func() {
|
||
_ = conn.Close()
|
||
})
|
||
|
||
return conn
|
||
}
|
||
|
||
// sendUDPMessage 发送一条 UDP 消息。
|
||
func sendUDPMessage(t *testing.T, conn *transport.UDPConn, msg protocol.Message) {
|
||
t.Helper()
|
||
|
||
if err := conn.Send(msg); err != nil {
|
||
t.Fatalf("Send() error = %v", err)
|
||
}
|
||
}
|
||
|
||
// receiveUDPMessage 接收一条 UDP 消息,带超时。
|
||
func receiveUDPMessage(t *testing.T, conn *transport.UDPConn) protocol.Message {
|
||
t.Helper()
|
||
|
||
type result struct {
|
||
msg protocol.Message
|
||
err error
|
||
}
|
||
|
||
ch := make(chan result, 1)
|
||
go func() {
|
||
msg, _, err := conn.Receive()
|
||
ch <- result{msg: msg, err: err}
|
||
}()
|
||
|
||
select {
|
||
case r := <-ch:
|
||
if r.err != nil {
|
||
t.Fatalf("Receive() error = %v", r.err)
|
||
}
|
||
return r.msg
|
||
case <-time.After(2 * time.Second):
|
||
t.Fatal("Receive() timed out after 2s")
|
||
return protocol.Message{}
|
||
}
|
||
}
|