添加新的网口标识

This commit is contained in:
nnbcccscdscdsc
2026-03-23 20:50:44 +08:00
parent 4824675244
commit 2dd33bf73e
4 changed files with 123 additions and 8 deletions

View File

@@ -6,16 +6,19 @@ import (
"os"
"path/filepath"
"sync/atomic"
"syscall"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
var dialServer = net.Dial
var dialServer = dialServerWithOptions
type clientOptions struct {
logger latencylog.Logger
logger latencylog.Logger
bindIP string
bindDevice string
}
// Option 用于配置 Client 的可选行为,例如时延日志。
@@ -28,6 +31,20 @@ func WithLogger(logger latencylog.Logger) Option {
}
}
// WithBindIP 指定拨号时使用的本地源 IP。
func WithBindIP(ip string) Option {
return func(options *clientOptions) {
options.bindIP = ip
}
}
// WithBindDevice 指定拨号时绑定的 Linux 网卡名,例如 eth0 或 wwan0。
func WithBindDevice(device string) Option {
return func(options *clientOptions) {
options.bindDevice = device
}
}
// Client 表示一个已经连接到 server 的 peer。
type Client struct {
id string
@@ -49,7 +66,7 @@ func Dial(serverAddr, peerID string, opts ...Option) (*Client, error) {
options.logger = latencylog.NoopLogger{}
}
rawConn, err := dialServer("tcp", serverAddr) //使用 net.Dial 连接到 serverAddr 指定的 TCP 地址,返回一个 net.Conn。
rawConn, err := dialServer(serverAddr, options)
if err != nil {
return nil, fmt.Errorf("peer: dial server: %w", err)
}
@@ -177,3 +194,42 @@ func (c *Client) Close() error {
func (c *Client) nextMessageID() uint64 {
return atomic.AddUint64(&c.nextID, 1)
}
func dialServerWithOptions(serverAddr string, options clientOptions) (net.Conn, error) {
dialer, err := buildDialer(options)
if err != nil {
return nil, err
}
return dialer.Dial("tcp", serverAddr)
}
func buildDialer(options clientOptions) (*net.Dialer, error) {
dialer := &net.Dialer{}
if options.bindIP != "" {
ip := net.ParseIP(options.bindIP)
if ip == nil {
return nil, fmt.Errorf("peer: invalid bind ip %q", options.bindIP)
}
dialer.LocalAddr = &net.TCPAddr{IP: ip}
}
if options.bindDevice != "" {
device := options.bindDevice
dialer.Control = func(_, _ string, rawConn syscall.RawConn) error {
var bindErr error
if err := rawConn.Control(func(fd uintptr) {
bindErr = syscall.BindToDevice(int(fd), device)
}); err != nil {
return err
}
if bindErr != nil {
return fmt.Errorf("peer: bind device %s: %w", device, bindErr)
}
return nil
}
}
return dialer, nil
}

View File

@@ -58,6 +58,51 @@ func TestDialRegistersPeer(t *testing.T) {
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
}
func TestDialRegistersPeerWithBindIP(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
defer cleanup()
client, err := Dial("ignored", "peer-a", WithBindIP("127.0.0.1"))
if err != nil {
t.Fatalf("Dial() with bind ip error = %v", err)
}
defer func() { _ = client.Close() }()
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
}
func TestDialRejectsInvalidBindIP(t *testing.T) {
_, err := Dial("ignored", "peer-a", WithBindIP("not-an-ip"))
if err == nil {
t.Fatal("Dial() error = nil, want invalid bind ip error")
}
if !strings.Contains(err.Error(), `invalid bind ip "not-an-ip"`) {
t.Fatalf("Dial() error = %v, want invalid bind ip error", err)
}
}
func TestDialPassesBindDeviceOptionToDialer(t *testing.T) {
originalDial := dialServer
defer func() {
dialServer = originalDial
}()
gotDevice := ""
dialServer = func(_ string, options clientOptions) (net.Conn, error) {
gotDevice = options.bindDevice
return nil, net.ErrClosed
}
_, err := Dial("ignored", "peer-a", WithBindDevice("wwan0"))
if err == nil {
t.Fatal("Dial() error = nil, want dial error")
}
if gotDevice != "wwan0" {
t.Fatalf("bind device = %q, want %q", gotDevice, "wwan0")
}
}
func TestClientsExchangeTextAndFileMessages(t *testing.T) {
hub := server.NewHub()
cleanup := stubDialToHub(t, hub)
@@ -662,8 +707,12 @@ func stubDialToHub(t *testing.T, hub *server.Hub) func() {
originalDial := dialServer
serverAddr, cleanup := startRealHubServer(t, hub)
dialServer = func(network, addr string) (net.Conn, error) {
return net.Dial(network, serverAddr)
dialServer = func(_ string, options clientOptions) (net.Conn, error) {
dialer, err := buildDialer(options)
if err != nil {
return nil, err
}
return dialer.Dial("tcp", serverAddr)
}
return func() {