This commit is contained in:
nnbcccscdscdsc
2026-03-23 20:18:53 +08:00
commit 4824675244
28 changed files with 5569 additions and 0 deletions

192
cmd/internal/server/hub.go Normal file
View File

@@ -0,0 +1,192 @@
package server
import (
"fmt"
"net"
"sync"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
const gracefulRejectCloseTimeout = 100 * time.Millisecond
// Hub 管理已注册 peer 的连接,并负责在它们之间转发消息。
type Hub struct {
mu sync.RWMutex
peers map[string]*transport.TCPConn
logger latencylog.Logger
}
// Option 用于配置 Hub 的可选行为,例如时延日志。
type Option func(*Hub)
// WithLogger 为 hub 注入时延日志记录器。
func WithLogger(logger latencylog.Logger) Option {
return func(hub *Hub) {
hub.logger = logger
}
}
// NewHub 创建一个空的连接中心。
func NewHub(opts ...Option) *Hub {
hub := &Hub{
peers: make(map[string]*transport.TCPConn),
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
return hub
}
// HasPeer 返回给定 ID 是否已经注册到 hub。
func (h *Hub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
// ServeConn 处理一条新接入的底层 TCP 连接。
// 连接上的第一条消息必须是 register之后才允许转发 text/file。
func (h *Hub) ServeConn(rawConn net.Conn) error {
conn, err := transport.NewTCPConn(rawConn)
if err != nil {
_ = rawConn.Close()
return fmt.Errorf("server: create transport conn: %w", err)
}
peerID, gracefulClose, err := h.registerConn(conn)
if err != nil {
h.closeConn(conn, gracefulClose)
return err
}
defer h.unregister(peerID, conn)
if err := h.receivePeerLoop(peerID, conn); err != nil {
return err
}
return nil
}
// registerConn 从新连接上读取第一条消息,验证它是 register 消息,并把连接注册到 hub。
func (h *Hub) registerConn(conn *transport.TCPConn) (string, bool, error) {
msg, err := conn.Receive()
if err != nil {
return "", false, fmt.Errorf("server: receive register: %w", err)
}
if msg.Type != protocol.MessageTypeRegister {
if sendErr := sendServerError(conn, msg.From, "first message must be register"); sendErr != nil {
return "", false, fmt.Errorf("server: reject unregistered peer: %w", sendErr)
}
return "", true, fmt.Errorf("server: first message must be register, got %s", msg.Type)
}
h.mu.Lock()
defer h.mu.Unlock()
if _, exists := h.peers[msg.From]; exists {
if sendErr := sendServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
return "", false, fmt.Errorf("server: duplicate peer id %s: %w", msg.From, sendErr)
}
return "", true, fmt.Errorf("server: duplicate peer id: %s", msg.From)
}
h.peers[msg.From] = conn
return msg.From, false, nil
}
// handlePeerMessage 验证消息类型并执行相应的转发或错误响应。
func (h *Hub) handlePeerMessage(peerID string, conn *transport.TCPConn, msg protocol.Message) (bool, error) {
switch msg.Type {
case protocol.MessageTypeText, protocol.MessageTypeFile: //只允许已注册的 peer 发送文本或文件消息,其他类型都视为协议错误。
msg.From = peerID
targetConn, ok := h.lookup(msg.To)
if !ok {
return false, sendServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
}
if err := targetConn.Send(msg); err != nil { //转发消息,如果发送失败,说明目标连接可能已经不可用,此时从 hub 中注销该连接并关闭它,并向发送方返回错误响应。
h.unregister(msg.To, targetConn)
_ = targetConn.Close()
return false, sendServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
}
return false, nil
case protocol.MessageTypeRegister, protocol.MessageTypeError: //已注册的 peer 不允许再发送 register 或 error 消息,这些都视为协议错误。
if err := sendServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
return false, fmt.Errorf("server: send protocol error: %w", err)
}
return true, fmt.Errorf("server: unexpected message type from peer %s: %s", peerID, msg.Type)
default: // 其他任何消息类型都视为协议错误。
if err := sendServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
return false, fmt.Errorf("server: send unsupported type error: %w", err)
}
return true, fmt.Errorf("server: unsupported message type: %s", msg.Type)
}
}
func (h *Hub) receivePeerLoop(peerID string, conn *transport.TCPConn) error {
for {
msg, err := conn.Receive()
if err != nil {
_ = conn.Close()
return fmt.Errorf("transport: receive loop read: %w", err)
}
gracefulClose, err := h.handlePeerMessage(peerID, conn, msg)
if err != nil {
h.closeConn(conn, gracefulClose)
return fmt.Errorf("transport: receive loop handler: %w", err)
}
}
}
// lookup 在 hub 中查找目标 peer 的连接。
func (h *Hub) lookup(peerID string) (*transport.TCPConn, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, ok := h.peers[peerID]
return conn, ok
}
// unregister 从 hub 中移除指定 peer 的连接,通常在连接关闭或发生错误时调用。
func (h *Hub) unregister(peerID string, conn *transport.TCPConn) {
h.mu.Lock()
defer h.mu.Unlock()
current, ok := h.peers[peerID]
if ok && current == conn {
delete(h.peers, peerID)
}
}
func (h *Hub) closeConn(conn *transport.TCPConn, graceful bool) {
if graceful {
_ = conn.CloseGracefully(gracefulRejectCloseTimeout)
return
}
_ = conn.Close()
}
// sendServerError 是一个辅助函数,用于向指定 peer 发送错误消息。
func sendServerError(conn *transport.TCPConn, to, message string) error {
return conn.Send(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
})
}

