284 lines
6.6 KiB
Go
284 lines
6.6 KiB
Go
package server
|
|
|
|
import (
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
kcp "github.com/xtaci/kcp-go/v5"
|
|
|
|
"omnisocketgo/cmd/internal/latencylog"
|
|
"omnisocketgo/cmd/internal/protocol"
|
|
"omnisocketgo/cmd/internal/transport"
|
|
)
|
|
|
|
func TestUDPRelayKCPForwardAndReturn(t *testing.T) {
|
|
hub, hubAddr, hubCleanup := startKCPHubForRelay(t)
|
|
defer hubCleanup()
|
|
|
|
relayAddr, relay := startUDPRelay(t, hubAddr)
|
|
|
|
peerBConn := dialKCPPeer(t, hubAddr)
|
|
peerAConn := dialKCPPeer(t, relayAddr)
|
|
|
|
if err := peerBConn.Send(protocol.Message{
|
|
Type: protocol.MessageTypeRegister,
|
|
From: "peer-b",
|
|
To: protocol.ServerPeerID,
|
|
}); err != nil {
|
|
t.Fatalf("peerB register: %v", err)
|
|
}
|
|
if err := peerAConn.Send(protocol.Message{
|
|
Type: protocol.MessageTypeRegister,
|
|
From: "peer-a",
|
|
To: protocol.ServerPeerID,
|
|
}); err != nil {
|
|
t.Fatalf("peerA register: %v", err)
|
|
}
|
|
|
|
waitForRelay(t, func() bool {
|
|
return hub.HasPeer("peer-a") && hub.HasPeer("peer-b")
|
|
}, "both peers to be registered")
|
|
waitForRelay(t, func() bool {
|
|
relay.mu.RLock()
|
|
defer relay.mu.RUnlock()
|
|
return relay.clientAddr != nil
|
|
}, "relay to learn the downstream peer")
|
|
|
|
if err := peerBConn.Send(protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 1,
|
|
From: "peer-b",
|
|
To: "peer-a",
|
|
Body: []byte("hello from peer-b"),
|
|
}); err != nil {
|
|
t.Fatalf("peerB send text: %v", err)
|
|
}
|
|
|
|
msg, err := peerAConn.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerA receive: %v", err)
|
|
}
|
|
if msg.Type != protocol.MessageTypeText {
|
|
t.Fatalf("message type = %s, want text", msg.Type)
|
|
}
|
|
if msg.From != "peer-b" {
|
|
t.Fatalf("message from = %s, want peer-b", msg.From)
|
|
}
|
|
if string(msg.Body) != "hello from peer-b" {
|
|
t.Fatalf("message body = %q, want %q", string(msg.Body), "hello from peer-b")
|
|
}
|
|
|
|
if err := peerAConn.Send(protocol.Message{
|
|
Type: protocol.MessageTypeText,
|
|
ID: 2,
|
|
From: "peer-a",
|
|
To: "peer-b",
|
|
Body: []byte("reply from peer-a"),
|
|
}); err != nil {
|
|
t.Fatalf("peerA send text: %v", err)
|
|
}
|
|
|
|
msg2, err := peerBConn.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerB receive: %v", err)
|
|
}
|
|
if msg2.Type != protocol.MessageTypeText {
|
|
t.Fatalf("message type = %s, want text", msg2.Type)
|
|
}
|
|
if msg2.From != "peer-a" {
|
|
t.Fatalf("message from = %s, want peer-a", msg2.From)
|
|
}
|
|
if string(msg2.Body) != "reply from peer-a" {
|
|
t.Fatalf("message body = %q, want %q", string(msg2.Body), "reply from peer-a")
|
|
}
|
|
}
|
|
|
|
func TestUDPRelayKCPFileMessage(t *testing.T) {
|
|
hub, hubAddr, hubCleanup := startKCPHubForRelay(t)
|
|
defer hubCleanup()
|
|
|
|
relayAddr, relay := startUDPRelay(t, hubAddr)
|
|
|
|
peerBConn := dialKCPPeer(t, hubAddr)
|
|
peerAConn := dialKCPPeer(t, relayAddr)
|
|
|
|
if err := peerBConn.Send(protocol.Message{
|
|
Type: protocol.MessageTypeRegister,
|
|
From: "peer-b",
|
|
To: protocol.ServerPeerID,
|
|
}); err != nil {
|
|
t.Fatalf("peerB register: %v", err)
|
|
}
|
|
if err := peerAConn.Send(protocol.Message{
|
|
Type: protocol.MessageTypeRegister,
|
|
From: "peer-a",
|
|
To: protocol.ServerPeerID,
|
|
}); err != nil {
|
|
t.Fatalf("peerA register: %v", err)
|
|
}
|
|
|
|
waitForRelay(t, func() bool {
|
|
return hub.HasPeer("peer-a") && hub.HasPeer("peer-b")
|
|
}, "both peers to be registered")
|
|
waitForRelay(t, func() bool {
|
|
relay.mu.RLock()
|
|
defer relay.mu.RUnlock()
|
|
return relay.clientAddr != nil
|
|
}, "relay to learn the downstream peer")
|
|
|
|
if err := peerBConn.Send(protocol.Message{
|
|
Type: protocol.MessageTypeFile,
|
|
ID: 1,
|
|
From: "peer-b",
|
|
To: "peer-a",
|
|
FileName: "test.bin",
|
|
Body: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
|
}); err != nil {
|
|
t.Fatalf("peerB send file: %v", err)
|
|
}
|
|
|
|
msg, err := peerAConn.Receive()
|
|
if err != nil {
|
|
t.Fatalf("peerA receive: %v", err)
|
|
}
|
|
if msg.Type != protocol.MessageTypeFile {
|
|
t.Fatalf("message type = %s, want file", msg.Type)
|
|
}
|
|
if msg.FileName != "test.bin" {
|
|
t.Fatalf("file name = %q, want %q", msg.FileName, "test.bin")
|
|
}
|
|
if string(msg.Body) != string([]byte{0xDE, 0xAD, 0xBE, 0xEF}) {
|
|
t.Fatalf("file body = %v, want %v", msg.Body, []byte{0xDE, 0xAD, 0xBE, 0xEF})
|
|
}
|
|
}
|
|
|
|
func startKCPHubForRelay(t *testing.T) (*KCPHub, string, func()) {
|
|
t.Helper()
|
|
|
|
hub := NewKCPHub()
|
|
|
|
listener, packetConn, err := transport.ListenKCPSessions("127.0.0.1:0", "", nil, latencylog.NodeRoleServer, "hub")
|
|
if err != nil {
|
|
t.Fatalf("ListenKCPSessions() error = %v", err)
|
|
}
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
stop = make(chan struct{})
|
|
)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
session, acceptErr := listener.AcceptKCP()
|
|
if acceptErr != nil {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
default:
|
|
}
|
|
if strings.Contains(acceptErr.Error(), "closed") {
|
|
return
|
|
}
|
|
t.Errorf("AcceptKCP() error = %v", acceptErr)
|
|
return
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func(sess *kcp.UDPSession) {
|
|
defer wg.Done()
|
|
if serveErr := hub.ServeSession(sess); serveErr != nil {
|
|
msg := serveErr.Error()
|
|
if !strings.Contains(msg, "closed") && !strings.Contains(msg, "broken pipe") {
|
|
t.Logf("hub.ServeSession() ended with %v", serveErr)
|
|
}
|
|
}
|
|
}(session)
|
|
}
|
|
}()
|
|
|
|
cleanup := func() {
|
|
close(stop)
|
|
_ = listener.Close()
|
|
_ = packetConn.Close()
|
|
wg.Wait()
|
|
}
|
|
|
|
return hub, listener.Addr().String(), cleanup
|
|
}
|
|
|
|
func dialKCPPeer(t *testing.T, serverAddr string) *transport.KCPConn {
|
|
t.Helper()
|
|
|
|
session, err := transport.DialKCPSession(serverAddr, "", "", nil, latencylog.NodeRolePeer, "test")
|
|
if err != nil {
|
|
t.Fatalf("DialKCPSession(%s) error = %v", serverAddr, err)
|
|
}
|
|
|
|
conn, err := transport.NewKCPConn(session)
|
|
if err != nil {
|
|
_ = session.Close()
|
|
t.Fatalf("NewKCPConn() error = %v", err)
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
_ = conn.Close()
|
|
})
|
|
|
|
return conn
|
|
}
|
|
|
|
func startUDPRelay(t *testing.T, upstreamAddr string) (string, *UDPRelay) {
|
|
t.Helper()
|
|
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", upstreamAddr)
|
|
if err != nil {
|
|
t.Fatalf("ResolveUDPAddr(%s) error = %v", upstreamAddr, err)
|
|
}
|
|
|
|
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("ListenPacket() error = %v", err)
|
|
}
|
|
|
|
relay, err := NewUDPRelay(conn, remoteAddr)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
t.Fatalf("NewUDPRelay() error = %v", err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if serveErr := relay.Serve(); serveErr != nil && !isExpectedRelayServeExit(serveErr) {
|
|
t.Errorf("relay.Serve() error = %v", serveErr)
|
|
}
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
_ = relay.Close()
|
|
wg.Wait()
|
|
})
|
|
|
|
return conn.LocalAddr().String(), relay
|
|
}
|
|
|
|
func waitForRelay(t *testing.T, condition func() bool, description string) {
|
|
t.Helper()
|
|
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
if condition() {
|
|
return
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
t.Fatalf("timed out waiting for %s", description)
|
|
}
|