feat:多跳(B->D->C->A)功能
This commit is contained in:
127
cmd/internal/server/udp_relay.go
Normal file
127
cmd/internal/server/udp_relay.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
// UDPRelay 负责在固定远端与多个客户端之间双向透明转发 KCP UDP datagram。
|
||||
type UDPRelay struct {
|
||||
conn net.PacketConn
|
||||
remote *net.UDPAddr
|
||||
|
||||
mu sync.RWMutex
|
||||
clients map[uint32]*net.UDPAddr
|
||||
}
|
||||
|
||||
// NewUDPRelay 创建一个绑定到给定 PacketConn 的透明 UDP relay。
|
||||
func NewUDPRelay(conn net.PacketConn, remote *net.UDPAddr) (*UDPRelay, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("server: nil udp relay conn")
|
||||
}
|
||||
if remote == nil {
|
||||
return nil, fmt.Errorf("server: nil udp relay remote")
|
||||
}
|
||||
|
||||
return &UDPRelay{
|
||||
conn: conn,
|
||||
remote: cloneUDPAddr(remote),
|
||||
clients: make(map[uint32]*net.UDPAddr),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve 持续双向转发原始 UDP datagram,不解析业务消息。
|
||||
func (r *UDPRelay) Serve() error {
|
||||
buffer := make([]byte, 64*1024)
|
||||
for {
|
||||
n, addr, err := r.conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
if isExpectedRelayServeExit(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("server: udp relay read packet: %w", err)
|
||||
}
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
log.Printf("udp relay dropped packet from non-udp addr %T", addr)
|
||||
continue
|
||||
}
|
||||
|
||||
payload := append([]byte(nil), buffer[:n]...)
|
||||
if sameUDPAddr(udpAddr, r.remote) {
|
||||
if err := r.forwardRemotePacket(payload); err != nil {
|
||||
log.Printf("udp relay failed forwarding remote packet from %s: %v", udpAddr, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.forwardClientPacket(udpAddr, payload); err != nil {
|
||||
log.Printf("udp relay failed forwarding client packet from %s: %v", udpAddr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UDPRelay) forwardClientPacket(addr *net.UDPAddr, payload []byte) error {
|
||||
convID, ok := transport.ParseKCPConversationID(payload)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing kcp conversation id")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.clients[convID] = cloneUDPAddr(addr)
|
||||
r.mu.Unlock()
|
||||
|
||||
if _, err := r.conn.WriteTo(payload, r.remote); err != nil {
|
||||
return fmt.Errorf("write conv %d to remote %s: %w", convID, r.remote, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UDPRelay) forwardRemotePacket(payload []byte) error {
|
||||
convID, ok := transport.ParseKCPConversationID(payload)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing kcp conversation id")
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
clientAddr := cloneUDPAddr(r.clients[convID])
|
||||
r.mu.RUnlock()
|
||||
if clientAddr == nil {
|
||||
return fmt.Errorf("unknown client for conv %d", convID)
|
||||
}
|
||||
|
||||
if _, err := r.conn.WriteTo(payload, clientAddr); err != nil {
|
||||
return fmt.Errorf("write conv %d to client %s: %w", convID, clientAddr, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr {
|
||||
if addr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipCopy := make([]byte, len(addr.IP))
|
||||
copy(ipCopy, addr.IP)
|
||||
|
||||
return &net.UDPAddr{
|
||||
IP: ipCopy,
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
}
|
||||
|
||||
func sameUDPAddr(left, right *net.UDPAddr) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
if left.Port != right.Port || left.Zone != right.Zone {
|
||||
return false
|
||||
}
|
||||
return left.IP.Equal(right.IP)
|
||||
}
|
||||
103
cmd/internal/server/udp_relay_test.go
Normal file
103
cmd/internal/server/udp_relay_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUDPRelayRoutesPacketsByKCPConversationID(t *testing.T) {
|
||||
remote, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenPacket(remote) error = %v", err)
|
||||
}
|
||||
defer remote.Close()
|
||||
|
||||
relayConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenPacket(relay) error = %v", err)
|
||||
}
|
||||
|
||||
relay, err := NewUDPRelay(relayConn, remote.LocalAddr().(*net.UDPAddr))
|
||||
if err != nil {
|
||||
_ = relayConn.Close()
|
||||
t.Fatalf("NewUDPRelay() error = %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if serveErr := relay.Serve(); serveErr != nil {
|
||||
t.Errorf("relay.Serve() error = %v", serveErr)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = relayConn.Close()
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
client1, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenPacket(client1) error = %v", err)
|
||||
}
|
||||
defer client1.Close()
|
||||
|
||||
client2, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenPacket(client2) error = %v", err)
|
||||
}
|
||||
defer client2.Close()
|
||||
|
||||
relayAddr := relayConn.LocalAddr()
|
||||
|
||||
sendPacket(t, client1, relayAddr, buildRelayTestPacket(1, []byte("client-one")))
|
||||
assertPacketReceived(t, remote, buildRelayTestPacket(1, []byte("client-one")))
|
||||
|
||||
sendPacket(t, client2, relayAddr, buildRelayTestPacket(2, []byte("client-two")))
|
||||
assertPacketReceived(t, remote, buildRelayTestPacket(2, []byte("client-two")))
|
||||
|
||||
sendPacket(t, remote, relayAddr, buildRelayTestPacket(2, []byte("reply-two")))
|
||||
assertPacketReceived(t, client2, buildRelayTestPacket(2, []byte("reply-two")))
|
||||
|
||||
sendPacket(t, remote, relayAddr, buildRelayTestPacket(1, []byte("reply-one")))
|
||||
assertPacketReceived(t, client1, buildRelayTestPacket(1, []byte("reply-one")))
|
||||
}
|
||||
|
||||
func buildRelayTestPacket(convID uint32, body []byte) []byte {
|
||||
packet := make([]byte, 4+len(body))
|
||||
binary.LittleEndian.PutUint32(packet[:4], convID)
|
||||
copy(packet[4:], body)
|
||||
return packet
|
||||
}
|
||||
|
||||
func sendPacket(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) {
|
||||
t.Helper()
|
||||
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("SetWriteDeadline() error = %v", err)
|
||||
}
|
||||
if _, err := conn.WriteTo(payload, addr); err != nil {
|
||||
t.Fatalf("WriteTo(%s) error = %v", addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertPacketReceived(t *testing.T, conn net.PacketConn, want []byte) {
|
||||
t.Helper()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("SetReadDeadline() error = %v", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
n, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrom() error = %v", err)
|
||||
}
|
||||
got := buffer[:n]
|
||||
if string(got) != string(want) {
|
||||
t.Fatalf("packet = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user