Files
OmniSocketGo/cmd/internal/server/udp_relay.go
2026-03-28 13:17:38 +08:00

120 lines
3.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package server
import (
"fmt"
"log"
"net"
"sync"
"omnisocketgo/cmd/internal/protocol"
)
// udpRelayBufSize 是 relay 接收缓冲区大小,与 UDP transport 层保持一致。
const udpRelayBufSize = protocol.MaxFrameSize + 1024
// UDPRelay 是一个透明的双向 UDP 转发器。
// 它在下游(客户端 A和上游server D之间原样转发 UDP 数据报,
// 不解析也不修改协议内容。
type UDPRelay struct {
downstream *net.UDPConn // 监听端口,等待下游客户端连接
upstream *net.UDPConn // 连接到上游 serverconnected socket
mu sync.RWMutex
clientAddr *net.UDPAddr // 下游客户端地址,从第一个下游包学习
}
// NewUDPRelay 创建一个新的 UDP relay。
// listenConn 是已经绑定好的监听 socket供下游客户端连接
// upstreamAddr 是上游 server D 的地址。
func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error) {
udpUpstreamAddr, err := net.ResolveUDPAddr("udp", upstreamAddr)
if err != nil {
return nil, fmt.Errorf("relay: resolve upstream addr %s: %w", upstreamAddr, err)
}
upstreamConn, err := net.DialUDP("udp", nil, udpUpstreamAddr)
if err != nil {
return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err)
}
return &UDPRelay{
downstream: listenConn,
upstream: upstreamConn,
}, nil
}
// Serve 启动双向转发循环,阻塞直到任一方向出错。
func (r *UDPRelay) Serve() error {
errCh := make(chan error, 2)
go func() {
errCh <- r.forwardDownstreamToUpstream()
}()
go func() {
errCh <- r.forwardUpstreamToDownstream()
}()
err := <-errCh
// 关闭两个 conn 让另一个 goroutine 也退出
_ = r.downstream.Close()
_ = r.upstream.Close()
return err
}
// forwardDownstreamToUpstream 从下游读取并转发到上游。
func (r *UDPRelay) forwardDownstreamToUpstream() error {
buf := make([]byte, udpRelayBufSize)
for {
n, addr, err := r.downstream.ReadFromUDP(buf)
if err != nil {
return fmt.Errorf("relay: read downstream: %w", err)
}
r.mu.Lock()
r.clientAddr = addr
r.mu.Unlock()
if _, err := r.upstream.Write(buf[:n]); err != nil {
return fmt.Errorf("relay: write upstream: %w", err)
}
log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr)
}
}
// forwardUpstreamToDownstream 从上游读取并转发到下游。
func (r *UDPRelay) forwardUpstreamToDownstream() error {
buf := make([]byte, udpRelayBufSize)
for {
n, err := r.upstream.Read(buf)
if err != nil {
return fmt.Errorf("relay: read upstream: %w", err)
}
r.mu.RLock()
addr := r.clientAddr
r.mu.RUnlock()
if addr == nil {
log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n)
continue
}
if _, err := r.downstream.WriteToUDP(buf[:n], addr); err != nil {
return fmt.Errorf("relay: write downstream to %s: %w", addr, err)
}
log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr)
}
}
// Close 关闭 relay 的上下游连接。
func (r *UDPRelay) Close() error {
err1 := r.downstream.Close()
err2 := r.upstream.Close()
if err1 != nil {
return err1
}
return err2
}