Files
OmniSocketGo/go/cmd/internal/protocol/codec_test.go

508 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package protocol
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"reflect"
"strings"
"testing"
)
// TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。
func TestEncodeDecodeMessageTextASCII(t *testing.T) {
original := Message{
Type: MessageTypeText,
ID: 42,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8
// 从而天然兼容 ASCII 之外的普通文本。
func TestEncodeDecodeMessageTextUTF8(t *testing.T) {
original := Message{
Type: MessageTypeText,
ID: 43,
From: "peer-a",
To: "peer-b",
Body: []byte("你好, world"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。
func TestEncodeDecodeMessageFile(t *testing.T) {
original := Message{
Type: MessageTypeFile,
ID: 44,
From: "peer-a",
To: "peer-b",
FileName: "data.bin",
Body: []byte{0x00, 0xff, 0x10, 0x7f},
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。
func TestEncodeDecodeMessageRegister(t *testing.T) {
original := Message{
Type: MessageTypeRegister,
ID: 45,
From: "peer-a",
To: ServerPeerID,
Body: []byte{},
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。
func TestEncodeDecodeMessageError(t *testing.T) {
original := Message{
Type: MessageTypeError,
ID: 46,
From: ServerPeerID,
To: "peer-a",
Body: []byte("unknown target"),
}
data, err := EncodeMessage(original)
if err != nil {
t.Fatalf("EncodeMessage() error = %v", err)
}
decoded, err := DecodeMessage(data)
if err != nil {
t.Fatalf("DecodeMessage() error = %v", err)
}
if !reflect.DeepEqual(decoded, original) {
t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original)
}
}
// TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑,
// 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。
func TestWriteReadFrame(t *testing.T) {
var buf bytes.Buffer
payload := []byte("header+body")
if err := WriteFrame(&buf, payload); err != nil {
t.Fatalf("WriteFrame() error = %v", err)
}
got, err := ReadFrame(&buf)
if err != nil {
t.Fatalf("ReadFrame() error = %v", err)
}
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch: got %q want %q", got, payload)
}
}
// TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层,
// 因为外层帧非空的前提下,空正文是合法业务内容。
func TestWriteReadMessageAllowsEmptyBody(t *testing.T) {
tests := []struct {
name string
message Message
}{
{
name: "empty text",
message: Message{
Type: MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte(""),
},
},
{
name: "empty file",
message: Message{
Type: MessageTypeFile,
ID: 2,
From: "peer-a",
To: "peer-b",
FileName: "empty.txt",
Body: []byte{},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
if err := WriteMessage(&buf, tt.message); err != nil {
t.Fatalf("WriteMessage() error = %v", err)
}
got, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage() error = %v", err)
}
if !reflect.DeepEqual(got, tt.message) {
t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message)
}
})
}
}
// TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。
func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) {
tests := []struct {
name string
message Message
wantErr error
}{
{
name: "invalid type",
message: Message{
Type: MessageType("unknown"),
ID: 1,
From: "peer-a",
To: "peer-b",
},
wantErr: ErrInvalidMessageType,
},
{
name: "missing from",
message: Message{
Type: MessageTypeText,
ID: 2,
To: "peer-b",
},
wantErr: ErrMissingFrom,
},
{
name: "missing to",
message: Message{
Type: MessageTypeText,
ID: 3,
From: "peer-a",
},
wantErr: ErrMissingTo,
},
{
name: "text with file name",
message: Message{
Type: MessageTypeText,
ID: 4,
From: "peer-a",
To: "peer-b",
FileName: "bad.txt",
Body: []byte("hello"),
},
wantErr: ErrUnexpectedFileName,
},
{
name: "text with invalid utf8",
message: Message{
Type: MessageTypeText,
ID: 5,
From: "peer-a",
To: "peer-b",
Body: []byte{0xff, 0xfe},
},
wantErr: ErrInvalidTextBody,
},
{
name: "file without file name",
message: Message{
Type: MessageTypeFile,
ID: 6,
From: "peer-a",
To: "peer-b",
Body: []byte{0x01},
},
wantErr: ErrMissingFileName,
},
{
name: "register with wrong target",
message: Message{
Type: MessageTypeRegister,
ID: 7,
From: "peer-a",
To: "peer-b",
},
wantErr: ErrInvalidRegisterTarget,
},
{
name: "register with body",
message: Message{
Type: MessageTypeRegister,
ID: 8,
From: "peer-a",
To: ServerPeerID,
Body: []byte("unexpected"),
},
wantErr: ErrUnexpectedBody,
},
{
name: "error with wrong source",
message: Message{
Type: MessageTypeError,
ID: 9,
From: "peer-a",
To: "peer-b",
Body: []byte("bad"),
},
wantErr: ErrInvalidErrorSource,
},
{
name: "error with file name",
message: Message{
Type: MessageTypeError,
ID: 10,
From: ServerPeerID,
To: "peer-a",
FileName: "bad.txt",
Body: []byte("bad"),
},
wantErr: ErrUnexpectedFileName,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := EncodeMessage(tt.message)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr)
}
})
}
}
// TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入,
// 而不是被当成一条合法的空消息。
func TestReadFrameRejectsInvalidLength(t *testing.T) {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
_, err := ReadFrame(&buf)
if !errors.Is(err, ErrInvalidFrameLength) {
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength)
}
}
// TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝,
// 从而保证最大长度限制真正生效。
func TestReadFrameRejectsTooLargeFrame(t *testing.T) {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
_, err := ReadFrame(&buf)
if !errors.Is(err, ErrFrameTooLarge) {
t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge)
}
}
// TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致:
// 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。
func TestWriteFrameRejectsEmptyPayload(t *testing.T) {
var buf bytes.Buffer
err := WriteFrame(&buf, nil)
if !errors.Is(err, ErrInvalidFrameLength) {
t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength)
}
}
// TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。
func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) {
tests := []struct {
name string
data []byte
}{
{
name: "too short for header len",
data: []byte{0x00, 0x00, 0x00},
},
{
name: "zero header len",
data: []byte{0x00, 0x00, 0x00, 0x00},
},
{
name: "header len exceeds payload",
data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := DecodeMessage(tt.data)
if !errors.Is(err, ErrInvalidHeaderLength) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength)
}
})
}
}
// TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。
func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) {
data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)
_, err := DecodeMessage(data)
if !errors.Is(err, ErrInvalidHeaderJSON) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON)
}
}
// TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。
func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) {
headerPayload, err := json.Marshal(messageHeader{
Type: MessageTypeText,
ID: 7,
From: "peer-a",
To: "peer-b",
ContentLength: 10,
})
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
var data bytes.Buffer
if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
if _, err := data.Write(headerPayload); err != nil {
t.Fatalf("data.Write(headerPayload) error = %v", err)
}
if _, err := data.Write([]byte("hello")); err != nil {
t.Fatalf("data.Write(body) error = %v", err)
}
_, err = DecodeMessage(data.Bytes())
if !errors.Is(err, ErrInvalidContentLength) {
t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength)
}
}
// TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file
// 验证读取端每次都能严格停在当前帧边界,不会串包。
func TestReadMultipleMessages(t *testing.T) {
var buf bytes.Buffer
first := Message{
Type: MessageTypeText,
ID: 1,
From: "peer-a",
To: "peer-b",
Body: []byte("hello"),
}
second := Message{
Type: MessageTypeFile,
ID: 2,
From: "peer-b",
To: "peer-a",
FileName: "payload.bin",
Body: []byte{0x01, 0x02, 0x03},
}
if err := WriteMessage(&buf, first); err != nil {
t.Fatalf("WriteMessage(first) error = %v", err)
}
if err := WriteMessage(&buf, second); err != nil {
t.Fatalf("WriteMessage(second) error = %v", err)
}
gotFirst, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage(first) error = %v", err)
}
gotSecond, err := ReadMessage(&buf)
if err != nil {
t.Fatalf("ReadMessage(second) error = %v", err)
}
if !reflect.DeepEqual(gotFirst, first) {
t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first)
}
if !reflect.DeepEqual(gotSecond, second) {
t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second)
}
}
// TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。
func TestReadMessageWrapsDecodeError(t *testing.T) {
var buf bytes.Buffer
if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil {
t.Fatalf("WriteFrame() error = %v", err)
}
_, err := ReadMessage(&buf)
if err == nil {
t.Fatal("ReadMessage() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "decode message") {
t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err)
}
}