104 lines
2.7 KiB
Go
104 lines
2.7 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestUDPRelayRoutesPacketsByKCPConversationID(t *testing.T) {
|
|
remote, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("ListenPacket(remote) error = %v", err)
|
|
}
|
|
defer remote.Close()
|
|
|
|
relayConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("ListenPacket(relay) error = %v", err)
|
|
}
|
|
|
|
relay, err := NewUDPRelay(relayConn, remote.LocalAddr().(*net.UDPAddr))
|
|
if err != nil {
|
|
_ = relayConn.Close()
|
|
t.Fatalf("NewUDPRelay() error = %v", err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if serveErr := relay.Serve(); serveErr != nil {
|
|
t.Errorf("relay.Serve() error = %v", serveErr)
|
|
}
|
|
}()
|
|
defer func() {
|
|
_ = relayConn.Close()
|
|
wg.Wait()
|
|
}()
|
|
|
|
client1, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("ListenPacket(client1) error = %v", err)
|
|
}
|
|
defer client1.Close()
|
|
|
|
client2, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("ListenPacket(client2) error = %v", err)
|
|
}
|
|
defer client2.Close()
|
|
|
|
relayAddr := relayConn.LocalAddr()
|
|
|
|
sendPacket(t, client1, relayAddr, buildRelayTestPacket(1, []byte("client-one")))
|
|
assertPacketReceived(t, remote, buildRelayTestPacket(1, []byte("client-one")))
|
|
|
|
sendPacket(t, client2, relayAddr, buildRelayTestPacket(2, []byte("client-two")))
|
|
assertPacketReceived(t, remote, buildRelayTestPacket(2, []byte("client-two")))
|
|
|
|
sendPacket(t, remote, relayAddr, buildRelayTestPacket(2, []byte("reply-two")))
|
|
assertPacketReceived(t, client2, buildRelayTestPacket(2, []byte("reply-two")))
|
|
|
|
sendPacket(t, remote, relayAddr, buildRelayTestPacket(1, []byte("reply-one")))
|
|
assertPacketReceived(t, client1, buildRelayTestPacket(1, []byte("reply-one")))
|
|
}
|
|
|
|
func buildRelayTestPacket(convID uint32, body []byte) []byte {
|
|
packet := make([]byte, 4+len(body))
|
|
binary.LittleEndian.PutUint32(packet[:4], convID)
|
|
copy(packet[4:], body)
|
|
return packet
|
|
}
|
|
|
|
func sendPacket(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) {
|
|
t.Helper()
|
|
|
|
if err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
|
t.Fatalf("SetWriteDeadline() error = %v", err)
|
|
}
|
|
if _, err := conn.WriteTo(payload, addr); err != nil {
|
|
t.Fatalf("WriteTo(%s) error = %v", addr, err)
|
|
}
|
|
}
|
|
|
|
func assertPacketReceived(t *testing.T, conn net.PacketConn, want []byte) {
|
|
t.Helper()
|
|
|
|
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
|
t.Fatalf("SetReadDeadline() error = %v", err)
|
|
}
|
|
|
|
buffer := make([]byte, 1024)
|
|
n, _, err := conn.ReadFrom(buffer)
|
|
if err != nil {
|
|
t.Fatalf("ReadFrom() error = %v", err)
|
|
}
|
|
got := buffer[:n]
|
|
if string(got) != string(want) {
|
|
t.Fatalf("packet = %v, want %v", got, want)
|
|
}
|
|
}
|