Files
OmniSocketGo/cmd/internal/server/udp_relay.go
2026-03-28 15:28:19 +08:00

136 lines
3.1 KiB
Go

package server
import (
"fmt"
"log"
"net"
"sync"
"omnisocketgo/cmd/internal/protocol"
)
const udpRelayBufSize = protocol.MaxFrameSize + 1024
// UDPRelay transparently forwards UDP datagrams between one downstream client
// and a fixed upstream server.
type UDPRelay struct {
downstream net.PacketConn
upstream *net.UDPConn
mu sync.RWMutex
clientAddr net.Addr
}
// NewUDPRelay creates a relay that listens on listenConn and forwards all
// traffic to upstreamAddr.
func NewUDPRelay(listenConn net.PacketConn, upstreamAddr *net.UDPAddr) (*UDPRelay, error) {
if listenConn == nil {
return nil, fmt.Errorf("relay: listen conn is required")
}
if upstreamAddr == nil {
return nil, fmt.Errorf("relay: upstream addr is required")
}
upstreamConn, err := net.DialUDP(relayUDPNetwork(upstreamAddr), nil, upstreamAddr)
if err != nil {
return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err)
}
return &UDPRelay{
downstream: listenConn,
upstream: upstreamConn,
}, nil
}
// Serve starts bidirectional forwarding and blocks until either direction
// exits with an error.
func (r *UDPRelay) Serve() error {
errCh := make(chan error, 2)
go func() {
errCh <- r.forwardDownstreamToUpstream()
}()
go func() {
errCh <- r.forwardUpstreamToDownstream()
}()
err := <-errCh
_ = r.downstream.Close()
_ = r.upstream.Close()
return err
}
func (r *UDPRelay) forwardDownstreamToUpstream() error {
buf := make([]byte, udpRelayBufSize)
for {
n, addr, err := r.downstream.ReadFrom(buf)
if err != nil {
return fmt.Errorf("relay: read downstream: %w", err)
}
clientAddr := cloneRelayAddr(addr)
r.mu.Lock()
previousAddr := cloneRelayAddr(r.clientAddr)
r.clientAddr = clientAddr
r.mu.Unlock()
switch {
case previousAddr == nil:
log.Printf("relay: learned downstream client %s", clientAddr)
case !sameRelayAddr(previousAddr, clientAddr):
log.Printf("relay: downstream client changed from %s to %s", previousAddr, clientAddr)
}
if _, err := r.upstream.Write(buf[:n]); err != nil {
return fmt.Errorf("relay: write upstream: %w", err)
}
log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr)
}
}
func (r *UDPRelay) forwardUpstreamToDownstream() error {
buf := make([]byte, udpRelayBufSize)
for {
n, err := r.upstream.Read(buf)
if err != nil {
return fmt.Errorf("relay: read upstream: %w", err)
}
r.mu.RLock()
addr := cloneRelayAddr(r.clientAddr)
r.mu.RUnlock()
if addr == nil {
log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n)
continue
}
if _, err := r.downstream.WriteTo(buf[:n], addr); err != nil {
return fmt.Errorf("relay: write downstream to %s: %w", addr, err)
}
log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr)
}
}
func (r *UDPRelay) Close() error {
err1 := r.downstream.Close()
err2 := r.upstream.Close()
if err1 != nil {
return err1
}
return err2
}
func relayUDPNetwork(addr *net.UDPAddr) string {
if addr == nil || addr.IP == nil {
return "udp"
}
if addr.IP.To4() != nil {
return "udp4"
}
return "udp6"
}