#include "protocol.h" #include "cJSON.h" #include static const char *protocol_message_type_table[] = { "text", "file", "register", "error", "binary" }; const char *protocol_message_type_name(message_type_t type) { if ((int) type < 0 || (size_t) type >= OMNI_ARRAY_LEN(protocol_message_type_table)) { return "invalid"; } return protocol_message_type_table[type]; } int protocol_message_type_from_name(const char *raw, message_type_t *out) { size_t i; if (raw == NULL || out == NULL) { return -1; } for (i = 0; i < OMNI_ARRAY_LEN(protocol_message_type_table); ++i) { if (strcmp(raw, protocol_message_type_table[i]) == 0) { *out = (message_type_t) i; return 0; } } return -1; } void protocol_message_init(message_t *msg) { if (msg == NULL) { return; } memset(msg, 0, sizeof(*msg)); msg->type = MSG_TYPE_INVALID; } void protocol_message_clear(message_t *msg) { if (msg == NULL) { return; } free(msg->body); protocol_message_init(msg); } int protocol_message_copy(message_t *dst, const message_t *src) { if (dst == NULL || src == NULL) { errno = EINVAL; return -1; } protocol_message_clear(dst); memcpy(dst, src, sizeof(*dst)); dst->body = NULL; if (src->body_len > 0) { dst->body = (uint8_t *) malloc(src->body_len); if (dst->body == NULL) { protocol_message_init(dst); errno = ENOMEM; return -1; } memcpy(dst->body, src->body, src->body_len); } return 0; } static int protocol_set_err(char *err, size_t err_len, const char *fmt, ...) { va_list args; if (err != NULL && err_len > 0) { va_start(args, fmt); vsnprintf(err, err_len, fmt, args); va_end(args); } return -1; } int protocol_validate_message(const message_t *msg, char *err, size_t err_len) { if (msg == NULL) { return protocol_set_err(err, err_len, "protocol: nil message"); } if (msg->from[0] == '\0') { return protocol_set_err(err, err_len, "protocol: missing from"); } if (msg->to[0] == '\0') { return protocol_set_err(err, err_len, "protocol: missing to"); } switch (msg->type) { case MSG_TYPE_TEXT: if (msg->file_name[0] != '\0') { return protocol_set_err(err, err_len, "protocol: unexpected file name"); } if (!omni_utf8_valid(msg->body, msg->body_len)) { return protocol_set_err(err, err_len, "protocol: invalid text body"); } break; case MSG_TYPE_FILE: if (msg->file_name[0] == '\0') { return protocol_set_err(err, err_len, "protocol: missing file name"); } break; case MSG_TYPE_BINARY: if (msg->file_name[0] != '\0') { return protocol_set_err(err, err_len, "protocol: unexpected file name"); } break; case MSG_TYPE_REGISTER: if (strcmp(msg->to, SERVER_PEER_ID) != 0) { return protocol_set_err(err, err_len, "protocol: invalid register target"); } if (msg->file_name[0] != '\0') { return protocol_set_err(err, err_len, "protocol: unexpected file name"); } if (msg->body_len != 0) { return protocol_set_err(err, err_len, "protocol: unexpected body"); } break; case MSG_TYPE_ERROR: if (strcmp(msg->from, SERVER_PEER_ID) != 0) { return protocol_set_err(err, err_len, "protocol: invalid error source"); } if (msg->file_name[0] != '\0') { return protocol_set_err(err, err_len, "protocol: unexpected file name"); } if (!omni_utf8_valid(msg->body, msg->body_len)) { return protocol_set_err(err, err_len, "protocol: invalid text body"); } break; default: return protocol_set_err(err, err_len, "protocol: invalid message type"); } return 0; } static int protocol_build_header_json(const message_t *msg, char **out_json, size_t *out_len) { cJSON *root; char *json; root = cJSON_CreateObject(); if (root == NULL) { errno = ENOMEM; return -1; } cJSON_AddStringToObject(root, "type", protocol_message_type_name(msg->type)); cJSON_AddNumberToObject(root, "id", (double) msg->id); cJSON_AddStringToObject(root, "from", msg->from); cJSON_AddStringToObject(root, "to", msg->to); if (msg->file_name[0] != '\0') { cJSON_AddStringToObject(root, "file_name", msg->file_name); } cJSON_AddNumberToObject(root, "content_length", (double) msg->body_len); json = cJSON_PrintUnformatted(root); cJSON_Delete(root); if (json == NULL) { errno = ENOMEM; return -1; } *out_len = strlen(json); *out_json = json; return 0; } int protocol_encode_message_datagram(const message_t *msg, uint8_t **out, size_t *out_len) { uint8_t *buffer; char *header_json; size_t header_len; uint32_t net_header_len; char err[128]; if (out == NULL || out_len == NULL) { errno = EINVAL; return -1; } *out = NULL; *out_len = 0; if (protocol_validate_message(msg, err, sizeof(err)) != 0) { errno = EINVAL; return -1; } if (protocol_build_header_json(msg, &header_json, &header_len) != 0) { return -1; } if (4U + header_len + msg->body_len > OMNI_MAX_FRAME_SIZE) { cJSON_free(header_json); errno = EMSGSIZE; return -1; } buffer = (uint8_t *) malloc(4U + header_len + msg->body_len); if (buffer == NULL) { cJSON_free(header_json); errno = ENOMEM; return -1; } net_header_len = htonl((uint32_t) header_len); memcpy(buffer, &net_header_len, 4); memcpy(buffer + 4, header_json, header_len); if (msg->body_len > 0) { memcpy(buffer + 4 + header_len, msg->body, msg->body_len); } cJSON_free(header_json); *out = buffer; *out_len = 4U + header_len + msg->body_len; return 0; } static int protocol_copy_string_field(char *dst, size_t dst_len, const cJSON *object, const char *field, int required, char *err, size_t err_len) { const cJSON *item = cJSON_GetObjectItemCaseSensitive(object, field); if (item == NULL) { if (required) { return protocol_set_err(err, err_len, "protocol: missing %s", field); } dst[0] = '\0'; return 0; } if (!cJSON_IsString(item) || item->valuestring == NULL) { return protocol_set_err(err, err_len, "protocol: invalid %s", field); } snprintf(dst, dst_len, "%s", item->valuestring); return 0; } static int protocol_copy_u64_field(uint64_t *dst, const cJSON *object, const char *field, int required, char *err, size_t err_len) { const cJSON *item = cJSON_GetObjectItemCaseSensitive(object, field); if (item == NULL) { if (required) { return protocol_set_err(err, err_len, "protocol: missing %s", field); } *dst = 0; return 0; } if (!cJSON_IsNumber(item)) { return protocol_set_err(err, err_len, "protocol: invalid %s", field); } *dst = (uint64_t) item->valuedouble; return 0; } int protocol_decode_message_datagram(const uint8_t *data, size_t data_len, message_t *out_msg, char *err, size_t err_len) { uint32_t net_header_len; uint32_t header_len; char *header_text = NULL; cJSON *header = NULL; const cJSON *type_item; uint64_t content_length = 0; if (data == NULL || out_msg == NULL || data_len < 4U) { return protocol_set_err(err, err_len, "protocol: invalid datagram"); } if (data_len > OMNI_MAX_FRAME_SIZE) { return protocol_set_err(err, err_len, "protocol: frame too large"); } protocol_message_clear(out_msg); memcpy(&net_header_len, data, 4); header_len = ntohl(net_header_len); if (header_len == 0 || (size_t) header_len > data_len - 4U) { return protocol_set_err(err, err_len, "protocol: invalid header length"); } header_text = (char *) malloc((size_t) header_len + 1U); if (header_text == NULL) { errno = ENOMEM; return -1; } memcpy(header_text, data + 4, header_len); header_text[header_len] = '\0'; header = cJSON_Parse(header_text); free(header_text); if (header == NULL || !cJSON_IsObject(header)) { if (header != NULL) { cJSON_Delete(header); } return protocol_set_err(err, err_len, "protocol: invalid header json"); } type_item = cJSON_GetObjectItemCaseSensitive(header, "type"); if (type_item == NULL || !cJSON_IsString(type_item) || protocol_message_type_from_name(type_item->valuestring, &out_msg->type) != 0) { cJSON_Delete(header); return protocol_set_err(err, err_len, "protocol: invalid message type"); } if (protocol_copy_u64_field(&out_msg->id, header, "id", 1, err, err_len) != 0 || protocol_copy_string_field(out_msg->from, sizeof(out_msg->from), header, "from", 1, err, err_len) != 0 || protocol_copy_string_field(out_msg->to, sizeof(out_msg->to), header, "to", 1, err, err_len) != 0 || protocol_copy_string_field(out_msg->file_name, sizeof(out_msg->file_name), header, "file_name", 0, err, err_len) != 0 || protocol_copy_u64_field(&content_length, header, "content_length", 1, err, err_len) != 0) { cJSON_Delete(header); protocol_message_clear(out_msg); return -1; } cJSON_Delete(header); if ((size_t) content_length != data_len - 4U - (size_t) header_len) { protocol_message_clear(out_msg); return protocol_set_err(err, err_len, "protocol: invalid content length"); } out_msg->body_len = (size_t) content_length; if (out_msg->body_len > 0) { out_msg->body = (uint8_t *) malloc(out_msg->body_len); if (out_msg->body == NULL) { protocol_message_clear(out_msg); errno = ENOMEM; return -1; } memcpy(out_msg->body, data + 4U + header_len, out_msg->body_len); } if (protocol_validate_message(out_msg, err, err_len) != 0) { protocol_message_clear(out_msg); return -1; } return 0; } int protocol_encode_message_stream(const message_t *msg, uint8_t **out, size_t *out_len) { uint8_t *payload; uint8_t *buffer; size_t payload_len; uint32_t net_len; if (protocol_encode_message_datagram(msg, &payload, &payload_len) != 0) { return -1; } buffer = (uint8_t *) malloc(payload_len + 4U); if (buffer == NULL) { free(payload); errno = ENOMEM; return -1; } net_len = htonl((uint32_t) payload_len); memcpy(buffer, &net_len, 4); memcpy(buffer + 4, payload, payload_len); free(payload); *out = buffer; *out_len = payload_len + 4U; return 0; } int protocol_decode_message_stream_payload(const uint8_t *payload, size_t payload_len, message_t *out_msg, char *err, size_t err_len) { return protocol_decode_message_datagram(payload, payload_len, out_msg, err, err_len); } void protocol_frame_decoder_init(protocol_frame_decoder_t *decoder) { memset(decoder, 0, sizeof(*decoder)); } void protocol_frame_decoder_reset(protocol_frame_decoder_t *decoder) { decoder->len = 0; } void protocol_frame_decoder_destroy(protocol_frame_decoder_t *decoder) { free(decoder->buffer); memset(decoder, 0, sizeof(*decoder)); } int protocol_frame_decoder_feed(protocol_frame_decoder_t *decoder, const uint8_t *data, size_t data_len) { uint8_t *next_buffer; size_t next_cap; if (decoder->len + data_len > OMNI_MAX_FRAME_SIZE * 2U) { errno = EMSGSIZE; return -1; } if (decoder->len + data_len > decoder->cap) { next_cap = decoder->cap == 0 ? 4096U : decoder->cap; while (next_cap < decoder->len + data_len) { next_cap *= 2U; } next_buffer = (uint8_t *) realloc(decoder->buffer, next_cap); if (next_buffer == NULL) { errno = ENOMEM; return -1; } decoder->buffer = next_buffer; decoder->cap = next_cap; } memcpy(decoder->buffer + decoder->len, data, data_len); decoder->len += data_len; return 0; } int protocol_frame_decoder_next(protocol_frame_decoder_t *decoder, uint8_t **payload, size_t *payload_len) { uint32_t net_len; uint32_t frame_len; uint8_t *frame; if (payload == NULL || payload_len == NULL) { errno = EINVAL; return -1; } *payload = NULL; *payload_len = 0; if (decoder->len < 4U) { return 0; } memcpy(&net_len, decoder->buffer, 4); frame_len = ntohl(net_len); if (frame_len == 0 || frame_len > OMNI_MAX_FRAME_SIZE) { errno = EMSGSIZE; return -1; } if (decoder->len < 4U + frame_len) { return 0; } frame = (uint8_t *) malloc(frame_len); if (frame == NULL) { errno = ENOMEM; return -1; } memcpy(frame, decoder->buffer + 4, frame_len); memmove(decoder->buffer, decoder->buffer + 4U + frame_len, decoder->len - 4U - frame_len); decoder->len -= 4U + frame_len; *payload = frame; *payload_len = frame_len; return 1; }