init
This commit is contained in:
192
cmd/internal/server/hub.go
Normal file
192
cmd/internal/server/hub.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"omnisocketgo/cmd/internal/latencylog"
|
||||
"omnisocketgo/cmd/internal/protocol"
|
||||
"omnisocketgo/cmd/internal/transport"
|
||||
)
|
||||
|
||||
const gracefulRejectCloseTimeout = 100 * time.Millisecond
|
||||
|
||||
// Hub 管理已注册 peer 的连接,并负责在它们之间转发消息。
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
peers map[string]*transport.TCPConn
|
||||
logger latencylog.Logger
|
||||
}
|
||||
|
||||
// Option 用于配置 Hub 的可选行为,例如时延日志。
|
||||
type Option func(*Hub)
|
||||
|
||||
// WithLogger 为 hub 注入时延日志记录器。
|
||||
func WithLogger(logger latencylog.Logger) Option {
|
||||
return func(hub *Hub) {
|
||||
hub.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// NewHub 创建一个空的连接中心。
|
||||
func NewHub(opts ...Option) *Hub {
|
||||
hub := &Hub{
|
||||
peers: make(map[string]*transport.TCPConn),
|
||||
logger: latencylog.NoopLogger{},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(hub)
|
||||
}
|
||||
|
||||
if hub.logger == nil {
|
||||
hub.logger = latencylog.NoopLogger{}
|
||||
}
|
||||
|
||||
return hub
|
||||
}
|
||||
|
||||
// HasPeer 返回给定 ID 是否已经注册到 hub。
|
||||
func (h *Hub) HasPeer(peerID string) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
_, ok := h.peers[peerID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ServeConn 处理一条新接入的底层 TCP 连接。
|
||||
// 连接上的第一条消息必须是 register,之后才允许转发 text/file。
|
||||
func (h *Hub) ServeConn(rawConn net.Conn) error {
|
||||
conn, err := transport.NewTCPConn(rawConn)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return fmt.Errorf("server: create transport conn: %w", err)
|
||||
}
|
||||
|
||||
peerID, gracefulClose, err := h.registerConn(conn)
|
||||
if err != nil {
|
||||
h.closeConn(conn, gracefulClose)
|
||||
return err
|
||||
}
|
||||
defer h.unregister(peerID, conn)
|
||||
|
||||
if err := h.receivePeerLoop(peerID, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerConn 从新连接上读取第一条消息,验证它是 register 消息,并把连接注册到 hub。
|
||||
func (h *Hub) registerConn(conn *transport.TCPConn) (string, bool, error) {
|
||||
msg, err := conn.Receive()
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("server: receive register: %w", err)
|
||||
}
|
||||
|
||||
if msg.Type != protocol.MessageTypeRegister {
|
||||
if sendErr := sendServerError(conn, msg.From, "first message must be register"); sendErr != nil {
|
||||
return "", false, fmt.Errorf("server: reject unregistered peer: %w", sendErr)
|
||||
}
|
||||
return "", true, fmt.Errorf("server: first message must be register, got %s", msg.Type)
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if _, exists := h.peers[msg.From]; exists {
|
||||
if sendErr := sendServerError(conn, msg.From, fmt.Sprintf("duplicate peer id: %s", msg.From)); sendErr != nil {
|
||||
return "", false, fmt.Errorf("server: duplicate peer id %s: %w", msg.From, sendErr)
|
||||
}
|
||||
return "", true, fmt.Errorf("server: duplicate peer id: %s", msg.From)
|
||||
}
|
||||
|
||||
h.peers[msg.From] = conn
|
||||
return msg.From, false, nil
|
||||
}
|
||||
|
||||
// handlePeerMessage 验证消息类型并执行相应的转发或错误响应。
|
||||
func (h *Hub) handlePeerMessage(peerID string, conn *transport.TCPConn, msg protocol.Message) (bool, error) {
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeText, protocol.MessageTypeFile: //只允许已注册的 peer 发送文本或文件消息,其他类型都视为协议错误。
|
||||
msg.From = peerID
|
||||
targetConn, ok := h.lookup(msg.To)
|
||||
if !ok {
|
||||
return false, sendServerError(conn, peerID, fmt.Sprintf("unknown target: %s", msg.To))
|
||||
}
|
||||
if err := targetConn.Send(msg); err != nil { //转发消息,如果发送失败,说明目标连接可能已经不可用,此时从 hub 中注销该连接并关闭它,并向发送方返回错误响应。
|
||||
h.unregister(msg.To, targetConn)
|
||||
_ = targetConn.Close()
|
||||
return false, sendServerError(conn, peerID, fmt.Sprintf("failed to forward to %s", msg.To))
|
||||
}
|
||||
return false, nil
|
||||
case protocol.MessageTypeRegister, protocol.MessageTypeError: //已注册的 peer 不允许再发送 register 或 error 消息,这些都视为协议错误。
|
||||
if err := sendServerError(conn, peerID, "registered peers can only send text or file messages"); err != nil {
|
||||
return false, fmt.Errorf("server: send protocol error: %w", err)
|
||||
}
|
||||
return true, fmt.Errorf("server: unexpected message type from peer %s: %s", peerID, msg.Type)
|
||||
default: // 其他任何消息类型都视为协议错误。
|
||||
if err := sendServerError(conn, peerID, fmt.Sprintf("unsupported message type: %s", msg.Type)); err != nil {
|
||||
return false, fmt.Errorf("server: send unsupported type error: %w", err)
|
||||
}
|
||||
return true, fmt.Errorf("server: unsupported message type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) receivePeerLoop(peerID string, conn *transport.TCPConn) error {
|
||||
for {
|
||||
msg, err := conn.Receive()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("transport: receive loop read: %w", err)
|
||||
}
|
||||
|
||||
gracefulClose, err := h.handlePeerMessage(peerID, conn, msg)
|
||||
if err != nil {
|
||||
h.closeConn(conn, gracefulClose)
|
||||
return fmt.Errorf("transport: receive loop handler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// lookup 在 hub 中查找目标 peer 的连接。
|
||||
func (h *Hub) lookup(peerID string) (*transport.TCPConn, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
conn, ok := h.peers[peerID]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// unregister 从 hub 中移除指定 peer 的连接,通常在连接关闭或发生错误时调用。
|
||||
func (h *Hub) unregister(peerID string, conn *transport.TCPConn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
current, ok := h.peers[peerID]
|
||||
if ok && current == conn {
|
||||
delete(h.peers, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) closeConn(conn *transport.TCPConn, graceful bool) {
|
||||
if graceful {
|
||||
_ = conn.CloseGracefully(gracefulRejectCloseTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
// sendServerError 是一个辅助函数,用于向指定 peer 发送错误消息。
|
||||
func sendServerError(conn *transport.TCPConn, to, message string) error {
|
||||
return conn.Send(protocol.Message{
|
||||
Type: protocol.MessageTypeError,
|
||||
From: protocol.ServerPeerID,
|
||||
To: to,
|
||||
Body: []byte(message),
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user