feat:新增server upd转发功能

This commit is contained in:
nnbcccscdscdsc
2026-03-28 13:13:17 +08:00
parent 8e2bd0ffc6
commit 34d2f574ac
3 changed files with 383 additions and 166 deletions

View File

@@ -6,122 +6,114 @@ import (
"net"
"sync"
"omnisocketgo/cmd/internal/transport"
"omnisocketgo/cmd/internal/protocol"
)
// UDPRelay 负责在固定远端与多个客户端之间双向透明转发 KCP UDP datagram
type UDPRelay struct {
conn net.PacketConn
remote *net.UDPAddr
// udpRelayBufSize 是 relay 接收缓冲区大小,与 UDP transport 层保持一致
const udpRelayBufSize = protocol.MaxFrameSize + 1024
mu sync.RWMutex
clients map[uint32]*net.UDPAddr
// UDPRelay 是一个透明的双向 UDP 转发器。
// 它在下游(客户端 A和上游server D之间原样转发 UDP 数据报,
// 不解析也不修改协议内容。
type UDPRelay struct {
downstream *net.UDPConn // 监听端口,等待下游客户端连接
upstream *net.UDPConn // 连接到上游 serverconnected socket
mu sync.RWMutex
clientAddr *net.UDPAddr // 下游客户端地址,从第一个下游包学习
}
// NewUDPRelay 创建一个绑定到给定 PacketConn 的透明 UDP relay。
func NewUDPRelay(conn net.PacketConn, remote *net.UDPAddr) (*UDPRelay, error) {
if conn == nil {
return nil, fmt.Errorf("server: nil udp relay conn")
// NewUDPRelay 创建一个新的 UDP relay。
// listenConn 是已经绑定好的监听 socket供下游客户端连接
// upstreamAddr 是上游 server D 的地址。
func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error) {
udpUpstreamAddr, err := net.ResolveUDPAddr("udp", upstreamAddr)
if err != nil {
return nil, fmt.Errorf("relay: resolve upstream addr %s: %w", upstreamAddr, err)
}
if remote == nil {
return nil, fmt.Errorf("server: nil udp relay remote")
upstreamConn, err := net.DialUDP("udp", nil, udpUpstreamAddr)
if err != nil {
return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err)
}
return &UDPRelay{
conn: conn,
remote: cloneUDPAddr(remote),
clients: make(map[uint32]*net.UDPAddr),
downstream: listenConn,
upstream: upstreamConn,
}, nil
}
// Serve 持续双向转发原始 UDP datagram不解析业务消息
// Serve 启动双向转发循环,阻塞直到任一方向出错
func (r *UDPRelay) Serve() error {
buffer := make([]byte, 64*1024)
errCh := make(chan error, 2)
go func() {
errCh <- r.forwardDownstreamToUpstream()
}()
go func() {
errCh <- r.forwardUpstreamToDownstream()
}()
err := <-errCh
// 关闭两个 conn 让另一个 goroutine 也退出
_ = r.downstream.Close()
_ = r.upstream.Close()
return err
}
// forwardDownstreamToUpstream 从下游读取并转发到上游。
func (r *UDPRelay) forwardDownstreamToUpstream() error {
buf := make([]byte, udpRelayBufSize)
for {
n, addr, err := r.conn.ReadFrom(buffer)
n, addr, err := r.downstream.ReadFromUDP(buf)
if err != nil {
if isExpectedRelayServeExit(err) {
return nil
}
return fmt.Errorf("server: udp relay read packet: %w", err)
return fmt.Errorf("relay: read downstream: %w", err)
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
log.Printf("udp relay dropped packet from non-udp addr %T", addr)
r.mu.Lock()
r.clientAddr = addr
r.mu.Unlock()
if _, err := r.upstream.Write(buf[:n]); err != nil {
return fmt.Errorf("relay: write upstream: %w", err)
}
log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr)
}
}
// forwardUpstreamToDownstream 从上游读取并转发到下游。
func (r *UDPRelay) forwardUpstreamToDownstream() error {
buf := make([]byte, udpRelayBufSize)
for {
n, err := r.upstream.Read(buf)
if err != nil {
return fmt.Errorf("relay: read upstream: %w", err)
}
r.mu.RLock()
addr := r.clientAddr
r.mu.RUnlock()
if addr == nil {
log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n)
continue
}
payload := append([]byte(nil), buffer[:n]...)
if sameUDPAddr(udpAddr, r.remote) {
if err := r.forwardRemotePacket(payload); err != nil {
log.Printf("udp relay failed forwarding remote packet from %s: %v", udpAddr, err)
}
continue
if _, err := r.downstream.WriteToUDP(buf[:n], addr); err != nil {
return fmt.Errorf("relay: write downstream to %s: %w", addr, err)
}
if err := r.forwardClientPacket(udpAddr, payload); err != nil {
log.Printf("udp relay failed forwarding client packet from %s: %v", udpAddr, err)
}
log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr)
}
}
func (r *UDPRelay) forwardClientPacket(addr *net.UDPAddr, payload []byte) error {
convID, ok := transport.ParseKCPConversationID(payload)
if !ok {
return fmt.Errorf("missing kcp conversation id")
// Close 关闭 relay 的上下游连接。
func (r *UDPRelay) Close() error {
err1 := r.downstream.Close()
err2 := r.upstream.Close()
if err1 != nil {
return err1
}
r.mu.Lock()
r.clients[convID] = cloneUDPAddr(addr)
r.mu.Unlock()
if _, err := r.conn.WriteTo(payload, r.remote); err != nil {
return fmt.Errorf("write conv %d to remote %s: %w", convID, r.remote, err)
}
return nil
}
func (r *UDPRelay) forwardRemotePacket(payload []byte) error {
convID, ok := transport.ParseKCPConversationID(payload)
if !ok {
return fmt.Errorf("missing kcp conversation id")
}
r.mu.RLock()
clientAddr := cloneUDPAddr(r.clients[convID])
r.mu.RUnlock()
if clientAddr == nil {
return fmt.Errorf("unknown client for conv %d", convID)
}
if _, err := r.conn.WriteTo(payload, clientAddr); err != nil {
return fmt.Errorf("write conv %d to client %s: %w", convID, clientAddr, err)
}
return nil
}
func cloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr {
if addr == nil {
return nil
}
ipCopy := make([]byte, len(addr.IP))
copy(ipCopy, addr.IP)
return &net.UDPAddr{
IP: ipCopy,
Port: addr.Port,
Zone: addr.Zone,
}
}
func sameUDPAddr(left, right *net.UDPAddr) bool {
if left == nil || right == nil {
return left == right
}
if left.Port != right.Port || left.Zone != right.Zone {
return false
}
return left.IP.Equal(right.IP)
return err2
}