del: 将go版本的内容删除,只保留处理日志功能
This commit is contained in:
279
go/cmd/internal/protocol/codec.go
Normal file
279
go/cmd/internal/protocol/codec.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// MaxFrameSize 用于限制单个帧的最大长度,
|
||||
// 避免异常对端通过伪造超大长度值导致接收方无上限分配内存。
|
||||
const MaxFrameSize = 8 * 1024 * 1024 // 先临时设置传输的视频帧不超过8MB
|
||||
|
||||
var (
|
||||
ErrInvalidFrameLength = errors.New("protocol: invalid frame length") // 表示帧长度非法,例如长度为 0。
|
||||
ErrFrameTooLarge = errors.New("protocol: frame too large") // 表示帧长度超过允许的上限。
|
||||
ErrInvalidMessageType = errors.New("protocol: invalid message type") // 表示消息类型不是当前协议支持的类型。
|
||||
ErrMissingFrom = errors.New("protocol: missing from") // 表示消息缺少发送方标识。
|
||||
ErrMissingTo = errors.New("protocol: missing to") // 表示消息缺少接收方标识。
|
||||
ErrMissingFileName = errors.New("protocol: missing file name") // 表示 file 消息缺少文件名。
|
||||
ErrUnexpectedFileName = errors.New("protocol: unexpected file name") // 表示 text 消息错误地携带了文件名。
|
||||
ErrInvalidTextBody = errors.New("protocol: invalid text body") // 表示 text 消息正文不是合法 UTF-8。
|
||||
ErrUnexpectedBody = errors.New("protocol: unexpected body") // 表示某些控制消息不允许携带正文。
|
||||
ErrInvalidRegisterTarget = errors.New("protocol: invalid register target") // 表示 register 消息没有发往 server。
|
||||
ErrInvalidErrorSource = errors.New("protocol: invalid error source") // 表示 error 消息不是由 server 发出。
|
||||
ErrInvalidHeaderLength = errors.New("protocol: invalid header length") // 表示 header 长度字段为 0、越界或无法完整切分。
|
||||
ErrInvalidHeaderJSON = errors.New("protocol: invalid header json") // 表示 header JSON 无法解析,可能是格式错误或缺少必要字段。
|
||||
ErrInvalidContentLength = errors.New("protocol: invalid content length") // 表示头部记录的正文长度与实际正文不一致。
|
||||
)
|
||||
|
||||
// 应用层消息:[4字节 frameLength][4字节 headerLen][header JSON(下面自定义的Message头)][body bytes]
|
||||
// 写了 tag:JSON 字段名是你指定的 type;不写 tag:JSON 字段名默认是 Go 字段名 Type
|
||||
type messageHeader struct {
|
||||
Type MessageType `json:"type"`
|
||||
ID uint64 `json:"id"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
ContentLength int `json:"content_length"`
|
||||
}
|
||||
|
||||
// EncodeMessage 将逻辑消息编码为帧内字节格式:
|
||||
// 1. 4 字节大端序 header 长度
|
||||
// 2. header JSON
|
||||
// 3. 原始 body 字节
|
||||
func EncodeMessage(msg Message) ([]byte, error) {
|
||||
if err := validateMessage(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header := messageHeader{
|
||||
Type: msg.Type,
|
||||
ID: msg.ID,
|
||||
From: msg.From,
|
||||
To: msg.To,
|
||||
FileName: msg.FileName,
|
||||
ContentLength: len(msg.Body),
|
||||
}
|
||||
|
||||
headerPayload, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("protocol: encode header: %w", err)
|
||||
}
|
||||
// 创建一个新的字节切片来存储完整的帧内容,避免直接在 headerPayload 上修改导致数据混乱。
|
||||
payload := make([]byte, 4+len(headerPayload)+len(msg.Body))
|
||||
// 在 payload 前 4 字节写入 header 长度,后续内容依次是 header JSON(第五个字节开始) 和 body。
|
||||
binary.BigEndian.PutUint32(payload[:4], uint32(len(headerPayload)))
|
||||
copy(payload[4:], headerPayload)
|
||||
copy(payload[4+len(headerPayload):], msg.Body)
|
||||
|
||||
//检查整个帧长度是否合法,避免上层调用者构造的消息过大导致发送失败。
|
||||
if len(payload) > MaxFrameSize {
|
||||
return nil, ErrFrameTooLarge
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// DecodeMessage 将帧内字节格式还原为 Message。
|
||||
func DecodeMessage(data []byte) (Message, error) {
|
||||
if len(data) > MaxFrameSize {
|
||||
return Message{}, ErrFrameTooLarge
|
||||
}
|
||||
if len(data) < 4 {
|
||||
return Message{}, ErrInvalidHeaderLength
|
||||
}
|
||||
|
||||
headerLen := int(binary.BigEndian.Uint32(data[:4]))
|
||||
if headerLen == 0 || headerLen > len(data)-4 {
|
||||
return Message{}, ErrInvalidHeaderLength
|
||||
}
|
||||
|
||||
headerPayload := data[4 : 4+headerLen]
|
||||
body := data[4+headerLen:]
|
||||
|
||||
var header messageHeader
|
||||
if err := json.Unmarshal(headerPayload, &header); err != nil {
|
||||
return Message{}, fmt.Errorf("protocol: decode header: %w", errors.Join(ErrInvalidHeaderJSON, err))
|
||||
}
|
||||
|
||||
if header.ContentLength < 0 || header.ContentLength != len(body) {
|
||||
return Message{}, ErrInvalidContentLength
|
||||
}
|
||||
|
||||
bodyCopy := make([]byte, len(body))
|
||||
copy(bodyCopy, body)
|
||||
|
||||
msg := Message{
|
||||
Type: header.Type,
|
||||
ID: header.ID,
|
||||
From: header.From,
|
||||
To: header.To,
|
||||
FileName: header.FileName,
|
||||
Body: bodyCopy,
|
||||
}
|
||||
|
||||
if err := validateMessage(msg); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// WriteFrame 向流中写入一个带长度前缀的帧。
|
||||
// TCP帧格式如下:
|
||||
// 1. 4 字节大端序长度
|
||||
// 2. 后续 payload 内容
|
||||
//
|
||||
// TCP 是字节流协议,没有天然的消息边界。
|
||||
// 增加显式长度前缀后,接收方就知道一条完整消息应该读取多少字节,
|
||||
// 从而解决粘包和拆包问题。
|
||||
func WriteFrame(w io.Writer, payload []byte) error {
|
||||
size := len(payload)
|
||||
//空帧
|
||||
if size == 0 {
|
||||
return ErrInvalidFrameLength
|
||||
}
|
||||
//帧过大
|
||||
if size > MaxFrameSize {
|
||||
return ErrFrameTooLarge
|
||||
}
|
||||
|
||||
var header [4]byte
|
||||
binary.BigEndian.PutUint32(header[:], uint32(size))
|
||||
|
||||
// 先写长度头,接收方才能根据长度一次性读取完整消息体。
|
||||
if err := writeFull(w, header[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeFull(w, payload)
|
||||
}
|
||||
|
||||
// ReadFrame 从流中读取一个完整的长度前缀帧。
|
||||
// 它会先读取固定 4 字节长度头,校验长度是否合法,
|
||||
// 再使用 io.ReadFull 按长度读取完整消息体,
|
||||
// 这样即使底层 TCP 发生分段读取,也不会把半条消息暴露给上层。
|
||||
func ReadFrame(r io.Reader) ([]byte, error) {
|
||||
var header [4]byte
|
||||
if _, err := io.ReadFull(r, header[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
size := binary.BigEndian.Uint32(header[:])
|
||||
// 长度为 0 的帧被认为是非法输入,而不是合法的空消息。
|
||||
if size == 0 {
|
||||
return nil, ErrInvalidFrameLength
|
||||
}
|
||||
// 长度超过上限的帧会被拒绝,避免接收方无上限分配内存。
|
||||
if size > MaxFrameSize {
|
||||
return nil, ErrFrameTooLarge
|
||||
}
|
||||
|
||||
payload := make([]byte, int(size))
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// WriteMessage 是给上层直接使用的完整发送路径:
|
||||
// 把一条结构化消息完整编码并发送出去”的总入口。
|
||||
// Message -> header+body -> 长度前缀帧 -> io.Writer。
|
||||
func WriteMessage(w io.Writer, msg Message) error {
|
||||
payload, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("protocol: encode message: %w", err)
|
||||
}
|
||||
|
||||
if err := WriteFrame(w, payload); err != nil {
|
||||
return fmt.Errorf("protocol: write frame: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMessage 是给上层直接使用的完整接收路径:
|
||||
// io.Reader -> 长度前缀帧 -> header+body -> Message。
|
||||
func ReadMessage(r io.Reader) (Message, error) {
|
||||
payload, err := ReadFrame(r)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("protocol: read frame: %w", err)
|
||||
}
|
||||
|
||||
msg, err := DecodeMessage(payload)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("protocol: decode message: %w", err)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// validateMessage 检查 Message 传输的类型(只接受 text 和 file )。
|
||||
func validateMessage(msg Message) error {
|
||||
if msg.From == "" {
|
||||
return ErrMissingFrom
|
||||
}
|
||||
if msg.To == "" {
|
||||
return ErrMissingTo
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case MessageTypeText:
|
||||
if msg.FileName != "" {
|
||||
return ErrUnexpectedFileName
|
||||
}
|
||||
if !utf8.Valid(msg.Body) {
|
||||
return ErrInvalidTextBody
|
||||
}
|
||||
case MessageTypeFile:
|
||||
if msg.FileName == "" {
|
||||
return ErrMissingFileName
|
||||
}
|
||||
case MessageTypeRegister:
|
||||
if msg.To != ServerPeerID {
|
||||
return ErrInvalidRegisterTarget
|
||||
}
|
||||
if msg.FileName != "" {
|
||||
return ErrUnexpectedFileName
|
||||
}
|
||||
if len(msg.Body) != 0 {
|
||||
return ErrUnexpectedBody
|
||||
}
|
||||
case MessageTypeError:
|
||||
if msg.From != ServerPeerID {
|
||||
return ErrInvalidErrorSource
|
||||
}
|
||||
if msg.FileName != "" {
|
||||
return ErrUnexpectedFileName
|
||||
}
|
||||
if !utf8.Valid(msg.Body) {
|
||||
return ErrInvalidTextBody
|
||||
}
|
||||
default:
|
||||
return ErrInvalidMessageType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeFull 会持续写入,直到所有字节都写完或者底层返回错误。
|
||||
// 这样可以避免某些 Writer 发生部分写入时破坏帧格式。
|
||||
func writeFull(w io.Writer, data []byte) error {
|
||||
for len(data) > 0 {
|
||||
n, err := w.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
507
go/cmd/internal/protocol/codec_test.go
Normal file
507
go/cmd/internal/protocol/codec_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
|
||||
func TestEncodeDecodeMessageTextASCII(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 42,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8,
|
||||
// 从而天然兼容 ASCII 之外的普通文本。
|
||||
func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 43,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("你好, world"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
|
||||
func TestEncodeDecodeMessageFile(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 44,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "data.bin",
|
||||
Body: []byte{0x00, 0xff, 0x10, 0x7f},
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
|
||||
func TestEncodeDecodeMessageRegister(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 45,
|
||||
From: "peer-a",
|
||||
To: ServerPeerID,
|
||||
Body: []byte{},
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
|
||||
func TestEncodeDecodeMessageError(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 46,
|
||||
From: ServerPeerID,
|
||||
To: "peer-a",
|
||||
Body: []byte("unknown target"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
|
||||
// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
|
||||
func TestWriteReadFrame(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
payload := []byte("header+body")
|
||||
|
||||
if err := WriteFrame(&buf, payload); err != nil {
|
||||
t.Fatalf("WriteFrame() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadFrame(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame() error = %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(got, payload) {
|
||||
t.Fatalf("payload mismatch: got %q want %q", got, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
|
||||
// 因为外层帧非空的前提下,空正文是合法业务内容。
|
||||
func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message Message
|
||||
}{
|
||||
{
|
||||
name: "empty text",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte(""),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
message: Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "empty.txt",
|
||||
Body: []byte{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := WriteMessage(&buf, tt.message); err != nil {
|
||||
t.Fatalf("WriteMessage() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.message) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
|
||||
func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message Message
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid type",
|
||||
message: Message{
|
||||
Type: MessageType("unknown"),
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrInvalidMessageType,
|
||||
},
|
||||
{
|
||||
name: "missing from",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 2,
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrMissingFrom,
|
||||
},
|
||||
{
|
||||
name: "missing to",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 3,
|
||||
From: "peer-a",
|
||||
},
|
||||
wantErr: ErrMissingTo,
|
||||
},
|
||||
{
|
||||
name: "text with file name",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 4,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "bad.txt",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
wantErr: ErrUnexpectedFileName,
|
||||
},
|
||||
{
|
||||
name: "text with invalid utf8",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 5,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte{0xff, 0xfe},
|
||||
},
|
||||
wantErr: ErrInvalidTextBody,
|
||||
},
|
||||
{
|
||||
name: "file without file name",
|
||||
message: Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 6,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte{0x01},
|
||||
},
|
||||
wantErr: ErrMissingFileName,
|
||||
},
|
||||
{
|
||||
name: "register with wrong target",
|
||||
message: Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrInvalidRegisterTarget,
|
||||
},
|
||||
{
|
||||
name: "register with body",
|
||||
message: Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 8,
|
||||
From: "peer-a",
|
||||
To: ServerPeerID,
|
||||
Body: []byte("unexpected"),
|
||||
},
|
||||
wantErr: ErrUnexpectedBody,
|
||||
},
|
||||
{
|
||||
name: "error with wrong source",
|
||||
message: Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 9,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("bad"),
|
||||
},
|
||||
wantErr: ErrInvalidErrorSource,
|
||||
},
|
||||
{
|
||||
name: "error with file name",
|
||||
message: Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 10,
|
||||
From: ServerPeerID,
|
||||
To: "peer-a",
|
||||
FileName: "bad.txt",
|
||||
Body: []byte("bad"),
|
||||
},
|
||||
wantErr: ErrUnexpectedFileName,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := EncodeMessage(tt.message)
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
|
||||
// 而不是被当成一条合法的空消息。
|
||||
func TestReadFrameRejectsInvalidLength(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadFrame(&buf)
|
||||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
|
||||
// 从而保证最大长度限制真正生效。
|
||||
func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadFrame(&buf)
|
||||
if !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
|
||||
// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
|
||||
func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
err := WriteFrame(&buf, nil)
|
||||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||||
t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
|
||||
func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "too short for header len",
|
||||
data: []byte{0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
name: "zero header len",
|
||||
data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
name: "header len exceeds payload",
|
||||
data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := DecodeMessage(tt.data)
|
||||
if !errors.Is(err, ErrInvalidHeaderLength) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
|
||||
func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
|
||||
data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
|
||||
|
||||
_, err := DecodeMessage(data)
|
||||
if !errors.Is(err, ErrInvalidHeaderJSON) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
|
||||
func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
|
||||
headerPayload, err := json.Marshal(messageHeader{
|
||||
Type: MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
ContentLength: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
var data bytes.Buffer
|
||||
if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
if _, err := data.Write(headerPayload); err != nil {
|
||||
t.Fatalf("data.Write(headerPayload) error = %v", err)
|
||||
}
|
||||
if _, err := data.Write([]byte("hello")); err != nil {
|
||||
t.Fatalf("data.Write(body) error = %v", err)
|
||||
}
|
||||
|
||||
_, err = DecodeMessage(data.Bytes())
|
||||
if !errors.Is(err, ErrInvalidContentLength) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file,
|
||||
// 验证读取端每次都能严格停在当前帧边界,不会串包。
|
||||
func TestReadMultipleMessages(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
first := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
second := Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-b",
|
||||
To: "peer-a",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
|
||||
if err := WriteMessage(&buf, first); err != nil {
|
||||
t.Fatalf("WriteMessage(first) error = %v", err)
|
||||
}
|
||||
if err := WriteMessage(&buf, second); err != nil {
|
||||
t.Fatalf("WriteMessage(second) error = %v", err)
|
||||
}
|
||||
|
||||
gotFirst, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage(first) error = %v", err)
|
||||
}
|
||||
gotSecond, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage(second) error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotFirst, first) {
|
||||
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
|
||||
}
|
||||
if !reflect.DeepEqual(gotSecond, second) {
|
||||
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
|
||||
func TestReadMessageWrapsDecodeError(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
|
||||
t.Fatalf("WriteFrame() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadMessage(&buf)
|
||||
if err == nil {
|
||||
t.Fatal("ReadMessage() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "decode message") {
|
||||
t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
|
||||
}
|
||||
}
|
||||
33
go/cmd/internal/protocol/message.go
Normal file
33
go/cmd/internal/protocol/message.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package protocol
|
||||
|
||||
// MessageType 表示一条消息的传输类型。
|
||||
// v1 只区分普通文本和文件两类负载。
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// MessageTypeText 表示正文按 UTF-8 文本解释,天然兼容 ASCII。
|
||||
MessageTypeText MessageType = "text"
|
||||
// MessageTypeFile 表示正文是原始文件字节。
|
||||
MessageTypeFile MessageType = "file"
|
||||
// MessageTypeRegister 表示 peer 向 server 显式注册自己的身份。
|
||||
MessageTypeRegister MessageType = "register"
|
||||
// MessageTypeError 表示 server 向 peer 返回错误信息。
|
||||
MessageTypeError MessageType = "error"
|
||||
)
|
||||
|
||||
// ServerPeerID 是协议中约定的 server 端固定标识。
|
||||
const ServerPeerID = "server"
|
||||
|
||||
// Message 是 peer 和 server 共用的传输消息结构。
|
||||
// 头部元信息会被编码为 JSON,Body 则作为原始字节拼接在头部之后。
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"` // 消息类型,只允许 text 或 file。
|
||||
ID uint64 `json:"id"` // 由发送方生成,用于追踪消息。
|
||||
From string `json:"from"` // 发送方标识。
|
||||
To string `json:"to"` // 接收方标识。
|
||||
|
||||
// FileName 仅在 Type 为 file 时使用。
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
// Body 是真正传输的正文内容,不进入头部 JSON。
|
||||
Body []byte `json:"-"`
|
||||
}
|
||||
Reference in New Issue
Block a user