128 lines
3.0 KiB
Go
128 lines
3.0 KiB
Go
package server
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"sync"
|
||
|
||
"omnisocketgo/cmd/internal/transport"
|
||
)
|
||
|
||
// UDPRelay 负责在固定远端与多个客户端之间双向透明转发 KCP UDP datagram。
|
||
type UDPRelay struct {
|
||
conn net.PacketConn
|
||
remote *net.UDPAddr
|
||
|
||
mu sync.RWMutex
|
||
clients map[uint32]*net.UDPAddr
|
||
}
|
||
|
||
// NewUDPRelay 创建一个绑定到给定 PacketConn 的透明 UDP relay。
|
||
func NewUDPRelay(conn net.PacketConn, remote *net.UDPAddr) (*UDPRelay, error) {
|
||
if conn == nil {
|
||
return nil, fmt.Errorf("server: nil udp relay conn")
|
||
}
|
||
if remote == nil {
|
||
return nil, fmt.Errorf("server: nil udp relay remote")
|
||
}
|
||
|
||
return &UDPRelay{
|
||
conn: conn,
|
||
remote: cloneUDPAddr(remote),
|
||
clients: make(map[uint32]*net.UDPAddr),
|
||
}, nil
|
||
}
|
||
|
||
// Serve 持续双向转发原始 UDP datagram,不解析业务消息。
|
||
func (r *UDPRelay) Serve() error {
|
||
buffer := make([]byte, 64*1024)
|
||
for {
|
||
n, addr, err := r.conn.ReadFrom(buffer)
|
||
if err != nil {
|
||
if isExpectedRelayServeExit(err) {
|
||
return nil
|
||
}
|
||
return fmt.Errorf("server: udp relay read packet: %w", err)
|
||
}
|
||
|
||
udpAddr, ok := addr.(*net.UDPAddr)
|
||
if !ok {
|
||
log.Printf("udp relay dropped packet from non-udp addr %T", addr)
|
||
continue
|
||
}
|
||
|
||
payload := append([]byte(nil), buffer[:n]...)
|
||
if sameUDPAddr(udpAddr, r.remote) {
|
||
if err := r.forwardRemotePacket(payload); err != nil {
|
||
log.Printf("udp relay failed forwarding remote packet from %s: %v", udpAddr, err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
if err := r.forwardClientPacket(udpAddr, payload); err != nil {
|
||
log.Printf("udp relay failed forwarding client packet from %s: %v", udpAddr, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (r *UDPRelay) forwardClientPacket(addr *net.UDPAddr, payload []byte) error {
|
||
convID, ok := transport.ParseKCPConversationID(payload)
|
||
if !ok {
|
||
return fmt.Errorf("missing kcp conversation id")
|
||
}
|
||
|
||
r.mu.Lock()
|
||
r.clients[convID] = cloneUDPAddr(addr)
|
||
r.mu.Unlock()
|
||
|
||
if _, err := r.conn.WriteTo(payload, r.remote); err != nil {
|
||
return fmt.Errorf("write conv %d to remote %s: %w", convID, r.remote, err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (r *UDPRelay) forwardRemotePacket(payload []byte) error {
|
||
convID, ok := transport.ParseKCPConversationID(payload)
|
||
if !ok {
|
||
return fmt.Errorf("missing kcp conversation id")
|
||
}
|
||
|
||
r.mu.RLock()
|
||
clientAddr := cloneUDPAddr(r.clients[convID])
|
||
r.mu.RUnlock()
|
||
if clientAddr == nil {
|
||
return fmt.Errorf("unknown client for conv %d", convID)
|
||
}
|
||
|
||
if _, err := r.conn.WriteTo(payload, clientAddr); err != nil {
|
||
return fmt.Errorf("write conv %d to client %s: %w", convID, clientAddr, err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func cloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr {
|
||
if addr == nil {
|
||
return nil
|
||
}
|
||
|
||
ipCopy := make([]byte, len(addr.IP))
|
||
copy(ipCopy, addr.IP)
|
||
|
||
return &net.UDPAddr{
|
||
IP: ipCopy,
|
||
Port: addr.Port,
|
||
Zone: addr.Zone,
|
||
}
|
||
}
|
||
|
||
func sameUDPAddr(left, right *net.UDPAddr) bool {
|
||
if left == nil || right == nil {
|
||
return left == right
|
||
}
|
||
if left.Port != right.Port || left.Zone != right.Zone {
|
||
return false
|
||
}
|
||
return left.IP.Equal(right.IP)
|
||
}
|