package server import ( "fmt" "log" "net" "sync" "omnisocketgo/cmd/internal/protocol" ) const udpRelayBufSize = protocol.MaxFrameSize + 1024 // UDPRelay transparently forwards UDP datagrams between one downstream client // and a fixed upstream server. type UDPRelay struct { downstream net.PacketConn upstream *net.UDPConn mu sync.RWMutex clientAddr net.Addr } // NewUDPRelay creates a relay that listens on listenConn and forwards all // traffic to upstreamAddr. func NewUDPRelay(listenConn net.PacketConn, upstreamAddr *net.UDPAddr) (*UDPRelay, error) { if listenConn == nil { return nil, fmt.Errorf("relay: listen conn is required") } if upstreamAddr == nil { return nil, fmt.Errorf("relay: upstream addr is required") } upstreamConn, err := net.DialUDP(relayUDPNetwork(upstreamAddr), nil, upstreamAddr) if err != nil { return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err) } return &UDPRelay{ downstream: listenConn, upstream: upstreamConn, }, nil } // Serve starts bidirectional forwarding and blocks until either direction // exits with an error. func (r *UDPRelay) Serve() error { errCh := make(chan error, 2) go func() { errCh <- r.forwardDownstreamToUpstream() }() go func() { errCh <- r.forwardUpstreamToDownstream() }() err := <-errCh _ = r.downstream.Close() _ = r.upstream.Close() return err } func (r *UDPRelay) forwardDownstreamToUpstream() error { buf := make([]byte, udpRelayBufSize) for { n, addr, err := r.downstream.ReadFrom(buf) if err != nil { return fmt.Errorf("relay: read downstream: %w", err) } clientAddr := cloneRelayAddr(addr) r.mu.Lock() previousAddr := cloneRelayAddr(r.clientAddr) r.clientAddr = clientAddr r.mu.Unlock() switch { case previousAddr == nil: log.Printf("relay: learned downstream client %s", clientAddr) case !sameRelayAddr(previousAddr, clientAddr): log.Printf("relay: downstream client changed from %s to %s", previousAddr, clientAddr) } if _, err := r.upstream.Write(buf[:n]); err != nil { return fmt.Errorf("relay: write upstream: %w", err) } log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr) } } func (r *UDPRelay) forwardUpstreamToDownstream() error { buf := make([]byte, udpRelayBufSize) for { n, err := r.upstream.Read(buf) if err != nil { return fmt.Errorf("relay: read upstream: %w", err) } r.mu.RLock() addr := cloneRelayAddr(r.clientAddr) r.mu.RUnlock() if addr == nil { log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n) continue } if _, err := r.downstream.WriteTo(buf[:n], addr); err != nil { return fmt.Errorf("relay: write downstream to %s: %w", addr, err) } log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr) } } func (r *UDPRelay) Close() error { err1 := r.downstream.Close() err2 := r.upstream.Close() if err1 != nil { return err1 } return err2 } func relayUDPNetwork(addr *net.UDPAddr) string { if addr == nil || addr.IP == nil { return "udp" } if addr.IP.To4() != nil { return "udp4" } return "udp6" }