120 lines
3.1 KiB
Go
120 lines
3.1 KiB
Go
package server
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"sync"
|
||
|
||
"omnisocketgo/cmd/internal/protocol"
|
||
)
|
||
|
||
// udpRelayBufSize 是 relay 接收缓冲区大小,与 UDP transport 层保持一致。
|
||
const udpRelayBufSize = protocol.MaxFrameSize + 1024
|
||
|
||
// UDPRelay 是一个透明的双向 UDP 转发器。
|
||
// 它在下游(客户端 A)和上游(server D)之间原样转发 UDP 数据报,
|
||
// 不解析也不修改协议内容。
|
||
type UDPRelay struct {
|
||
downstream *net.UDPConn // 监听端口,等待下游客户端连接
|
||
upstream *net.UDPConn // 连接到上游 server(connected socket)
|
||
|
||
mu sync.RWMutex
|
||
clientAddr *net.UDPAddr // 下游客户端地址,从第一个下游包学习
|
||
}
|
||
|
||
// NewUDPRelay 创建一个新的 UDP relay。
|
||
// listenConn 是已经绑定好的监听 socket(供下游客户端连接),
|
||
// upstreamAddr 是上游 server D 的地址。
|
||
func NewUDPRelay(listenConn *net.UDPConn, upstreamAddr string) (*UDPRelay, error) {
|
||
udpUpstreamAddr, err := net.ResolveUDPAddr("udp", upstreamAddr)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("relay: resolve upstream addr %s: %w", upstreamAddr, err)
|
||
}
|
||
|
||
upstreamConn, err := net.DialUDP("udp", nil, udpUpstreamAddr)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("relay: dial upstream %s: %w", upstreamAddr, err)
|
||
}
|
||
|
||
return &UDPRelay{
|
||
downstream: listenConn,
|
||
upstream: upstreamConn,
|
||
}, nil
|
||
}
|
||
|
||
// Serve 启动双向转发循环,阻塞直到任一方向出错。
|
||
func (r *UDPRelay) Serve() error {
|
||
errCh := make(chan error, 2)
|
||
|
||
go func() {
|
||
errCh <- r.forwardDownstreamToUpstream()
|
||
}()
|
||
go func() {
|
||
errCh <- r.forwardUpstreamToDownstream()
|
||
}()
|
||
|
||
err := <-errCh
|
||
// 关闭两个 conn 让另一个 goroutine 也退出
|
||
_ = r.downstream.Close()
|
||
_ = r.upstream.Close()
|
||
return err
|
||
}
|
||
|
||
// forwardDownstreamToUpstream 从下游读取并转发到上游。
|
||
func (r *UDPRelay) forwardDownstreamToUpstream() error {
|
||
buf := make([]byte, udpRelayBufSize)
|
||
for {
|
||
n, addr, err := r.downstream.ReadFromUDP(buf)
|
||
if err != nil {
|
||
return fmt.Errorf("relay: read downstream: %w", err)
|
||
}
|
||
|
||
r.mu.Lock()
|
||
r.clientAddr = addr
|
||
r.mu.Unlock()
|
||
|
||
if _, err := r.upstream.Write(buf[:n]); err != nil {
|
||
return fmt.Errorf("relay: write upstream: %w", err)
|
||
}
|
||
|
||
log.Printf("relay: forwarded %d bytes downstream(%s) -> upstream", n, addr)
|
||
}
|
||
}
|
||
|
||
// forwardUpstreamToDownstream 从上游读取并转发到下游。
|
||
func (r *UDPRelay) forwardUpstreamToDownstream() error {
|
||
buf := make([]byte, udpRelayBufSize)
|
||
for {
|
||
n, err := r.upstream.Read(buf)
|
||
if err != nil {
|
||
return fmt.Errorf("relay: read upstream: %w", err)
|
||
}
|
||
|
||
r.mu.RLock()
|
||
addr := r.clientAddr
|
||
r.mu.RUnlock()
|
||
|
||
if addr == nil {
|
||
log.Printf("relay: dropping %d bytes from upstream (no downstream client yet)", n)
|
||
continue
|
||
}
|
||
|
||
if _, err := r.downstream.WriteToUDP(buf[:n], addr); err != nil {
|
||
return fmt.Errorf("relay: write downstream to %s: %w", addr, err)
|
||
}
|
||
|
||
log.Printf("relay: forwarded %d bytes upstream -> downstream(%s)", n, addr)
|
||
}
|
||
}
|
||
|
||
// Close 关闭 relay 的上下游连接。
|
||
func (r *UDPRelay) Close() error {
|
||
err1 := r.downstream.Close()
|
||
err2 := r.upstream.Close()
|
||
if err1 != nil {
|
||
return err1
|
||
}
|
||
return err2
|
||
}
|