package protocol import ( "bytes" "encoding/binary" "encoding/json" "errors" "reflect" "strings" "testing" ) // TestEncodeDecodeMessageTextASCII 验证 ASCII 文本可以按 text 消息往返编解码。 func TestEncodeDecodeMessageTextASCII(t *testing.T) { original := Message{ Type: MessageTypeText, ID: 42, From: "peer-a", To: "peer-b", Body: []byte("hello"), } data, err := EncodeMessage(original) if err != nil { t.Fatalf("EncodeMessage() error = %v", err) } decoded, err := DecodeMessage(data) if err != nil { t.Fatalf("DecodeMessage() error = %v", err) } if !reflect.DeepEqual(decoded, original) { t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) } } // TestEncodeDecodeMessageTextUTF8 验证 text 消息允许合法 UTF-8, // 从而天然兼容 ASCII 之外的普通文本。 func TestEncodeDecodeMessageTextUTF8(t *testing.T) { original := Message{ Type: MessageTypeText, ID: 43, From: "peer-a", To: "peer-b", Body: []byte("你好, world"), } data, err := EncodeMessage(original) if err != nil { t.Fatalf("EncodeMessage() error = %v", err) } decoded, err := DecodeMessage(data) if err != nil { t.Fatalf("DecodeMessage() error = %v", err) } if !reflect.DeepEqual(decoded, original) { t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) } } // TestEncodeDecodeMessageFile 验证 file 消息会保留文件名和原始二进制正文。 func TestEncodeDecodeMessageFile(t *testing.T) { original := Message{ Type: MessageTypeFile, ID: 44, From: "peer-a", To: "peer-b", FileName: "data.bin", Body: []byte{0x00, 0xff, 0x10, 0x7f}, } data, err := EncodeMessage(original) if err != nil { t.Fatalf("EncodeMessage() error = %v", err) } decoded, err := DecodeMessage(data) if err != nil { t.Fatalf("DecodeMessage() error = %v", err) } if !reflect.DeepEqual(decoded, original) { t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) } } // TestEncodeDecodeMessageRegister 验证 register 控制消息也能正常编解码。 func TestEncodeDecodeMessageRegister(t *testing.T) { original := Message{ Type: MessageTypeRegister, ID: 45, From: "peer-a", To: ServerPeerID, Body: []byte{}, } data, err := EncodeMessage(original) if err != nil { t.Fatalf("EncodeMessage() error = %v", err) } decoded, err := DecodeMessage(data) if err != nil { t.Fatalf("DecodeMessage() error = %v", err) } if !reflect.DeepEqual(decoded, original) { t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) } } // TestEncodeDecodeMessageError 验证 error 控制消息会保留 UTF-8 错误文本。 func TestEncodeDecodeMessageError(t *testing.T) { original := Message{ Type: MessageTypeError, ID: 46, From: ServerPeerID, To: "peer-a", Body: []byte("unknown target"), } data, err := EncodeMessage(original) if err != nil { t.Fatalf("EncodeMessage() error = %v", err) } decoded, err := DecodeMessage(data) if err != nil { t.Fatalf("DecodeMessage() error = %v", err) } if !reflect.DeepEqual(decoded, original) { t.Fatalf("round trip mismatch: got %+v want %+v", decoded, original) } } // TestWriteReadFrame 单独验证最底层的长度前缀帧逻辑, // 不依赖 Message 结构,方便确认 TCP 粘包拆包问题是否被正确处理。 func TestWriteReadFrame(t *testing.T) { var buf bytes.Buffer payload := []byte("header+body") if err := WriteFrame(&buf, payload); err != nil { t.Fatalf("WriteFrame() error = %v", err) } got, err := ReadFrame(&buf) if err != nil { t.Fatalf("ReadFrame() error = %v", err) } if !bytes.Equal(got, payload) { t.Fatalf("payload mismatch: got %q want %q", got, payload) } } // TestWriteReadMessageAllowsEmptyBody 验证空文本和空文件都可以正常通过协议层, // 因为外层帧非空的前提下,空正文是合法业务内容。 func TestWriteReadMessageAllowsEmptyBody(t *testing.T) { tests := []struct { name string message Message }{ { name: "empty text", message: Message{ Type: MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte(""), }, }, { name: "empty file", message: Message{ Type: MessageTypeFile, ID: 2, From: "peer-a", To: "peer-b", FileName: "empty.txt", Body: []byte{}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var buf bytes.Buffer if err := WriteMessage(&buf, tt.message); err != nil { t.Fatalf("WriteMessage() error = %v", err) } got, err := ReadMessage(&buf) if err != nil { t.Fatalf("ReadMessage() error = %v", err) } if !reflect.DeepEqual(got, tt.message) { t.Fatalf("round trip mismatch: got %+v want %+v", got, tt.message) } }) } } // TestWriteReadMessageRejectsInvalidMessages 验证协议层会在编码前拦住明显非法的消息。 func TestWriteReadMessageRejectsInvalidMessages(t *testing.T) { tests := []struct { name string message Message wantErr error }{ { name: "invalid type", message: Message{ Type: MessageType("unknown"), ID: 1, From: "peer-a", To: "peer-b", }, wantErr: ErrInvalidMessageType, }, { name: "missing from", message: Message{ Type: MessageTypeText, ID: 2, To: "peer-b", }, wantErr: ErrMissingFrom, }, { name: "missing to", message: Message{ Type: MessageTypeText, ID: 3, From: "peer-a", }, wantErr: ErrMissingTo, }, { name: "text with file name", message: Message{ Type: MessageTypeText, ID: 4, From: "peer-a", To: "peer-b", FileName: "bad.txt", Body: []byte("hello"), }, wantErr: ErrUnexpectedFileName, }, { name: "text with invalid utf8", message: Message{ Type: MessageTypeText, ID: 5, From: "peer-a", To: "peer-b", Body: []byte{0xff, 0xfe}, }, wantErr: ErrInvalidTextBody, }, { name: "file without file name", message: Message{ Type: MessageTypeFile, ID: 6, From: "peer-a", To: "peer-b", Body: []byte{0x01}, }, wantErr: ErrMissingFileName, }, { name: "register with wrong target", message: Message{ Type: MessageTypeRegister, ID: 7, From: "peer-a", To: "peer-b", }, wantErr: ErrInvalidRegisterTarget, }, { name: "register with body", message: Message{ Type: MessageTypeRegister, ID: 8, From: "peer-a", To: ServerPeerID, Body: []byte("unexpected"), }, wantErr: ErrUnexpectedBody, }, { name: "error with wrong source", message: Message{ Type: MessageTypeError, ID: 9, From: "peer-a", To: "peer-b", Body: []byte("bad"), }, wantErr: ErrInvalidErrorSource, }, { name: "error with file name", message: Message{ Type: MessageTypeError, ID: 10, From: ServerPeerID, To: "peer-a", FileName: "bad.txt", Body: []byte("bad"), }, wantErr: ErrUnexpectedFileName, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := EncodeMessage(tt.message) if !errors.Is(err, tt.wantErr) { t.Fatalf("EncodeMessage() error = %v, want %v", err, tt.wantErr) } }) } } // TestReadFrameRejectsInvalidLength 验证长度为 0 的帧会被当成非法输入, // 而不是被当成一条合法的空消息。 func TestReadFrameRejectsInvalidLength(t *testing.T) { var buf bytes.Buffer if err := binary.Write(&buf, binary.BigEndian, uint32(0)); err != nil { t.Fatalf("binary.Write() error = %v", err) } _, err := ReadFrame(&buf) if !errors.Is(err, ErrInvalidFrameLength) { t.Fatalf("ReadFrame() error = %v, want %v", err, ErrInvalidFrameLength) } } // TestReadFrameRejectsTooLargeFrame 验证超大帧会在分配消息体前被拒绝, // 从而保证最大长度限制真正生效。 func TestReadFrameRejectsTooLargeFrame(t *testing.T) { var buf bytes.Buffer if err := binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)); err != nil { t.Fatalf("binary.Write() error = %v", err) } _, err := ReadFrame(&buf) if !errors.Is(err, ErrFrameTooLarge) { t.Fatalf("ReadFrame() error = %v, want %v", err, ErrFrameTooLarge) } } // TestWriteFrameRejectsEmptyPayload 验证写入端和读取端的约束保持一致: // 既然读取端不接受 0 长度帧,写入端也不应该产生这种帧。 func TestWriteFrameRejectsEmptyPayload(t *testing.T) { var buf bytes.Buffer err := WriteFrame(&buf, nil) if !errors.Is(err, ErrInvalidFrameLength) { t.Fatalf("WriteFrame() error = %v, want %v", err, ErrInvalidFrameLength) } } // TestDecodeMessageRejectsInvalidHeaderLength 验证无法切出完整头部时会被立即拒绝。 func TestDecodeMessageRejectsInvalidHeaderLength(t *testing.T) { tests := []struct { name string data []byte }{ { name: "too short for header len", data: []byte{0x00, 0x00, 0x00}, }, { name: "zero header len", data: []byte{0x00, 0x00, 0x00, 0x00}, }, { name: "header len exceeds payload", data: []byte{0x00, 0x00, 0x00, 0x10, '{', '}'}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := DecodeMessage(tt.data) if !errors.Is(err, ErrInvalidHeaderLength) { t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderLength) } }) } } // TestDecodeMessageRejectsInvalidHeaderJSON 验证头部 JSON 非法时能返回明确错误。 func TestDecodeMessageRejectsInvalidHeaderJSON(t *testing.T) { data := append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...) _, err := DecodeMessage(data) if !errors.Is(err, ErrInvalidHeaderJSON) { t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidHeaderJSON) } } // TestDecodeMessageRejectsContentLengthMismatch 验证头部声明长度和实际正文不一致时会失败。 func TestDecodeMessageRejectsContentLengthMismatch(t *testing.T) { headerPayload, err := json.Marshal(messageHeader{ Type: MessageTypeText, ID: 7, From: "peer-a", To: "peer-b", ContentLength: 10, }) if err != nil { t.Fatalf("json.Marshal() error = %v", err) } var data bytes.Buffer if err := binary.Write(&data, binary.BigEndian, uint32(len(headerPayload))); err != nil { t.Fatalf("binary.Write() error = %v", err) } if _, err := data.Write(headerPayload); err != nil { t.Fatalf("data.Write(headerPayload) error = %v", err) } if _, err := data.Write([]byte("hello")); err != nil { t.Fatalf("data.Write(body) error = %v", err) } _, err = DecodeMessage(data.Bytes()) if !errors.Is(err, ErrInvalidContentLength) { t.Fatalf("DecodeMessage() error = %v, want %v", err, ErrInvalidContentLength) } } // TestReadMultipleMessages 模拟同一条流中连续写入 text 和 file, // 验证读取端每次都能严格停在当前帧边界,不会串包。 func TestReadMultipleMessages(t *testing.T) { var buf bytes.Buffer first := Message{ Type: MessageTypeText, ID: 1, From: "peer-a", To: "peer-b", Body: []byte("hello"), } second := Message{ Type: MessageTypeFile, ID: 2, From: "peer-b", To: "peer-a", FileName: "payload.bin", Body: []byte{0x01, 0x02, 0x03}, } if err := WriteMessage(&buf, first); err != nil { t.Fatalf("WriteMessage(first) error = %v", err) } if err := WriteMessage(&buf, second); err != nil { t.Fatalf("WriteMessage(second) error = %v", err) } gotFirst, err := ReadMessage(&buf) if err != nil { t.Fatalf("ReadMessage(first) error = %v", err) } gotSecond, err := ReadMessage(&buf) if err != nil { t.Fatalf("ReadMessage(second) error = %v", err) } if !reflect.DeepEqual(gotFirst, first) { t.Fatalf("first message mismatch: got %+v want %+v", gotFirst, first) } if !reflect.DeepEqual(gotSecond, second) { t.Fatalf("second message mismatch: got %+v want %+v", gotSecond, second) } } // TestReadMessageWrapsDecodeError 验证 ReadMessage 在返回错误时会保留解码阶段上下文。 func TestReadMessageWrapsDecodeError(t *testing.T) { var buf bytes.Buffer if err := WriteFrame(&buf, append([]byte{0x00, 0x00, 0x00, 0x09}, []byte("{invalid}")...)); err != nil { t.Fatalf("WriteFrame() error = %v", err) } _, err := ReadMessage(&buf) if err == nil { t.Fatal("ReadMessage() error = nil, want non-nil") } if !strings.Contains(err.Error(), "decode message") { t.Fatalf("ReadMessage() error = %v, want wrapped decode error", err) } }