This commit is contained in:
nnbcccscdscdsc
2026-03-23 20:18:53 +08:00
commit 4824675244
28 changed files with 5569 additions and 0 deletions

View File

@@ -0,0 +1,507 @@
package protocol
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"reflect"
"strings"
"testing"
)
// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
func TestEncodeDecodeMessageTextASCII(t *testing.T) {
original := Message{
Type: MessageTypeText,
ID: 42,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8
// 从而天然兼容 ASCII 之外的普通文本。
func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
original := Message{
Type: MessageTypeText,
ID: 43,
From: "peer-a",
To: "peer-b",
Body: []byte("你好, world"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
func TestEncodeDecodeMessageFile(t *testing.T) {
original := Message{
Type: MessageTypeFile,
ID: 44,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0xff, 0x10, 0x7f},
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
func TestEncodeDecodeMessageRegister(t *testing.T) {
original := Message{
Type: MessageTypeRegister,
ID: 45,
From: "peer-a",
To: ServerPeerID,
Body: []byte{},
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
func TestEncodeDecodeMessageError(t *testing.T) {
original := Message{
Type: MessageTypeError,
ID: 46,
From: ServerPeerID,
To: "peer-a",
Body: []byte("unknown target"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
func TestWriteReadFrame(t *testing.T) {
var buf bytes.Buffer
payload := []byte("header+body")
if err := WriteFrame(&buf, payload); err != nil {
t.Fatalf("WriteFrame() error = %v", err)
}
got, err := ReadFrame(&buf)
if err != nil {
t.Fatalf("ReadFrame() error = %v", err)
}
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch: got %q want %q", got, payload)
}
}
// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
// 因为外层帧非空的前提下,空正文是合法业务内容。
func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
tests := []struct {
name string
message Message
}{
{
name: "empty text",
message: Message{
Type: MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte(""),
},
},
{
name: "empty file",
message: Message{
Type: MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "empty.txt",
Body: []byte{},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
if err := WriteMessage(&buf, tt.message); err != nil {
t.Fatalf("WriteMessage() error = %v", err)
}
got, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage() error = %v", err)
}
if !reflect.DeepEqual(got, tt.message) {
t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
}
})
}
}
// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
tests := []struct {
name string
message Message
wantErr error
}{
{
name: "invalid type",
message: Message{
Type: MessageType("unknown"),
ID: 1,
From: "peer-a",
To: "peer-b",
},
wantErr: ErrInvalidMessageType,
},
{
name: "missing from",
message: Message{
Type: MessageTypeText,
ID: 2,
To: "peer-b",
},
wantErr: ErrMissingFrom,
},
{
name: "missing to",
message: Message{
Type: MessageTypeText,
ID: 3,
From: "peer-a",
},
wantErr: ErrMissingTo,
},
{
name: "text with file name",
message: Message{
Type: MessageTypeText,
ID: 4,
From: "peer-a",
To: "peer-b",
FileName: "bad.txt",
Body: []byte("hello"),
},
wantErr: ErrUnexpectedFileName,
},
{
name: "text with invalid utf8",
message: Message{
Type: MessageTypeText,
ID: 5,
From: "peer-a",
To: "peer-b",
Body: []byte{0xff, 0xfe},
},
wantErr: ErrInvalidTextBody,
},
{
name: "file without file name",
message: Message{
Type: MessageTypeFile,
ID: 6,
From: "peer-a",
To: "peer-b",
Body: []byte{0x01},
},
wantErr: ErrMissingFileName,
},
{
name: "register with wrong target",
message: Message{
Type: MessageTypeRegister,
ID: 7,
From: "peer-a",
To: "peer-b",
},
wantErr: ErrInvalidRegisterTarget,
},
{
name: "register with body",
message: Message{
Type: MessageTypeRegister,
ID: 8,
From: "peer-a",
To: ServerPeerID,
Body: []byte("unexpected"),
},
wantErr: ErrUnexpectedBody,
},
{
name: "error with wrong source",
message: Message{
Type: MessageTypeError,
ID: 9,
From: "peer-a",
To: "peer-b",
Body: []byte("bad"),
},
wantErr: ErrInvalidErrorSource,
},
{
name: "error with file name",
message: Message{
Type: MessageTypeError,
ID: 10,
From: ServerPeerID,
To: "peer-a",
FileName: "bad.txt",
Body: []byte("bad"),
},
wantErr: ErrUnexpectedFileName,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := EncodeMessage(tt.message)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
}
})
}
}
// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
// 而不是被当成一条合法的空消息。
func TestReadFrameRejectsInvalidLength(t *testing.T) {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
_, err := ReadFrame(&buf)
if !errors.Is(err, ErrInvalidFrameLength) {
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
}
}
// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
// 从而保证最大长度限制真正生效。
func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
_, err := ReadFrame(&buf)
if !errors.Is(err, ErrFrameTooLarge) {
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
}
}
// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
var buf bytes.Buffer
err := WriteFrame(&buf, nil)
if !errors.Is(err, ErrInvalidFrameLength) {
t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
}
}
// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
tests := []struct {
name string
data []byte
}{
{
name: "too short for header len",
data: []byte{0x00, 0x00, 0x00},
},
{
name: "zero header len",
data: []byte{0x00, 0x00, 0x00, 0x00},
},
{
name: "header len exceeds payload",
data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := DecodeMessage(tt.data)
if !errors.Is(err, ErrInvalidHeaderLength) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
}
})
}
}
// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
_, err := DecodeMessage(data)
if !errors.Is(err, ErrInvalidHeaderJSON) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
}
}
// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
headerPayload, err := json.Marshal(messageHeader{
Type: MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
ContentLength: 10,
})
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
var data bytes.Buffer
if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
if _, err := data.Write(headerPayload); err != nil {
t.Fatalf("data.Write(headerPayload) error = %v", err)
}
if _, err := data.Write([]byte("hello")); err != nil {
t.Fatalf("data.Write(body) error = %v", err)
}
_, err = DecodeMessage(data.Bytes())
if !errors.Is(err, ErrInvalidContentLength) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
}
}
// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file
// 验证读取端每次都能严格停在当前帧边界,不会串包。
func TestReadMultipleMessages(t *testing.T) {
var buf bytes.Buffer
first := Message{
Type: MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
second := Message{
Type: MessageTypeFile,
ID: 2,
From: "peer-b",
To: "peer-a",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
}
if err := WriteMessage(&buf, first); err != nil {
t.Fatalf("WriteMessage(first) error = %v", err)
}
if err := WriteMessage(&buf, second); err != nil {
t.Fatalf("WriteMessage(second) error = %v", err)
}
gotFirst, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage(first) error = %v", err)
}
gotSecond, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage(second) error = %v", err)
}
if !reflect.DeepEqual(gotFirst, first) {
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
}
if !reflect.DeepEqual(gotSecond, second) {
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
}
}
// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
func TestReadMessageWrapsDecodeError(t *testing.T) {
var buf bytes.Buffer
if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
t.Fatalf("WriteFrame() error = %v", err)
}
_, err := ReadMessage(&buf)
if err == nil {
t.Fatal("ReadMessage() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "decode message") {
t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
}
}