package transport import ( "net" "reflect" "strings" "sync" "testing" "omnisocketgo/cmd/internal/latencylog" "omnisocketgo/cmd/internal/protocol" ) // TestUDPSendReceiveMessage 验证 UDP transport 可以正常收发 text 和 file 消息。 func TestUDPSendReceiveMessage(t *testing.T) { tests := []struct { name string msg protocol.Message }{ { name: "text", msg: protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello udp"), }, }, { name: "file", msg: protocol.Message{ Type: protocol.MessageTypeFile, ID: 2, From: "peer-a", To: "peer-b", FileName: "data.bin", Body: []byte{0x00, 0x10, 0xff}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sender, receiver := newUDPConnPair(t, nil, nil) sendErr := make(chan error, 1) go func() { sendErr <- sender.Send(tt.msg) }() got, _, err := receiver.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("Send() error = %v", err) } if !reflect.DeepEqual(got, tt.msg) { t.Fatalf("message mismatch: got %+v want %+v", got, tt.msg) } }) } } // TestUDPSendLogsHandoffEvents 验证 UDP Send 会记录 handoff 事件。 func TestUDPSendLogsHandoffEvents(t *testing.T) { logger := &recordingLogger{} sender, receiver := newUDPConnPair( t, []UDPOption{WithUDPLogger(logger, latencylog.NodeRolePeer, "peer-a")}, nil, ) msg := protocol.Message{ Type: protocol.MessageTypeText, ID: 7, From: "peer-a", To: "peer-b", Body: []byte("hello"), } sendErr := make(chan error, 1) go func() { sendErr <- sender.Send(msg) }() got, _, err := receiver.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("Send() error = %v", err) } if !reflect.DeepEqual(got, msg) { t.Fatalf("message mismatch: got %+v want %+v", got, msg) } events := logger.Events() if len(events) < 2 { t.Fatalf("event count = %d, want at least 2", len(events)) } if events[0].Event != latencylog.EventSendHandoffBegin { t.Fatalf("first event = %q, want %q", events[0].Event, latencylog.EventSendHandoffBegin) } // 最后一个事件应该是 SendHandoffEnd lastEvent := events[len(events)-1] if lastEvent.Event != latencylog.EventSendHandoffEnd { t.Fatalf("last event = %q, want %q", lastEvent.Event, latencylog.EventSendHandoffEnd) } } // TestUDPReceiveLoopDeliversMessages 验证 ReceiveLoop 会逐条交付连续到达的消息。 func TestUDPReceiveLoopDeliversMessages(t *testing.T) { sender, receiver := newUDPConnPair(t, nil, nil) want := []protocol.Message{ { Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), }, { Type: protocol.MessageTypeFile, ID: 2, From: "peer-a", To: "peer-b", FileName: "payload.bin", Body: []byte{0x01, 0x02, 0x03}, }, } var ( mu sync.Mutex got []protocol.Message ) loopErr := make(chan error, 1) go func() { loopErr <- receiver.ReceiveLoop(func(msg protocol.Message, _ *net.UDPAddr) error { mu.Lock() defer mu.Unlock() got = append(got, msg) if len(got) >= len(want) { return nil } return nil }) }() for _, msg := range want { if err := sender.Send(msg); err != nil { t.Fatalf("Send() error = %v", err) } } // 关闭发送端,ReceiveLoop 会因读取错误退出 if err := sender.Close(); err != nil { t.Fatalf("sender.Close() error = %v", err) } err := <-loopErr if err == nil { t.Fatal("ReceiveLoop() error = nil, want non-nil after peer close") } if !strings.Contains(err.Error(), "udp receive loop read") { t.Fatalf("ReceiveLoop() error = %v, want read context", err) } mu.Lock() defer mu.Unlock() if !reflect.DeepEqual(got, want) { t.Fatalf("received messages mismatch: got %+v want %+v", got, want) } } // TestUDPCloseIsIdempotent 验证 Close 可以安全地被重复调用。 func TestUDPCloseIsIdempotent(t *testing.T) { conn, peer := newUDPConnPair(t, nil, nil) if err := conn.Close(); err != nil { t.Fatalf("Close(first) error = %v", err) } if err := conn.Close(); err != nil { t.Fatalf("Close(second) error = %v, want nil", err) } _ = peer.Close() } // TestUDPSendToMessage 验证 SendTo 可以向指定地址发送消息。 func TestUDPSendToMessage(t *testing.T) { serverConn := newUDPListener(t) peerConn := newUDPDialed(t, serverConn.conn.LocalAddr().(*net.UDPAddr)) msg := protocol.Message{ Type: protocol.MessageTypeText, ID: 1, From: "peer-a", To: "server", Body: []byte("hello sendto"), } // peer 发送消息到 server sendErr := make(chan error, 1) go func() { sendErr <- peerConn.Send(msg) }() got, addr, err := serverConn.Receive() if err != nil { t.Fatalf("Receive() error = %v", err) } if err := <-sendErr; err != nil { t.Fatalf("Send() error = %v", err) } if !reflect.DeepEqual(got, msg) { t.Fatalf("message mismatch: got %+v want %+v", got, msg) } // server 用 SendTo 回复到 peer 地址 reply := protocol.Message{ Type: protocol.MessageTypeText, ID: 2, From: "server", To: "peer-a", Body: []byte("reply"), } sendErr2 := make(chan error, 1) go func() { sendErr2 <- serverConn.SendTo(reply, addr) }() gotReply, _, err := peerConn.Receive() if err != nil { t.Fatalf("peer Receive() error = %v", err) } if err := <-sendErr2; err != nil { t.Fatalf("SendTo() error = %v", err) } if !reflect.DeepEqual(gotReply, reply) { t.Fatalf("reply mismatch: got %+v want %+v", gotReply, reply) } } // newUDPConnPair 创建一对互相连接的 UDP transport 连接,用于测试。 func newUDPConnPair(t *testing.T, senderOpts []UDPOption, receiverOpts []UDPOption) (*UDPConn, *UDPConn) { t.Helper() // 创建两个 UDP socket,通过 Dial 互相连接 addr1, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ResolveUDPAddr() error = %v", err) } addr2, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ResolveUDPAddr() error = %v", err) } conn1, err := net.ListenUDP("udp", addr1) if err != nil { t.Fatalf("ListenUDP(1) error = %v", err) } conn2, err := net.ListenUDP("udp", addr2) if err != nil { _ = conn1.Close() t.Fatalf("ListenUDP(2) error = %v", err) } // 用 Dial 模式连接对端 senderRaw, err := net.DialUDP("udp", nil, conn2.LocalAddr().(*net.UDPAddr)) if err != nil { _ = conn1.Close() _ = conn2.Close() t.Fatalf("DialUDP(sender) error = %v", err) } _ = conn1.Close() // 不再需要 conn1 receiverRaw, err := net.DialUDP("udp", conn2.LocalAddr().(*net.UDPAddr), senderRaw.LocalAddr().(*net.UDPAddr)) if err != nil { _ = senderRaw.Close() _ = conn2.Close() t.Fatalf("DialUDP(receiver) error = %v", err) } _ = conn2.Close() // 不再需要 conn2 sender, err := NewUDPConn(senderRaw, nil, senderOpts...) if err != nil { _ = senderRaw.Close() _ = receiverRaw.Close() t.Fatalf("NewUDPConn(sender) error = %v", err) } receiver, err := NewUDPConn(receiverRaw, nil, receiverOpts...) if err != nil { _ = sender.Close() _ = receiverRaw.Close() t.Fatalf("NewUDPConn(receiver) error = %v", err) } t.Cleanup(func() { _ = sender.Close() _ = receiver.Close() }) return sender, receiver } // newUDPListener 创建一个监听模式的 UDP 连接,用于测试 server 场景。 func newUDPListener(t *testing.T) *UDPConn { t.Helper() addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { t.Fatalf("ResolveUDPAddr() error = %v", err) } raw, err := net.ListenUDP("udp", addr) if err != nil { t.Fatalf("ListenUDP() error = %v", err) } conn, err := NewUDPConn(raw, nil) if err != nil { _ = raw.Close() t.Fatalf("NewUDPConn() error = %v", err) } t.Cleanup(func() { _ = conn.Close() }) return conn } // newUDPDialed 创建一个已连接到指定地址的 UDP 连接,用于测试 peer 场景。 func newUDPDialed(t *testing.T, serverAddr *net.UDPAddr) *UDPConn { t.Helper() raw, err := net.DialUDP("udp", nil, serverAddr) if err != nil { t.Fatalf("DialUDP() error = %v", err) } conn, err := NewUDPConn(raw, nil) if err != nil { _ = raw.Close() t.Fatalf("NewUDPConn() error = %v", err) } t.Cleanup(func() { _ = conn.Close() }) return conn }