init
This commit is contained in:
507
cmd/internal/protocol/codec_test.go
Normal file
507
cmd/internal/protocol/codec_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
|
||||
func TestEncodeDecodeMessageTextASCII(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 42,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8,
|
||||
// 从而天然兼容 ASCII 之外的普通文本。
|
||||
func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 43,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("你好, world"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
|
||||
func TestEncodeDecodeMessageFile(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 44,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "data.bin",
|
||||
Body: []byte{0x00, 0xff, 0x10, 0x7f},
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
|
||||
func TestEncodeDecodeMessageRegister(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 45,
|
||||
From: "peer-a",
|
||||
To: ServerPeerID,
|
||||
Body: []byte{},
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
|
||||
func TestEncodeDecodeMessageError(t *testing.T) {
|
||||
original := Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 46,
|
||||
From: ServerPeerID,
|
||||
To: "peer-a",
|
||||
Body: []byte("unknown target"),
|
||||
}
|
||||
|
||||
data, err := EncodeMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
decoded, err := DecodeMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(decoded, original) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
|
||||
// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
|
||||
func TestWriteReadFrame(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
payload := []byte("header+body")
|
||||
|
||||
if err := WriteFrame(&buf, payload); err != nil {
|
||||
t.Fatalf("WriteFrame() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadFrame(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame() error = %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(got, payload) {
|
||||
t.Fatalf("payload mismatch: got %q want %q", got, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
|
||||
// 因为外层帧非空的前提下,空正文是合法业务内容。
|
||||
func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message Message
|
||||
}{
|
||||
{
|
||||
name: "empty text",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte(""),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
message: Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "empty.txt",
|
||||
Body: []byte{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := WriteMessage(&buf, tt.message); err != nil {
|
||||
t.Fatalf("WriteMessage() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.message) {
|
||||
t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
|
||||
func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message Message
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid type",
|
||||
message: Message{
|
||||
Type: MessageType("unknown"),
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrInvalidMessageType,
|
||||
},
|
||||
{
|
||||
name: "missing from",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 2,
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrMissingFrom,
|
||||
},
|
||||
{
|
||||
name: "missing to",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 3,
|
||||
From: "peer-a",
|
||||
},
|
||||
wantErr: ErrMissingTo,
|
||||
},
|
||||
{
|
||||
name: "text with file name",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 4,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
FileName: "bad.txt",
|
||||
Body: []byte("hello"),
|
||||
},
|
||||
wantErr: ErrUnexpectedFileName,
|
||||
},
|
||||
{
|
||||
name: "text with invalid utf8",
|
||||
message: Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 5,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte{0xff, 0xfe},
|
||||
},
|
||||
wantErr: ErrInvalidTextBody,
|
||||
},
|
||||
{
|
||||
name: "file without file name",
|
||||
message: Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 6,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte{0x01},
|
||||
},
|
||||
wantErr: ErrMissingFileName,
|
||||
},
|
||||
{
|
||||
name: "register with wrong target",
|
||||
message: Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
},
|
||||
wantErr: ErrInvalidRegisterTarget,
|
||||
},
|
||||
{
|
||||
name: "register with body",
|
||||
message: Message{
|
||||
Type: MessageTypeRegister,
|
||||
ID: 8,
|
||||
From: "peer-a",
|
||||
To: ServerPeerID,
|
||||
Body: []byte("unexpected"),
|
||||
},
|
||||
wantErr: ErrUnexpectedBody,
|
||||
},
|
||||
{
|
||||
name: "error with wrong source",
|
||||
message: Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 9,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("bad"),
|
||||
},
|
||||
wantErr: ErrInvalidErrorSource,
|
||||
},
|
||||
{
|
||||
name: "error with file name",
|
||||
message: Message{
|
||||
Type: MessageTypeError,
|
||||
ID: 10,
|
||||
From: ServerPeerID,
|
||||
To: "peer-a",
|
||||
FileName: "bad.txt",
|
||||
Body: []byte("bad"),
|
||||
},
|
||||
wantErr: ErrUnexpectedFileName,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := EncodeMessage(tt.message)
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
|
||||
// 而不是被当成一条合法的空消息。
|
||||
func TestReadFrameRejectsInvalidLength(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadFrame(&buf)
|
||||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
|
||||
// 从而保证最大长度限制真正生效。
|
||||
func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadFrame(&buf)
|
||||
if !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
|
||||
// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
|
||||
func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
err := WriteFrame(&buf, nil)
|
||||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||||
t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
|
||||
func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "too short for header len",
|
||||
data: []byte{0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
name: "zero header len",
|
||||
data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
name: "header len exceeds payload",
|
||||
data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := DecodeMessage(tt.data)
|
||||
if !errors.Is(err, ErrInvalidHeaderLength) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
|
||||
func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
|
||||
data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
|
||||
|
||||
_, err := DecodeMessage(data)
|
||||
if !errors.Is(err, ErrInvalidHeaderJSON) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
|
||||
func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
|
||||
headerPayload, err := json.Marshal(messageHeader{
|
||||
Type: MessageTypeText,
|
||||
ID: 7,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
ContentLength: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
var data bytes.Buffer
|
||||
if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
|
||||
t.Fatalf("binary.Write() error = %v", err)
|
||||
}
|
||||
if _, err := data.Write(headerPayload); err != nil {
|
||||
t.Fatalf("data.Write(headerPayload) error = %v", err)
|
||||
}
|
||||
if _, err := data.Write([]byte("hello")); err != nil {
|
||||
t.Fatalf("data.Write(body) error = %v", err)
|
||||
}
|
||||
|
||||
_, err = DecodeMessage(data.Bytes())
|
||||
if !errors.Is(err, ErrInvalidContentLength) {
|
||||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file,
|
||||
// 验证读取端每次都能严格停在当前帧边界,不会串包。
|
||||
func TestReadMultipleMessages(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
first := Message{
|
||||
Type: MessageTypeText,
|
||||
ID: 1,
|
||||
From: "peer-a",
|
||||
To: "peer-b",
|
||||
Body: []byte("hello"),
|
||||
}
|
||||
|
||||
second := Message{
|
||||
Type: MessageTypeFile,
|
||||
ID: 2,
|
||||
From: "peer-b",
|
||||
To: "peer-a",
|
||||
FileName: "payload.bin",
|
||||
Body: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
|
||||
if err := WriteMessage(&buf, first); err != nil {
|
||||
t.Fatalf("WriteMessage(first) error = %v", err)
|
||||
}
|
||||
if err := WriteMessage(&buf, second); err != nil {
|
||||
t.Fatalf("WriteMessage(second) error = %v", err)
|
||||
}
|
||||
|
||||
gotFirst, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage(first) error = %v", err)
|
||||
}
|
||||
gotSecond, err := ReadMessage(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage(second) error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotFirst, first) {
|
||||
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
|
||||
}
|
||||
if !reflect.DeepEqual(gotSecond, second) {
|
||||
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
|
||||
func TestReadMessageWrapsDecodeError(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
|
||||
t.Fatalf("WriteFrame() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := ReadMessage(&buf)
|
||||
if err == nil {
|
||||
t.Fatal("ReadMessage() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "decode message") {
|
||||
t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user