508 lines
12 KiB
Go
508 lines
12 KiB
Go
package protocol
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/binary"
|
||
"encoding/json"
|
||
"errors"
|
||
"reflect"
|
||
"strings"
|
||
"testing"
|
||
)
|
||
|
||
// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
|
||
func TestEncodeDecodeMessageTextASCII(t *testing.T) {
|
||
original := Message{
|
||
Type: MessageTypeText,
|
||
ID: 42,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte("hello"),
|
||
}
|
||
|
||
data, err := EncodeMessage(original)
|
||
if err != nil {
|
||
t.Fatalf("EncodeMessage() error = %v", err)
|
||
}
|
||
|
||
decoded, err := DecodeMessage(data)
|
||
if err != nil {
|
||
t.Fatalf("DecodeMessage() error = %v", err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(decoded, original) {
|
||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||
}
|
||
}
|
||
|
||
// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8,
|
||
// 从而天然兼容 ASCII 之外的普通文本。
|
||
func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
|
||
original := Message{
|
||
Type: MessageTypeText,
|
||
ID: 43,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte("你好, world"),
|
||
}
|
||
|
||
data, err := EncodeMessage(original)
|
||
if err != nil {
|
||
t.Fatalf("EncodeMessage() error = %v", err)
|
||
}
|
||
|
||
decoded, err := DecodeMessage(data)
|
||
if err != nil {
|
||
t.Fatalf("DecodeMessage() error = %v", err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(decoded, original) {
|
||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||
}
|
||
}
|
||
|
||
// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
|
||
func TestEncodeDecodeMessageFile(t *testing.T) {
|
||
original := Message{
|
||
Type: MessageTypeFile,
|
||
ID: 44,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
FileName: "data.bin",
|
||
Body: []byte{0x00, 0xff, 0x10, 0x7f},
|
||
}
|
||
|
||
data, err := EncodeMessage(original)
|
||
if err != nil {
|
||
t.Fatalf("EncodeMessage() error = %v", err)
|
||
}
|
||
|
||
decoded, err := DecodeMessage(data)
|
||
if err != nil {
|
||
t.Fatalf("DecodeMessage() error = %v", err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(decoded, original) {
|
||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||
}
|
||
}
|
||
|
||
// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
|
||
func TestEncodeDecodeMessageRegister(t *testing.T) {
|
||
original := Message{
|
||
Type: MessageTypeRegister,
|
||
ID: 45,
|
||
From: "peer-a",
|
||
To: ServerPeerID,
|
||
Body: []byte{},
|
||
}
|
||
|
||
data, err := EncodeMessage(original)
|
||
if err != nil {
|
||
t.Fatalf("EncodeMessage() error = %v", err)
|
||
}
|
||
|
||
decoded, err := DecodeMessage(data)
|
||
if err != nil {
|
||
t.Fatalf("DecodeMessage() error = %v", err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(decoded, original) {
|
||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||
}
|
||
}
|
||
|
||
// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
|
||
func TestEncodeDecodeMessageError(t *testing.T) {
|
||
original := Message{
|
||
Type: MessageTypeError,
|
||
ID: 46,
|
||
From: ServerPeerID,
|
||
To: "peer-a",
|
||
Body: []byte("unknown target"),
|
||
}
|
||
|
||
data, err := EncodeMessage(original)
|
||
if err != nil {
|
||
t.Fatalf("EncodeMessage() error = %v", err)
|
||
}
|
||
|
||
decoded, err := DecodeMessage(data)
|
||
if err != nil {
|
||
t.Fatalf("DecodeMessage() error = %v", err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(decoded, original) {
|
||
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
|
||
}
|
||
}
|
||
|
||
// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
|
||
// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
|
||
func TestWriteReadFrame(t *testing.T) {
|
||
var buf bytes.Buffer
|
||
payload := []byte("header+body")
|
||
|
||
if err := WriteFrame(&buf, payload); err != nil {
|
||
t.Fatalf("WriteFrame() error = %v", err)
|
||
}
|
||
|
||
got, err := ReadFrame(&buf)
|
||
if err != nil {
|
||
t.Fatalf("ReadFrame() error = %v", err)
|
||
}
|
||
|
||
if !bytes.Equal(got, payload) {
|
||
t.Fatalf("payload mismatch: got %q want %q", got, payload)
|
||
}
|
||
}
|
||
|
||
// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
|
||
// 因为外层帧非空的前提下,空正文是合法业务内容。
|
||
func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
message Message
|
||
}{
|
||
{
|
||
name: "empty text",
|
||
message: Message{
|
||
Type: MessageTypeText,
|
||
ID: 1,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte(""),
|
||
},
|
||
},
|
||
{
|
||
name: "empty file",
|
||
message: Message{
|
||
Type: MessageTypeFile,
|
||
ID: 2,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
FileName: "empty.txt",
|
||
Body: []byte{},
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
var buf bytes.Buffer
|
||
|
||
if err := WriteMessage(&buf, tt.message); err != nil {
|
||
t.Fatalf("WriteMessage() error = %v", err)
|
||
}
|
||
|
||
got, err := ReadMessage(&buf)
|
||
if err != nil {
|
||
t.Fatalf("ReadMessage() error = %v", err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(got, tt.message) {
|
||
t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
|
||
func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
message Message
|
||
wantErr error
|
||
}{
|
||
{
|
||
name: "invalid type",
|
||
message: Message{
|
||
Type: MessageType("unknown"),
|
||
ID: 1,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
},
|
||
wantErr: ErrInvalidMessageType,
|
||
},
|
||
{
|
||
name: "missing from",
|
||
message: Message{
|
||
Type: MessageTypeText,
|
||
ID: 2,
|
||
To: "peer-b",
|
||
},
|
||
wantErr: ErrMissingFrom,
|
||
},
|
||
{
|
||
name: "missing to",
|
||
message: Message{
|
||
Type: MessageTypeText,
|
||
ID: 3,
|
||
From: "peer-a",
|
||
},
|
||
wantErr: ErrMissingTo,
|
||
},
|
||
{
|
||
name: "text with file name",
|
||
message: Message{
|
||
Type: MessageTypeText,
|
||
ID: 4,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
FileName: "bad.txt",
|
||
Body: []byte("hello"),
|
||
},
|
||
wantErr: ErrUnexpectedFileName,
|
||
},
|
||
{
|
||
name: "text with invalid utf8",
|
||
message: Message{
|
||
Type: MessageTypeText,
|
||
ID: 5,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte{0xff, 0xfe},
|
||
},
|
||
wantErr: ErrInvalidTextBody,
|
||
},
|
||
{
|
||
name: "file without file name",
|
||
message: Message{
|
||
Type: MessageTypeFile,
|
||
ID: 6,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte{0x01},
|
||
},
|
||
wantErr: ErrMissingFileName,
|
||
},
|
||
{
|
||
name: "register with wrong target",
|
||
message: Message{
|
||
Type: MessageTypeRegister,
|
||
ID: 7,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
},
|
||
wantErr: ErrInvalidRegisterTarget,
|
||
},
|
||
{
|
||
name: "register with body",
|
||
message: Message{
|
||
Type: MessageTypeRegister,
|
||
ID: 8,
|
||
From: "peer-a",
|
||
To: ServerPeerID,
|
||
Body: []byte("unexpected"),
|
||
},
|
||
wantErr: ErrUnexpectedBody,
|
||
},
|
||
{
|
||
name: "error with wrong source",
|
||
message: Message{
|
||
Type: MessageTypeError,
|
||
ID: 9,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte("bad"),
|
||
},
|
||
wantErr: ErrInvalidErrorSource,
|
||
},
|
||
{
|
||
name: "error with file name",
|
||
message: Message{
|
||
Type: MessageTypeError,
|
||
ID: 10,
|
||
From: ServerPeerID,
|
||
To: "peer-a",
|
||
FileName: "bad.txt",
|
||
Body: []byte("bad"),
|
||
},
|
||
wantErr: ErrUnexpectedFileName,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
_, err := EncodeMessage(tt.message)
|
||
if !errors.Is(err, tt.wantErr) {
|
||
t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
|
||
// 而不是被当成一条合法的空消息。
|
||
func TestReadFrameRejectsInvalidLength(t *testing.T) {
|
||
var buf bytes.Buffer
|
||
|
||
if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
|
||
t.Fatalf("binary.Write() error = %v", err)
|
||
}
|
||
|
||
_, err := ReadFrame(&buf)
|
||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||
}
|
||
}
|
||
|
||
// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
|
||
// 从而保证最大长度限制真正生效。
|
||
func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
|
||
var buf bytes.Buffer
|
||
|
||
if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
|
||
t.Fatalf("binary.Write() error = %v", err)
|
||
}
|
||
|
||
_, err := ReadFrame(&buf)
|
||
if !errors.Is(err, ErrFrameTooLarge) {
|
||
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
|
||
}
|
||
}
|
||
|
||
// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
|
||
// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
|
||
func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
|
||
var buf bytes.Buffer
|
||
|
||
err := WriteFrame(&buf, nil)
|
||
if !errors.Is(err, ErrInvalidFrameLength) {
|
||
t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
|
||
}
|
||
}
|
||
|
||
// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
|
||
func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
data []byte
|
||
}{
|
||
{
|
||
name: "too short for header len",
|
||
data: []byte{0x00, 0x00, 0x00},
|
||
},
|
||
{
|
||
name: "zero header len",
|
||
data: []byte{0x00, 0x00, 0x00, 0x00},
|
||
},
|
||
{
|
||
name: "header len exceeds payload",
|
||
data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
_, err := DecodeMessage(tt.data)
|
||
if !errors.Is(err, ErrInvalidHeaderLength) {
|
||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
|
||
func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
|
||
data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
|
||
|
||
_, err := DecodeMessage(data)
|
||
if !errors.Is(err, ErrInvalidHeaderJSON) {
|
||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
|
||
}
|
||
}
|
||
|
||
// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
|
||
func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
|
||
headerPayload, err := json.Marshal(messageHeader{
|
||
Type: MessageTypeText,
|
||
ID: 7,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
ContentLength: 10,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("json.Marshal() error = %v", err)
|
||
}
|
||
|
||
var data bytes.Buffer
|
||
if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
|
||
t.Fatalf("binary.Write() error = %v", err)
|
||
}
|
||
if _, err := data.Write(headerPayload); err != nil {
|
||
t.Fatalf("data.Write(headerPayload) error = %v", err)
|
||
}
|
||
if _, err := data.Write([]byte("hello")); err != nil {
|
||
t.Fatalf("data.Write(body) error = %v", err)
|
||
}
|
||
|
||
_, err = DecodeMessage(data.Bytes())
|
||
if !errors.Is(err, ErrInvalidContentLength) {
|
||
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
|
||
}
|
||
}
|
||
|
||
// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file,
|
||
// 验证读取端每次都能严格停在当前帧边界,不会串包。
|
||
func TestReadMultipleMessages(t *testing.T) {
|
||
var buf bytes.Buffer
|
||
|
||
first := Message{
|
||
Type: MessageTypeText,
|
||
ID: 1,
|
||
From: "peer-a",
|
||
To: "peer-b",
|
||
Body: []byte("hello"),
|
||
}
|
||
|
||
second := Message{
|
||
Type: MessageTypeFile,
|
||
ID: 2,
|
||
From: "peer-b",
|
||
To: "peer-a",
|
||
FileName: "payload.bin",
|
||
Body: []byte{0x01, 0x02, 0x03},
|
||
}
|
||
|
||
if err := WriteMessage(&buf, first); err != nil {
|
||
t.Fatalf("WriteMessage(first) error = %v", err)
|
||
}
|
||
if err := WriteMessage(&buf, second); err != nil {
|
||
t.Fatalf("WriteMessage(second) error = %v", err)
|
||
}
|
||
|
||
gotFirst, err := ReadMessage(&buf)
|
||
if err != nil {
|
||
t.Fatalf("ReadMessage(first) error = %v", err)
|
||
}
|
||
gotSecond, err := ReadMessage(&buf)
|
||
if err != nil {
|
||
t.Fatalf("ReadMessage(second) error = %v", err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(gotFirst, first) {
|
||
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
|
||
}
|
||
if !reflect.DeepEqual(gotSecond, second) {
|
||
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
|
||
}
|
||
}
|
||
|
||
// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
|
||
func TestReadMessageWrapsDecodeError(t *testing.T) {
|
||
var buf bytes.Buffer
|
||
|
||
if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
|
||
t.Fatalf("WriteFrame() error = %v", err)
|
||
}
|
||
|
||
_, err := ReadMessage(&buf)
|
||
if err == nil {
|
||
t.Fatal("ReadMessage() error = nil, want non-nil")
|
||
}
|
||
if !strings.Contains(err.Error(), "decode message") {
|
||
t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
|
||
}
|
||
}
|