View File

@@ -0,0 +1,398 @@
package server
import (
"net"
"reflect"
"strings"
"sync"
"testing"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
type recordingLogger struct {
mu sync.Mutex
events []latencylog.Event
}
func (l *recordingLogger) LogEvent(event latencylog.Event) error {
l.mu.Lock()
defer l.mu.Unlock()
l.events = append(l.events, event)
return nil
}
func (l *recordingLogger) Events() []latencylog.Event {
l.mu.Lock()
defer l.mu.Unlock()
return append([]latencylog.Event(nil), l.events...)
}
func TestServeConnRegistersPeer(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
}); err != nil {
t.Fatalf("Send(register) error = %v", err)
}
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Close(); err != nil {
t.Fatalf("client.Close() error = %v", err)
}
err := <-done
if err == nil || !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("ServeConn() error = %v, want read-loop shutdown error", err)
}
}
func TestServeConnRejectsDuplicatePeerID(t *testing.T) {
hub := NewHub()
first, firstDone := startHubConn(t, hub)
registerPeer(t, first, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
second, secondDone := startHubConn(t, hub)
registerPeer(t, second, "peer-a")
got, err := second.Receive()
if err != nil {
t.Fatalf("second.Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "duplicate peer id: peer-a" {
t.Fatalf("error body = %q, want duplicate peer message", got.Body)
}
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "original peer-a to remain registered")
if err := first.Close(); err != nil {
t.Fatalf("first.Close() error = %v", err)
}
if err := <-secondDone; err == nil || !strings.Contains(err.Error(), "duplicate peer id") {
t.Fatalf("second ServeConn() error = %v, want duplicate peer id error", err)
}
if err := <-firstDone; err == nil || !strings.Contains(err.Error(), "receive loop read") {
t.Fatalf("first ServeConn() error = %v, want read-loop shutdown error", err)
}
}
func TestServeConnForwardsMessages(t *testing.T) {
tests := []struct {
name string
msg protocol.Message
}{
{
name: "text",
msg: protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "spoofed",
To: "peer-b",
Body: []byte("hello"),
},
},
{
name: "file",
msg: protocol.Message{
Type: protocol.MessageTypeFile,
ID: 2,
From: "spoofed",
To: "peer-b",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hub := NewHub()
sender, senderDone := startHubConn(t, hub)
receiver, receiverDone := startHubConn(t, hub)
registerPeer(t, sender, "peer-a")
registerPeer(t, receiver, "peer-b")
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
if err := sender.Send(tt.msg); err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
want := tt.msg
want.From = "peer-a"
if !reflect.DeepEqual(got, want) {
t.Fatalf("forwarded message mismatch: got %+v want %+v", got, want)
}
_ = sender.Close()
_ = receiver.Close()
<-senderDone
<-receiverDone
})
}
}
func TestServeConnReturnsErrorForUnknownTarget(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeText,
ID: 1,
From: "peer-a",
To: "missing-peer",
Body: []byte("hello"),
}); err != nil {
t.Fatalf("Send(text) error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "unknown target: missing-peer" {
t.Fatalf("error body = %q, want unknown target message", got.Body)
}
if !hub.HasPeer("peer-a") {
t.Fatal("peer-a should remain registered after unknown target error")
}
_ = client.Close()
<-done
}
func TestServeConnRejectsRegisterAfterRegistration(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeRegister,
From: "peer-a",
To: protocol.ServerPeerID,
}); err != nil {
t.Fatalf("Send(register again) error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if string(got.Body) != "registered peers can only send text or file messages" {
t.Fatalf("error body = %q, want registered-peer protocol error", got.Body)
}
if err := <-done; err == nil || !strings.Contains(err.Error(), "unexpected message type from peer peer-a: register") {
t.Fatalf("ServeConn() error = %v, want unexpected register message error", err)
}
}
func TestServeConnUnregistersPeerOnClose(t *testing.T) {
hub := NewHub()
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Close(); err != nil {
t.Fatalf("client.Close() error = %v", err)
}
<-done
waitFor(t, func() bool { return !hub.HasPeer("peer-a") }, "peer-a to be unregistered")
}
func TestServeConnDoesNotEmitEndpointLatencyEventsOnForward(t *testing.T) {
logger := &recordingLogger{}
hub := NewHub(WithLogger(logger))
sender, senderDone := startHubConn(t, hub)
receiver, receiverDone := startHubConn(t, hub)
registerPeer(t, sender, "peer-a")
registerPeer(t, receiver, "peer-b")
waitFor(t, func() bool { return hub.HasPeer("peer-a") && hub.HasPeer("peer-b") }, "both peers to be registered")
msg := protocol.Message{
Type: protocol.MessageTypeText,
ID: 11,
From: "spoofed",
To: "peer-b",
Body: []byte("hello"),
}
if err := sender.Send(msg); err != nil {
t.Fatalf("sender.Send() error = %v", err)
}
got, err := receiver.Receive()
if err != nil {
t.Fatalf("receiver.Receive() error = %v", err)
}
msg.From = "peer-a"
if !reflect.DeepEqual(got, msg) {
t.Fatalf("forwarded message mismatch: got %+v want %+v", got, msg)
}
events := logger.Events()
if len(events) != 0 {
t.Fatalf("event count = %d, want 0 because server is a black-box relay", len(events))
}
_ = sender.Close()
_ = receiver.Close()
<-senderDone
<-receiverDone
}
func TestServeConnDoesNotLogLatencyEventsForUnknownTarget(t *testing.T) {
logger := &recordingLogger{}
hub := NewHub(WithLogger(logger))
client, done := startHubConn(t, hub)
registerPeer(t, client, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
if err := client.Send(protocol.Message{
Type: protocol.MessageTypeText,
ID: 15,
From: "peer-a",
To: "missing-peer",
Body: []byte("hello"),
}); err != nil {
t.Fatalf("Send(text) error = %v", err)
}
got, err := client.Receive()
if err != nil {
t.Fatalf("Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got message type %s, want %s", got.Type, protocol.MessageTypeError)
}
if events := logger.Events(); len(events) != 0 {
t.Fatalf("event count = %d, want 0 for unknown target path", len(events))
}
_ = client.Close()
<-done
}
func TestServeConnDoesNotLogLatencyEventsForDuplicateRegister(t *testing.T) {
logger := &recordingLogger{}
hub := NewHub(WithLogger(logger))
first, firstDone := startHubConn(t, hub)
registerPeer(t, first, "peer-a")
waitFor(t, func() bool { return hub.HasPeer("peer-a") }, "peer-a to be registered")
second, secondDone := startHubConn(t, hub)
registerPeer(t, second, "peer-a")
got, err := second.Receive()
if err != nil {
t.Fatalf("second.Receive() error = %v", err)
}
if got.Type != protocol.MessageTypeError {
t.Fatalf("got type %s, want %s", got.Type, protocol.MessageTypeError)
}
if events := logger.Events(); len(events) != 0 {
t.Fatalf("event count = %d, want 0 for duplicate register path", len(events))
}
_ = first.Close()
<-secondDone
<-firstDone
}
func startHubConn(t *testing.T, hub *Hub) (*transport.TCPConn, <-chan error) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
done := make(chan error, 1)
go func() {
serverSide, acceptErr := listener.Accept()
if acceptErr != nil {
done <- acceptErr
return
}
done <- hub.ServeConn(serverSide)
}()
clientSide, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
_ = listener.Close()
t.Fatalf("net.Dial() error = %v", err)
}
if err := listener.Close(); err != nil {
t.Fatalf("listener.Close() error = %v", err)
}
conn, err := transport.NewTCPConn(clientSide)
if err != nil {
_ = clientSide.Close()
t.Fatalf("transport.NewTCPConn() error = %v", err)
}
return conn, done
}
func registerPeer(t *testing.T, conn *transport.TCPConn, peerID string) {
t.Helper()
if err := conn.Send(protocol.Message{
Type: protocol.MessageTypeRegister,
From: peerID,
To: protocol.ServerPeerID,
}); err != nil {
t.Fatalf("Send(register %s) error = %v", peerID, err)
}
}
func waitFor(t *testing.T, condition func() bool, description string) {
t.Helper()
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for %s", description)
}