Files

280 lines
8.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package protocol
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"unicode/utf8"
)
// MaxFrameSize 用于限制单个帧的最大长度,
// 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。
const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB
var (
ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。
ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。
ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。
ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。
ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。
ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。
ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。
ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。
ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。
ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。
ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。
ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。
ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。
ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。
)
// 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON下面自定义的Message头][body bytes]
// 写了 tagJSON 字段名是你指定的 type不写 tagJSON 字段名默认是 Go 字段名 Type
type messageHeader struct {
Type MessageType `json:"type"`
ID uint64 `json:"id"`
From string `json:"from"`
To string `json:"to"`
FileName string `json:"file_name,omitempty"`
ContentLength int `json:"content_length"`
}
// EncodeMessage 将逻辑消息编码为帧内字节格式:
// 1. 4 字节大端序 header 长度
// 2. header JSON
// 3. 原始 body 字节
func EncodeMessage(msg Message) ([]byte, error) {
if err := validateMessage(msg); err != nil {
return nil, err
}
header := messageHeader{
Type: msg.Type,
ID: msg.ID,
From: msg.From,
To: msg.To,
FileName: msg.FileName,
ContentLength: len(msg.Body),
}
headerPayload, err := json.Marshal(header)
if err != nil {
return nil, fmt.Errorf("protocol: encode header: %w", err)
}
// 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。
payload := make([]byte, 4+len(headerPayload)+len(msg.Body))
// 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON第五个字节开始 和 body。
binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload)))
copy(payload[4:], headerPayload)
copy(payload[4+len(headerPayload):], msg.Body)
//检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。
if len(payload) > MaxFrameSize {
return nil, ErrFrameTooLarge
}
return payload, nil
}
// DecodeMessage 将帧内字节格式还原为 Message。
func DecodeMessage(data []byte) (Message, error) {
if len(data) > MaxFrameSize {
return Message{}, ErrFrameTooLarge
}
if len(data) < 4 {
return Message{}, ErrInvalidHeaderLength
}
headerLen := int(binary.BigEndian.Uint32(data[:4]))
if headerLen == 0 || headerLen > len(data)-4 {
return Message{}, ErrInvalidHeaderLength
}
headerPayload := data[4 : 4+headerLen]
body := data[4+headerLen:]
var header messageHeader
if err := json.Unmarshal(headerPayload, &header); err != nil {
return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err))
}
if header.ContentLength < 0 || header.ContentLength != len(body) {
return Message{}, ErrInvalidContentLength
}
bodyCopy := make([]byte, len(body))
copy(bodyCopy, body)
msg := Message{
Type: header.Type,
ID: header.ID,
From: header.From,
To: header.To,
FileName: header.FileName,
Body: bodyCopy,
}
if err := validateMessage(msg); err != nil {
return Message{}, err
}
return msg, nil
}
// WriteFrame 向流中写入一个带长度前缀的帧。
// TCP帧格式如下
// 1. 4 字节大端序长度
// 2. 后续 payload 内容
//
// TCP 是字节流协议,没有天然的消息边界。
// 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节,
// 从而解决粘包和拆包问题。
func WriteFrame(w io.Writer, payload []byte) error {
size := len(payload)
//空帧
if size == 0 {
return ErrInvalidFrameLength
}
//帧过大
if size > MaxFrameSize {
return ErrFrameTooLarge
}
var header [4]byte
binary.BigEndian.PutUint32(header[:], uint32(size))
// 先写长度头,接收方才能根据长度一次性读取完整消息体。
if err := writeFull(w, header[:]); err != nil {
return err
}
return writeFull(w, payload)
}
// ReadFrame 从流中读取一个完整的长度前缀帧。
// 它会先读取固定 4 字节长度头,校验长度是否合法,
// 再使用 io.ReadFull 按长度读取完整消息体,
// 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。
func ReadFrame(r io.Reader) ([]byte, error) {
var header [4]byte
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, err
}
size := binary.BigEndian.Uint32(header[:])
// 长度为 0 的帧被认为是非法输入,而不是合法的空消息。
if size == 0 {
return nil, ErrInvalidFrameLength
}
// 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。
if size > MaxFrameSize {
return nil, ErrFrameTooLarge
}
payload := make([]byte, int(size))
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
return payload, nil
}
// WriteMessage 是给上层直接使用的完整发送路径:
// 把一条结构化消息完整编码并发送出去”的总入口。
// Message -> header+body -> 长度前缀帧 -> io.Writer。
func WriteMessage(w io.Writer, msg Message) error {
payload, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("protocol: encode message: %w", err)
}
if err := WriteFrame(w, payload); err != nil {
return fmt.Errorf("protocol: write frame: %w", err)
}
return nil
}
// ReadMessage 是给上层直接使用的完整接收路径:
// io.Reader -> 长度前缀帧 -> header+body -> Message。
func ReadMessage(r io.Reader) (Message, error) {
payload, err := ReadFrame(r)
if err != nil {
return Message{}, fmt.Errorf("protocol: read frame: %w", err)
}
msg, err := DecodeMessage(payload)
if err != nil {
return Message{}, fmt.Errorf("protocol: decode message: %w", err)
}
return msg, nil
}
// validateMessage 检查 Message 传输的类型(只接受 text 和 file )。
func validateMessage(msg Message) error {
if msg.From == "" {
return ErrMissingFrom
}
if msg.To == "" {
return ErrMissingTo
}
switch msg.Type {
case MessageTypeText:
if msg.FileName != "" {
return ErrUnexpectedFileName
}
if !utf8.Valid(msg.Body) {
return ErrInvalidTextBody
}
case MessageTypeFile:
if msg.FileName == "" {
return ErrMissingFileName
}
case MessageTypeRegister:
if msg.To != ServerPeerID {
return ErrInvalidRegisterTarget
}
if msg.FileName != "" {
return ErrUnexpectedFileName
}
if len(msg.Body) != 0 {
return ErrUnexpectedBody
}
case MessageTypeError:
if msg.From != ServerPeerID {
return ErrInvalidErrorSource
}
if msg.FileName != "" {
return ErrUnexpectedFileName
}
if !utf8.Valid(msg.Body) {
return ErrInvalidTextBody
}
default:
return ErrInvalidMessageType
}
return nil
}
// writeFull 会持续写入,直到所有字节都写完或者底层返回错误。
// 这样可以避免某些 Writer 发生部分写入时破坏帧格式。
func writeFull(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if err != nil {
return err
}
if n == 0 {
return io.ErrShortWrite
}
data = data[n:]
}
return nil
}