280 lines
8.7 KiB
Go
280 lines
8.7 KiB
Go
package protocol
|
||
|
||
import (
|
||
"encoding/binary"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"unicode/utf8"
|
||
)
|
||
|
||
// MaxFrameSize 用于限制单个帧的最大长度,
|
||
// 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。
|
||
const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB
|
||
|
||
var (
|
||
ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。
|
||
ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。
|
||
ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。
|
||
ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。
|
||
ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。
|
||
ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。
|
||
ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。
|
||
ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。
|
||
ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。
|
||
ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。
|
||
ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。
|
||
ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。
|
||
ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。
|
||
ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。
|
||
)
|
||
|
||
// 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON(下面自定义的Message头)][body bytes]
|
||
// 写了 tag:JSON 字段名是你指定的 type;不写 tag:JSON 字段名默认是 Go 字段名 Type
|
||
type messageHeader struct {
|
||
Type MessageType `json:"type"`
|
||
ID uint64 `json:"id"`
|
||
From string `json:"from"`
|
||
To string `json:"to"`
|
||
FileName string `json:"file_name,omitempty"`
|
||
ContentLength int `json:"content_length"`
|
||
}
|
||
|
||
// EncodeMessage 将逻辑消息编码为帧内字节格式:
|
||
// 1. 4 字节大端序 header 长度
|
||
// 2. header JSON
|
||
// 3. 原始 body 字节
|
||
func EncodeMessage(msg Message) ([]byte, error) {
|
||
if err := validateMessage(msg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
header := messageHeader{
|
||
Type: msg.Type,
|
||
ID: msg.ID,
|
||
From: msg.From,
|
||
To: msg.To,
|
||
FileName: msg.FileName,
|
||
ContentLength: len(msg.Body),
|
||
}
|
||
|
||
headerPayload, err := json.Marshal(header)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("protocol: encode header: %w", err)
|
||
}
|
||
// 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。
|
||
payload := make([]byte, 4+len(headerPayload)+len(msg.Body))
|
||
// 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON(第五个字节开始) 和 body。
|
||
binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload)))
|
||
copy(payload[4:], headerPayload)
|
||
copy(payload[4+len(headerPayload):], msg.Body)
|
||
|
||
//检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。
|
||
if len(payload) > MaxFrameSize {
|
||
return nil, ErrFrameTooLarge
|
||
}
|
||
|
||
return payload, nil
|
||
}
|
||
|
||
// DecodeMessage 将帧内字节格式还原为 Message。
|
||
func DecodeMessage(data []byte) (Message, error) {
|
||
if len(data) > MaxFrameSize {
|
||
return Message{}, ErrFrameTooLarge
|
||
}
|
||
if len(data) < 4 {
|
||
return Message{}, ErrInvalidHeaderLength
|
||
}
|
||
|
||
headerLen := int(binary.BigEndian.Uint32(data[:4]))
|
||
if headerLen == 0 || headerLen > len(data)-4 {
|
||
return Message{}, ErrInvalidHeaderLength
|
||
}
|
||
|
||
headerPayload := data[4 : 4+headerLen]
|
||
body := data[4+headerLen:]
|
||
|
||
var header messageHeader
|
||
if err := json.Unmarshal(headerPayload, &header); err != nil {
|
||
return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err))
|
||
}
|
||
|
||
if header.ContentLength < 0 || header.ContentLength != len(body) {
|
||
return Message{}, ErrInvalidContentLength
|
||
}
|
||
|
||
bodyCopy := make([]byte, len(body))
|
||
copy(bodyCopy, body)
|
||
|
||
msg := Message{
|
||
Type: header.Type,
|
||
ID: header.ID,
|
||
From: header.From,
|
||
To: header.To,
|
||
FileName: header.FileName,
|
||
Body: bodyCopy,
|
||
}
|
||
|
||
if err := validateMessage(msg); err != nil {
|
||
return Message{}, err
|
||
}
|
||
|
||
return msg, nil
|
||
}
|
||
|
||
// WriteFrame 向流中写入一个带长度前缀的帧。
|
||
// TCP帧格式如下:
|
||
// 1. 4 字节大端序长度
|
||
// 2. 后续 payload 内容
|
||
//
|
||
// TCP 是字节流协议,没有天然的消息边界。
|
||
// 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节,
|
||
// 从而解决粘包和拆包问题。
|
||
func WriteFrame(w io.Writer, payload []byte) error {
|
||
size := len(payload)
|
||
//空帧
|
||
if size == 0 {
|
||
return ErrInvalidFrameLength
|
||
}
|
||
//帧过大
|
||
if size > MaxFrameSize {
|
||
return ErrFrameTooLarge
|
||
}
|
||
|
||
var header [4]byte
|
||
binary.BigEndian.PutUint32(header[:], uint32(size))
|
||
|
||
// 先写长度头,接收方才能根据长度一次性读取完整消息体。
|
||
if err := writeFull(w, header[:]); err != nil {
|
||
return err
|
||
}
|
||
|
||
return writeFull(w, payload)
|
||
}
|
||
|
||
// ReadFrame 从流中读取一个完整的长度前缀帧。
|
||
// 它会先读取固定 4 字节长度头,校验长度是否合法,
|
||
// 再使用 io.ReadFull 按长度读取完整消息体,
|
||
// 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。
|
||
func ReadFrame(r io.Reader) ([]byte, error) {
|
||
var header [4]byte
|
||
if _, err := io.ReadFull(r, header[:]); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
size := binary.BigEndian.Uint32(header[:])
|
||
// 长度为 0 的帧被认为是非法输入,而不是合法的空消息。
|
||
if size == 0 {
|
||
return nil, ErrInvalidFrameLength
|
||
}
|
||
// 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。
|
||
if size > MaxFrameSize {
|
||
return nil, ErrFrameTooLarge
|
||
}
|
||
|
||
payload := make([]byte, int(size))
|
||
if _, err := io.ReadFull(r, payload); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return payload, nil
|
||
}
|
||
|
||
// WriteMessage 是给上层直接使用的完整发送路径:
|
||
// 把一条结构化消息完整编码并发送出去”的总入口。
|
||
// Message -> header+body -> 长度前缀帧 -> io.Writer。
|
||
func WriteMessage(w io.Writer, msg Message) error {
|
||
payload, err := EncodeMessage(msg)
|
||
if err != nil {
|
||
return fmt.Errorf("protocol: encode message: %w", err)
|
||
}
|
||
|
||
if err := WriteFrame(w, payload); err != nil {
|
||
return fmt.Errorf("protocol: write frame: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ReadMessage 是给上层直接使用的完整接收路径:
|
||
// io.Reader -> 长度前缀帧 -> header+body -> Message。
|
||
func ReadMessage(r io.Reader) (Message, error) {
|
||
payload, err := ReadFrame(r)
|
||
if err != nil {
|
||
return Message{}, fmt.Errorf("protocol: read frame: %w", err)
|
||
}
|
||
|
||
msg, err := DecodeMessage(payload)
|
||
if err != nil {
|
||
return Message{}, fmt.Errorf("protocol: decode message: %w", err)
|
||
}
|
||
|
||
return msg, nil
|
||
}
|
||
|
||
// validateMessage 检查 Message 传输的类型(只接受 text 和 file )。
|
||
func validateMessage(msg Message) error {
|
||
if msg.From == "" {
|
||
return ErrMissingFrom
|
||
}
|
||
if msg.To == "" {
|
||
return ErrMissingTo
|
||
}
|
||
|
||
switch msg.Type {
|
||
case MessageTypeText:
|
||
if msg.FileName != "" {
|
||
return ErrUnexpectedFileName
|
||
}
|
||
if !utf8.Valid(msg.Body) {
|
||
return ErrInvalidTextBody
|
||
}
|
||
case MessageTypeFile:
|
||
if msg.FileName == "" {
|
||
return ErrMissingFileName
|
||
}
|
||
case MessageTypeRegister:
|
||
if msg.To != ServerPeerID {
|
||
return ErrInvalidRegisterTarget
|
||
}
|
||
if msg.FileName != "" {
|
||
return ErrUnexpectedFileName
|
||
}
|
||
if len(msg.Body) != 0 {
|
||
return ErrUnexpectedBody
|
||
}
|
||
case MessageTypeError:
|
||
if msg.From != ServerPeerID {
|
||
return ErrInvalidErrorSource
|
||
}
|
||
if msg.FileName != "" {
|
||
return ErrUnexpectedFileName
|
||
}
|
||
if !utf8.Valid(msg.Body) {
|
||
return ErrInvalidTextBody
|
||
}
|
||
default:
|
||
return ErrInvalidMessageType
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// writeFull 会持续写入,直到所有字节都写完或者底层返回错误。
|
||
// 这样可以避免某些 Writer 发生部分写入时破坏帧格式。
|
||
func writeFull(w io.Writer, data []byte) error {
|
||
for len(data) > 0 {
|
||
n, err := w.Write(data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if n == 0 {
|
||
return io.ErrShortWrite
|
||
}
|
||
data = data[n:]
|
||
}
|
||
|
||
return nil
|
||
}
|