136 lines
3.1 KiB
Go
136 lines
3.1 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
|
|
"omnisocketgo/cmd/internal/protocol"
|
|
)
|
|
|
|
const udpRelayBufSize = protocol.MaxFrameSize + 1024
|
|
|
|
// UDPRelay transparently forwards UDP datagrams between one downstream client
|
|
// and a fixed upstream server.
|
|
type UDPRelay struct {
|
|
downstream net.PacketConn
|
|
upstream *net.UDPConn
|
|
|
|
mu sync.RWMutex
|
|
clientAddr net.Addr
|
|
}
|
|
|
|
// NewUDPRelay creates a relay that listens on listenConn and forwards all
|
|
// traffic to upstreamAddr.
|
|
func NewUDPRelay(listenConn net.PacketConn, upstreamAddr *net.UDPAddr) (*UDPRelay, error) {
|
|
if listenConn == nil {
|
|
return nil, fmt.Errorf("relay: listen conn is required")
|
|
}
|
|
if upstreamAddr == nil {
|
|
return nil, fmt.Errorf("relay: upstream addr is required")
|
|
}
|
|
|
|
upstreamConn, err := net.DialUDP(relayUDPNetwork(upstreamAddr), nil, upstreamAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err)
|
|
}
|
|
|
|
return &UDPRelay{
|
|
downstream: listenConn,
|
|
upstream: upstreamConn,
|
|
}, nil
|
|
}
|
|
|
|
// Serve starts bidirectional forwarding and blocks until either direction
|
|
// exits with an error.
|
|
func (r *UDPRelay) Serve() error {
|
|
errCh := make(chan error, 2)
|
|
|
|
go func() {
|
|
errCh <- r.forwardDownstreamToUpstream()
|
|
}()
|
|
go func() {
|
|
errCh <- r.forwardUpstreamToDownstream()
|
|
}()
|
|
|
|
err := <-errCh
|
|
_ = r.downstream.Close()
|
|
_ = r.upstream.Close()
|
|
return err
|
|
}
|
|
|
|
func (r *UDPRelay) forwardDownstreamToUpstream() error {
|
|
buf := make([]byte, udpRelayBufSize)
|
|
for {
|
|
n, addr, err := r.downstream.ReadFrom(buf)
|
|
if err != nil {
|
|
return fmt.Errorf("relay: read downstream: %w", err)
|
|
}
|
|
|
|
clientAddr := cloneRelayAddr(addr)
|
|
|
|
r.mu.Lock()
|
|
previousAddr := cloneRelayAddr(r.clientAddr)
|
|
r.clientAddr = clientAddr
|
|
r.mu.Unlock()
|
|
|
|
switch {
|
|
case previousAddr == nil:
|
|
log.Printf("relay: learned downstream client %s", clientAddr)
|
|
case !sameRelayAddr(previousAddr, clientAddr):
|
|
log.Printf("relay: downstream client changed from %s to %s", previousAddr, clientAddr)
|
|
}
|
|
|
|
if _, err := r.upstream.Write(buf[:n]); err != nil {
|
|
return fmt.Errorf("relay: write upstream: %w", err)
|
|
}
|
|
|
|
log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr)
|
|
}
|
|
}
|
|
|
|
func (r *UDPRelay) forwardUpstreamToDownstream() error {
|
|
buf := make([]byte, udpRelayBufSize)
|
|
for {
|
|
n, err := r.upstream.Read(buf)
|
|
if err != nil {
|
|
return fmt.Errorf("relay: read upstream: %w", err)
|
|
}
|
|
|
|
r.mu.RLock()
|
|
addr := cloneRelayAddr(r.clientAddr)
|
|
r.mu.RUnlock()
|
|
|
|
if addr == nil {
|
|
log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n)
|
|
continue
|
|
}
|
|
|
|
if _, err := r.downstream.WriteTo(buf[:n], addr); err != nil {
|
|
return fmt.Errorf("relay: write downstream to %s: %w", addr, err)
|
|
}
|
|
|
|
log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr)
|
|
}
|
|
}
|
|
|
|
func (r *UDPRelay) Close() error {
|
|
err1 := r.downstream.Close()
|
|
err2 := r.upstream.Close()
|
|
if err1 != nil {
|
|
return err1
|
|
}
|
|
return err2
|
|
}
|
|
|
|
func relayUDPNetwork(addr *net.UDPAddr) string {
|
|
if addr == nil || addr.IP == nil {
|
|
return "udp"
|
|
}
|
|
if addr.IP.To4() != nil {
|
|
return "udp4"
|
|
}
|
|
return "udp6"
|
|
}
|