package server import ( "net" "testing" "time" "omnisocketgo/cmd/internal/protocol" "omnisocketgo/cmd/internal/transport" ) // TestUDPHubRegisterAndForward 验证 peer 注册后可以互相转发消息。 func TestUDPHubRegisterAndForward(t *testing.T) { hub, hubAddr := startUDPHub(t) _ = hub peerA := dialUDPPeer(t, hubAddr) peerB := dialUDPPeer(t, hubAddr) // 注册 peer-a sendUDPMessage(t, peerA, protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-a", To: protocol.ServerPeerID, }) // 注册 peer-b sendUDPMessage(t, peerB, protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-b", To: protocol.ServerPeerID, }) // 等待注册被处理 time.Sleep(50 * time.Millisecond) // peer-a 发送消息给 peer-b sendUDPMessage(t, peerA, protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello from peer-a"), }) // peer-b 应该收到消息 msg := receiveUDPMessage(t, peerB) if msg.Type != protocol.MessageTypeText { t.Fatalf("message type = %s, want text", msg.Type) } if msg.From != "peer-a" { t.Fatalf("message from = %s, want peer-a", msg.From) } if msg.To != "peer-b" { t.Fatalf("message to = %s, want peer-b", msg.To) } if string(msg.Body) != "hello from peer-a" { t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-a") } } // TestUDPHubRejectsUnregistered 验证未注册的 peer 发送业务消息会收到错误。 func TestUDPHubRejectsUnregistered(t *testing.T) { _, hubAddr := startUDPHub(t) peer := dialUDPPeer(t, hubAddr) // 直接发送业务消息而不注册 sendUDPMessage(t, peer, protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("should fail"), }) // 应该收到错误响应 msg := receiveUDPMessage(t, peer) if msg.Type != protocol.MessageTypeError { t.Fatalf("message type = %s, want error", msg.Type) } } // TestUDPHubRejectsUnknownTarget 验证发送到不存在的目标会返回错误。 func TestUDPHubRejectsUnknownTarget(t *testing.T) { _, hubAddr := startUDPHub(t) peer := dialUDPPeer(t, hubAddr) // 注册 sendUDPMessage(t, peer, protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-a", To: protocol.ServerPeerID, }) time.Sleep(50 * time.Millisecond) // 发送到不存在的目标 sendUDPMessage(t, peer, protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-nonexistent", Body: []byte("should fail"), }) // 应该收到错误响应 msg := receiveUDPMessage(t, peer) if msg.Type != protocol.MessageTypeError { t.Fatalf("message type = %s, want error", msg.Type) } } // TestUDPHubOverridesFromField 验证 server 会覆盖消息的 From 字段。 func TestUDPHubOverridesFromField(t *testing.T) { _, hubAddr := startUDPHub(t) peerA := dialUDPPeer(t, hubAddr) peerB := dialUDPPeer(t, hubAddr) sendUDPMessage(t, peerA, protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-a", To: protocol.ServerPeerID, }) sendUDPMessage(t, peerB, protocol.Message{ Type: protocol.MessageTypeRegister, From: "peer-b", To: protocol.ServerPeerID, }) time.Sleep(50 * time.Millisecond) // peer-a 伪造 From 为 "fake-id" sendUDPMessage(t, peerA, protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "fake-id", To: "peer-b", Body: []byte("spoofed"), }) msg := receiveUDPMessage(t, peerB) // server 应该用实际注册的 "peer-a" 覆盖 From if msg.From != "peer-a" { t.Fatalf("message from = %s, want peer-a (server should override)", msg.From) } } // startUDPHub 创建并启动一个 UDPHub,返回 hub 和监听地址。 func startUDPHub(t *testing.T) (*UDPHub, *net.UDPAddr) { t.Helper() addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ResolveUDPAddr() error = %v", err) } conn, err := net.ListenUDP("udp", addr) if err != nil { t.Fatalf("ListenUDP() error = %v", err) } hub, err := NewUDPHub(conn) if err != nil { _ = conn.Close() t.Fatalf("NewUDPHub() error = %v", err) } go func() { _ = hub.Serve() }() t.Cleanup(func() { _ = hub.Close() }) return hub, conn.LocalAddr().(*net.UDPAddr) } // dialUDPPeer 创建一个连接到指定地址的 UDP transport 连接。 func dialUDPPeer(t *testing.T, serverAddr *net.UDPAddr) *transport.UDPConn { t.Helper() raw, err := net.DialUDP("udp", nil, serverAddr) if err != nil { t.Fatalf("DialUDP() error = %v", err) } conn, err := transport.NewUDPConn(raw, nil) if err != nil { _ = raw.Close() t.Fatalf("NewUDPConn() error = %v", err) } t.Cleanup(func() { _ = conn.Close() }) return conn } // sendUDPMessage 发送一条 UDP 消息。 func sendUDPMessage(t *testing.T, conn *transport.UDPConn, msg protocol.Message) { t.Helper() if err := conn.Send(msg); err != nil { t.Fatalf("Send() error = %v", err) } } // receiveUDPMessage 接收一条 UDP 消息,带超时。 func receiveUDPMessage(t *testing.T, conn *transport.UDPConn) protocol.Message { t.Helper() type result struct { msg protocol.Message err error } ch := make(chan result, 1) go func() { msg, _, err := conn.Receive() ch <- result{msg: msg, err: err} }() select { case r := <-ch: if r.err != nil { t.Fatalf("Receive() error = %v", r.err) } return r.msg case <-time.After(2 * time.Second): t.Fatal("Receive() timed out after 2s") return protocol.Message{} } }