package server import ( "encoding/binary" "net" "sync" "testing" "time" ) func TestUDPRelayRoutesPacketsByKCPConversationID(t *testing.T) { remote, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ListenPacket(remote) error = %v", err) } defer remote.Close() relayConn, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ListenPacket(relay) error = %v", err) } relay, err := NewUDPRelay(relayConn, remote.LocalAddr().(*net.UDPAddr)) if err != nil { _ = relayConn.Close() t.Fatalf("NewUDPRelay() error = %v", err) } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() if serveErr := relay.Serve(); serveErr != nil { t.Errorf("relay.Serve() error = %v", serveErr) } }() defer func() { _ = relayConn.Close() wg.Wait() }() client1, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ListenPacket(client1) error = %v", err) } defer client1.Close() client2, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ListenPacket(client2) error = %v", err) } defer client2.Close() relayAddr := relayConn.LocalAddr() sendPacket(t, client1, relayAddr, buildRelayTestPacket(1, []byte("client-one"))) assertPacketReceived(t, remote, buildRelayTestPacket(1, []byte("client-one"))) sendPacket(t, client2, relayAddr, buildRelayTestPacket(2, []byte("client-two"))) assertPacketReceived(t, remote, buildRelayTestPacket(2, []byte("client-two"))) sendPacket(t, remote, relayAddr, buildRelayTestPacket(2, []byte("reply-two"))) assertPacketReceived(t, client2, buildRelayTestPacket(2, []byte("reply-two"))) sendPacket(t, remote, relayAddr, buildRelayTestPacket(1, []byte("reply-one"))) assertPacketReceived(t, client1, buildRelayTestPacket(1, []byte("reply-one"))) } func buildRelayTestPacket(convID uint32, body []byte) []byte { packet := make([]byte, 4+len(body)) binary.LittleEndian.PutUint32(packet[:4], convID) copy(packet[4:], body) return packet } func sendPacket(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) { t.Helper() if err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil { t.Fatalf("SetWriteDeadline() error = %v", err) } if _, err := conn.WriteTo(payload, addr); err != nil { t.Fatalf("WriteTo(%s) error = %v", addr, err) } } func assertPacketReceived(t *testing.T, conn net.PacketConn, want []byte) { t.Helper() if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { t.Fatalf("SetReadDeadline() error = %v", err) } buffer := make([]byte, 1024) n, _, err := conn.ReadFrom(buffer) if err != nil { t.Fatalf("ReadFrom() error = %v", err) } got := buffer[:n] if string(got) != string(want) { t.Fatalf("packet = %v, want %v", got, want) } }