Files
nnbcccscdscdsc 4824675244 init
2026-03-23 20:18:53 +08:00

193 lines
5.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package server
import (
"fmt"
"net"
"sync"
"time"
"omnisocketgo/cmd/internal/latencylog"
"omnisocketgo/cmd/internal/protocol"
"omnisocketgo/cmd/internal/transport"
)
const gracefulRejectCloseTimeout = 100 * time.Millisecond
// Hub 管理已注册 peer 的连接,并负责在它们之间转发消息。
type Hub struct {
mu sync.RWMutex
peers map[string]*transport.TCPConn
logger latencylog.Logger
}
// Option 用于配置 Hub 的可选行为,例如时延日志。
type Option func(*Hub)
// WithLogger 为 hub 注入时延日志记录器。
func WithLogger(logger latencylog.Logger) Option {
return func(hub *Hub) {
hub.logger = logger
}
}
// NewHub 创建一个空的连接中心。
func NewHub(opts ...Option) *Hub {
hub := &Hub{
peers: make(map[string]*transport.TCPConn),
logger: latencylog.NoopLogger{},
}
for _, opt := range opts {
opt(hub)
}
if hub.logger == nil {
hub.logger = latencylog.NoopLogger{}
}
return hub
}
// HasPeer 返回给定 ID 是否已经注册到 hub。
func (h *Hub) HasPeer(peerID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
_, ok := h.peers[peerID]
return ok
}
// ServeConn 处理一条新接入的底层 TCP 连接。
// 连接上的第一条消息必须是 register之后才允许转发 text/file。
func (h *Hub) ServeConn(rawConn net.Conn) error {
conn, err := transport.NewTCPConn(rawConn)
if err != nil {
_ = rawConn.Close()
return fmt.Errorf("server: create transport conn: %w", err)
}
peerID, gracefulClose, err := h.registerConn(conn)
if err != nil {
h.closeConn(conn, gracefulClose)
return err
}
defer h.unregister(peerID, conn)
if err := h.receivePeerLoop(peerID, conn); err != nil {
return err
}
return nil
}
// registerConn 从新连接上读取第一条消息,验证它是 register 消息,并把连接注册到 hub。
func (h *Hub) registerConn(conn *transport.TCPConn) (string, bool, error) {
msg, err := conn.Receive()
if err != nil {
return "", false, fmt.Errorf("server: receive register: %w", err)
}
if msg.Type != protocol.MessageTypeRegister {
if sendErr := sendServerError(conn, msg.From, "first message must be register"); sendErr != nil {
return "", false, fmt.Errorf("server: reject unregistered peer: %w", sendErr)
}
return "", true, fmt.Errorf("server: first message must be register, got %s", msg.Type)
}
h.mu.Lock()
defer h.mu.Unlock()
if _, exists := h.peers[msg.From]; exists {
if sendErr := sendServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
return "", false, fmt.Errorf("server: duplicate peer id %s: %w", msg.From, sendErr)
}
return "", true, fmt.Errorf("server: duplicate peer id: %s", msg.From)
}
h.peers[msg.From] = conn
return msg.From, false, nil
}
// handlePeerMessage 验证消息类型并执行相应的转发或错误响应。
func (h *Hub) handlePeerMessage(peerID string, conn *transport.TCPConn, msg protocol.Message) (bool, error) {
switch msg.Type {
case protocol.MessageTypeText, protocol.MessageTypeFile: //只允许已注册的 peer 发送文本或文件消息,其他类型都视为协议错误。
msg.From = peerID
targetConn, ok := h.lookup(msg.To)
if !ok {
return false, sendServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
}
if err := targetConn.Send(msg); err != nil { //转发消息,如果发送失败,说明目标连接可能已经不可用,此时从 hub 中注销该连接并关闭它,并向发送方返回错误响应。
h.unregister(msg.To, targetConn)
_ = targetConn.Close()
return false, sendServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
}
return false, nil
case protocol.MessageTypeRegister, protocol.MessageTypeError: //已注册的 peer 不允许再发送 register 或 error 消息,这些都视为协议错误。
if err := sendServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
return false, fmt.Errorf("server: send protocol error: %w", err)
}
return true, fmt.Errorf("server: unexpected message type from peer %s: %s", peerID, msg.Type)
default: // 其他任何消息类型都视为协议错误。
if err := sendServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
return false, fmt.Errorf("server: send unsupported type error: %w", err)
}
return true, fmt.Errorf("server: unsupported message type: %s", msg.Type)
}
}
func (h *Hub) receivePeerLoop(peerID string, conn *transport.TCPConn) error {
for {
msg, err := conn.Receive()
if err != nil {
_ = conn.Close()
return fmt.Errorf("transport: receive loop read: %w", err)
}
gracefulClose, err := h.handlePeerMessage(peerID, conn, msg)
if err != nil {
h.closeConn(conn, gracefulClose)
return fmt.Errorf("transport: receive loop handler: %w", err)
}
}
}
// lookup 在 hub 中查找目标 peer 的连接。
func (h *Hub) lookup(peerID string) (*transport.TCPConn, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, ok := h.peers[peerID]
return conn, ok
}
// unregister 从 hub 中移除指定 peer 的连接,通常在连接关闭或发生错误时调用。
func (h *Hub) unregister(peerID string, conn *transport.TCPConn) {
h.mu.Lock()
defer h.mu.Unlock()
current, ok := h.peers[peerID]
if ok && current == conn {
delete(h.peers, peerID)
}
}
func (h *Hub) closeConn(conn *transport.TCPConn, graceful bool) {
if graceful {
_ = conn.CloseGracefully(gracefulRejectCloseTimeout)
return
}
_ = conn.Close()
}
// sendServerError 是一个辅助函数,用于向指定 peer 发送错误消息。
func sendServerError(conn *transport.TCPConn, to, message string) error {
return conn.Send(protocol.Message{
Type: protocol.MessageTypeError,
From: protocol.ServerPeerID,
To: to,
Body: []byte(message),
})
}