package protocol import ( "encoding/binary" "encoding/json" "errors" "fmt" "io" "unicode/utf8" ) // MaxFrameSize 用于限制单个帧的最大长度, // 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。 const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB var ( ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。 ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。 ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。 ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。 ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。 ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。 ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。 ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。 ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。 ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。 ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。 ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。 ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。 ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。 ) // 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON(下面自定义的Message头)][body bytes] // 写了 tag:JSON 字段名是你指定的 type;不写 tag:JSON 字段名默认是 Go 字段名 Type type messageHeader struct { Type MessageType `json:"type"` ID uint64 `json:"id"` From string `json:"from"` To string `json:"to"` FileName string `json:"file_name,omitempty"` ContentLength int `json:"content_length"` } // EncodeMessage 将逻辑消息编码为帧内字节格式: // 1. 4 字节大端序 header 长度 // 2. header JSON // 3. 原始 body 字节 func EncodeMessage(msg Message) ([]byte, error) { if err := validateMessage(msg); err != nil { return nil, err } header := messageHeader{ Type: msg.Type, ID: msg.ID, From: msg.From, To: msg.To, FileName: msg.FileName, ContentLength: len(msg.Body), } headerPayload, err := json.Marshal(header) if err != nil { return nil, fmt.Errorf("protocol: encode header: %w", err) } // 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。 payload := make([]byte, 4+len(headerPayload)+len(msg.Body)) // 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON(第五个字节开始) 和 body。 binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload))) copy(payload[4:], headerPayload) copy(payload[4+len(headerPayload):], msg.Body) //检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。 if len(payload) > MaxFrameSize { return nil, ErrFrameTooLarge } return payload, nil } // DecodeMessage 将帧内字节格式还原为 Message。 func DecodeMessage(data []byte) (Message, error) { if len(data) > MaxFrameSize { return Message{}, ErrFrameTooLarge } if len(data) < 4 { return Message{}, ErrInvalidHeaderLength } headerLen := int(binary.BigEndian.Uint32(data[:4])) if headerLen == 0 || headerLen > len(data)-4 { return Message{}, ErrInvalidHeaderLength } headerPayload := data[4 : 4+headerLen] body := data[4+headerLen:] var header messageHeader if err := json.Unmarshal(headerPayload, &header); err != nil { return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err)) } if header.ContentLength < 0 || header.ContentLength != len(body) { return Message{}, ErrInvalidContentLength } bodyCopy := make([]byte, len(body)) copy(bodyCopy, body) msg := Message{ Type: header.Type, ID: header.ID, From: header.From, To: header.To, FileName: header.FileName, Body: bodyCopy, } if err := validateMessage(msg); err != nil { return Message{}, err } return msg, nil } // WriteFrame 向流中写入一个带长度前缀的帧。 // TCP帧格式如下: // 1. 4 字节大端序长度 // 2. 后续 payload 内容 // // TCP 是字节流协议,没有天然的消息边界。 // 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节, // 从而解决粘包和拆包问题。 func WriteFrame(w io.Writer, payload []byte) error { size := len(payload) //空帧 if size == 0 { return ErrInvalidFrameLength } //帧过大 if size > MaxFrameSize { return ErrFrameTooLarge } var header [4]byte binary.BigEndian.PutUint32(header[:], uint32(size)) // 先写长度头,接收方才能根据长度一次性读取完整消息体。 if err := writeFull(w, header[:]); err != nil { return err } return writeFull(w, payload) } // ReadFrame 从流中读取一个完整的长度前缀帧。 // 它会先读取固定 4 字节长度头,校验长度是否合法, // 再使用 io.ReadFull 按长度读取完整消息体, // 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。 func ReadFrame(r io.Reader) ([]byte, error) { var header [4]byte if _, err := io.ReadFull(r, header[:]); err != nil { return nil, err } size := binary.BigEndian.Uint32(header[:]) // 长度为 0 的帧被认为是非法输入,而不是合法的空消息。 if size == 0 { return nil, ErrInvalidFrameLength } // 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。 if size > MaxFrameSize { return nil, ErrFrameTooLarge } payload := make([]byte, int(size)) if _, err := io.ReadFull(r, payload); err != nil { return nil, err } return payload, nil } // WriteMessage 是给上层直接使用的完整发送路径: // 把一条结构化消息完整编码并发送出去”的总入口。 // Message -> header+body -> 长度前缀帧 -> io.Writer。 func WriteMessage(w io.Writer, msg Message) error { payload, err := EncodeMessage(msg) if err != nil { return fmt.Errorf("protocol: encode message: %w", err) } if err := WriteFrame(w, payload); err != nil { return fmt.Errorf("protocol: write frame: %w", err) } return nil } // ReadMessage 是给上层直接使用的完整接收路径: // io.Reader -> 长度前缀帧 -> header+body -> Message。 func ReadMessage(r io.Reader) (Message, error) { payload, err := ReadFrame(r) if err != nil { return Message{}, fmt.Errorf("protocol: read frame: %w", err) } msg, err := DecodeMessage(payload) if err != nil { return Message{}, fmt.Errorf("protocol: decode message: %w", err) } return msg, nil } // validateMessage 检查 Message 传输的类型(只接受 text 和 file )。 func validateMessage(msg Message) error { if msg.From == "" { return ErrMissingFrom } if msg.To == "" { return ErrMissingTo } switch msg.Type { case MessageTypeText: if msg.FileName != "" { return ErrUnexpectedFileName } if !utf8.Valid(msg.Body) { return ErrInvalidTextBody } case MessageTypeFile: if msg.FileName == "" { return ErrMissingFileName } case MessageTypeRegister: if msg.To != ServerPeerID { return ErrInvalidRegisterTarget } if msg.FileName != "" { return ErrUnexpectedFileName } if len(msg.Body) != 0 { return ErrUnexpectedBody } case MessageTypeError: if msg.From != ServerPeerID { return ErrInvalidErrorSource } if msg.FileName != "" { return ErrUnexpectedFileName } if !utf8.Valid(msg.Body) { return ErrInvalidTextBody } default: return ErrInvalidMessageType } return nil } // writeFull 会持续写入,直到所有字节都写完或者底层返回错误。 // 这样可以避免某些 Writer 发生部分写入时破坏帧格式。 func writeFull(w io.Writer, data []byte) error { for len(data) > 0 { n, err := w.Write(data) if err != nil { return err } if n == 0 { return io.ErrShortWrite } data = data[n:] } return nil }