Files
OmniSocketGo/src/protocol.c

416 lines
13 KiB
C

#include "protocol.h"
#include "cJSON.h"
#include <arpa/inet.h>
static const char *protocol_message_type_table[] = {
"text",
"file",
"register",
"error",
"binary"
};
const char *protocol_message_type_name(message_type_t type) {
if ((int) type < 0 || (size_t) type >= OMNI_ARRAY_LEN(protocol_message_type_table)) {
return "invalid";
}
return protocol_message_type_table[type];
}
int protocol_message_type_from_name(const char *raw, message_type_t *out) {
size_t i;
if (raw == NULL || out == NULL) {
return -1;
}
for (i = 0; i < OMNI_ARRAY_LEN(protocol_message_type_table); ++i) {
if (strcmp(raw, protocol_message_type_table[i]) == 0) {
*out = (message_type_t) i;
return 0;
}
}
return -1;
}
void protocol_message_init(message_t *msg) {
if (msg == NULL) {
return;
}
memset(msg, 0, sizeof(*msg));
msg->type = MSG_TYPE_INVALID;
}
void protocol_message_clear(message_t *msg) {
if (msg == NULL) {
return;
}
free(msg->body);
protocol_message_init(msg);
}
int protocol_message_copy(message_t *dst, const message_t *src) {
if (dst == NULL || src == NULL) {
errno = EINVAL;
return -1;
}
protocol_message_clear(dst);
memcpy(dst, src, sizeof(*dst));
dst->body = NULL;
if (src->body_len > 0) {
dst->body = (uint8_t *) malloc(src->body_len);
if (dst->body == NULL) {
protocol_message_init(dst);
errno = ENOMEM;
return -1;
}
memcpy(dst->body, src->body, src->body_len);
}
return 0;
}
static int protocol_set_err(char *err, size_t err_len, const char *fmt, ...) {
va_list args;
if (err != NULL && err_len > 0) {
va_start(args, fmt);
vsnprintf(err, err_len, fmt, args);
va_end(args);
}
return -1;
}
int protocol_validate_message(const message_t *msg, char *err, size_t err_len) {
if (msg == NULL) {
return protocol_set_err(err, err_len, "protocol: nil message");
}
if (msg->from[0] == '\0') {
return protocol_set_err(err, err_len, "protocol: missing from");
}
if (msg->to[0] == '\0') {
return protocol_set_err(err, err_len, "protocol: missing to");
}
switch (msg->type) {
case MSG_TYPE_TEXT:
if (msg->file_name[0] != '\0') {
return protocol_set_err(err, err_len, "protocol: unexpected file name");
}
if (!omni_utf8_valid(msg->body, msg->body_len)) {
return protocol_set_err(err, err_len, "protocol: invalid text body");
}
break;
case MSG_TYPE_FILE:
if (msg->file_name[0] == '\0') {
return protocol_set_err(err, err_len, "protocol: missing file name");
}
break;
case MSG_TYPE_BINARY:
if (msg->file_name[0] != '\0') {
return protocol_set_err(err, err_len, "protocol: unexpected file name");
}
break;
case MSG_TYPE_REGISTER:
if (strcmp(msg->to, SERVER_PEER_ID) != 0) {
return protocol_set_err(err, err_len, "protocol: invalid register target");
}
if (msg->file_name[0] != '\0') {
return protocol_set_err(err, err_len, "protocol: unexpected file name");
}
if (msg->body_len != 0) {
return protocol_set_err(err, err_len, "protocol: unexpected body");
}
break;
case MSG_TYPE_ERROR:
if (strcmp(msg->from, SERVER_PEER_ID) != 0) {
return protocol_set_err(err, err_len, "protocol: invalid error source");
}
if (msg->file_name[0] != '\0') {
return protocol_set_err(err, err_len, "protocol: unexpected file name");
}
if (!omni_utf8_valid(msg->body, msg->body_len)) {
return protocol_set_err(err, err_len, "protocol: invalid text body");
}
break;
default:
return protocol_set_err(err, err_len, "protocol: invalid message type");
}
return 0;
}
static int protocol_build_header_json(const message_t *msg, char **out_json, size_t *out_len) {
cJSON *root;
char *json;
root = cJSON_CreateObject();
if (root == NULL) {
errno = ENOMEM;
return -1;
}
cJSON_AddStringToObject(root, "type", protocol_message_type_name(msg->type));
cJSON_AddNumberToObject(root, "id", (double) msg->id);
cJSON_AddStringToObject(root, "from", msg->from);
cJSON_AddStringToObject(root, "to", msg->to);
if (msg->file_name[0] != '\0') {
cJSON_AddStringToObject(root, "file_name", msg->file_name);
}
cJSON_AddNumberToObject(root, "content_length", (double) msg->body_len);
json = cJSON_PrintUnformatted(root);
cJSON_Delete(root);
if (json == NULL) {
errno = ENOMEM;
return -1;
}
*out_len = strlen(json);
*out_json = json;
return 0;
}
int protocol_encode_message_datagram(const message_t *msg, uint8_t **out, size_t *out_len) {
uint8_t *buffer;
char *header_json;
size_t header_len;
uint32_t net_header_len;
char err[128];
if (out == NULL || out_len == NULL) {
errno = EINVAL;
return -1;
}
*out = NULL;
*out_len = 0;
if (protocol_validate_message(msg, err, sizeof(err)) != 0) {
errno = EINVAL;
return -1;
}
if (protocol_build_header_json(msg, &header_json, &header_len) != 0) {
return -1;
}
if (4U + header_len + msg->body_len > OMNI_MAX_FRAME_SIZE) {
cJSON_free(header_json);
errno = EMSGSIZE;
return -1;
}
buffer = (uint8_t *) malloc(4U + header_len + msg->body_len);
if (buffer == NULL) {
cJSON_free(header_json);
errno = ENOMEM;
return -1;
}
net_header_len = htonl((uint32_t) header_len);
memcpy(buffer, &net_header_len, 4);
memcpy(buffer + 4, header_json, header_len);
if (msg->body_len > 0) {
memcpy(buffer + 4 + header_len, msg->body, msg->body_len);
}
cJSON_free(header_json);
*out = buffer;
*out_len = 4U + header_len + msg->body_len;
return 0;
}
static int protocol_copy_string_field(char *dst, size_t dst_len, const cJSON *object, const char *field, int required, char *err, size_t err_len) {
const cJSON *item = cJSON_GetObjectItemCaseSensitive(object, field);
if (item == NULL) {
if (required) {
return protocol_set_err(err, err_len, "protocol: missing %s", field);
}
dst[0] = '\0';
return 0;
}
if (!cJSON_IsString(item) || item->valuestring == NULL) {
return protocol_set_err(err, err_len, "protocol: invalid %s", field);
}
snprintf(dst, dst_len, "%s", item->valuestring);
return 0;
}
static int protocol_copy_u64_field(uint64_t *dst, const cJSON *object, const char *field, int required, char *err, size_t err_len) {
const cJSON *item = cJSON_GetObjectItemCaseSensitive(object, field);
if (item == NULL) {
if (required) {
return protocol_set_err(err, err_len, "protocol: missing %s", field);
}
*dst = 0;
return 0;
}
if (!cJSON_IsNumber(item)) {
return protocol_set_err(err, err_len, "protocol: invalid %s", field);
}
*dst = (uint64_t) item->valuedouble;
return 0;
}
int protocol_decode_message_datagram(const uint8_t *data, size_t data_len, message_t *out_msg, char *err, size_t err_len) {
uint32_t net_header_len;
uint32_t header_len;
char *header_text = NULL;
cJSON *header = NULL;
const cJSON *type_item;
uint64_t content_length = 0;
if (data == NULL || out_msg == NULL || data_len < 4U) {
return protocol_set_err(err, err_len, "protocol: invalid datagram");
}
if (data_len > OMNI_MAX_FRAME_SIZE) {
return protocol_set_err(err, err_len, "protocol: frame too large");
}
protocol_message_clear(out_msg);
memcpy(&net_header_len, data, 4);
header_len = ntohl(net_header_len);
if (header_len == 0 || (size_t) header_len > data_len - 4U) {
return protocol_set_err(err, err_len, "protocol: invalid header length");
}
header_text = (char *) malloc((size_t) header_len + 1U);
if (header_text == NULL) {
errno = ENOMEM;
return -1;
}
memcpy(header_text, data + 4, header_len);
header_text[header_len] = '\0';
header = cJSON_Parse(header_text);
free(header_text);
if (header == NULL || !cJSON_IsObject(header)) {
if (header != NULL) {
cJSON_Delete(header);
}
return protocol_set_err(err, err_len, "protocol: invalid header json");
}
type_item = cJSON_GetObjectItemCaseSensitive(header, "type");
if (type_item == NULL || !cJSON_IsString(type_item) || protocol_message_type_from_name(type_item->valuestring, &out_msg->type) != 0) {
cJSON_Delete(header);
return protocol_set_err(err, err_len, "protocol: invalid message type");
}
if (protocol_copy_u64_field(&out_msg->id, header, "id", 1, err, err_len) != 0 ||
protocol_copy_string_field(out_msg->from, sizeof(out_msg->from), header, "from", 1, err, err_len) != 0 ||
protocol_copy_string_field(out_msg->to, sizeof(out_msg->to), header, "to", 1, err, err_len) != 0 ||
protocol_copy_string_field(out_msg->file_name, sizeof(out_msg->file_name), header, "file_name", 0, err, err_len) != 0 ||
protocol_copy_u64_field(&content_length, header, "content_length", 1, err, err_len) != 0) {
cJSON_Delete(header);
protocol_message_clear(out_msg);
return -1;
}
cJSON_Delete(header);
if ((size_t) content_length != data_len - 4U - (size_t) header_len) {
protocol_message_clear(out_msg);
return protocol_set_err(err, err_len, "protocol: invalid content length");
}
out_msg->body_len = (size_t) content_length;
if (out_msg->body_len > 0) {
out_msg->body = (uint8_t *) malloc(out_msg->body_len);
if (out_msg->body == NULL) {
protocol_message_clear(out_msg);
errno = ENOMEM;
return -1;
}
memcpy(out_msg->body, data + 4U + header_len, out_msg->body_len);
}
if (protocol_validate_message(out_msg, err, err_len) != 0) {
protocol_message_clear(out_msg);
return -1;
}
return 0;
}
int protocol_encode_message_stream(const message_t *msg, uint8_t **out, size_t *out_len) {
uint8_t *payload;
uint8_t *buffer;
size_t payload_len;
uint32_t net_len;
if (protocol_encode_message_datagram(msg, &payload, &payload_len) != 0) {
return -1;
}
buffer = (uint8_t *) malloc(payload_len + 4U);
if (buffer == NULL) {
free(payload);
errno = ENOMEM;
return -1;
}
net_len = htonl((uint32_t) payload_len);
memcpy(buffer, &net_len, 4);
memcpy(buffer + 4, payload, payload_len);
free(payload);
*out = buffer;
*out_len = payload_len + 4U;
return 0;
}
int protocol_decode_message_stream_payload(const uint8_t *payload, size_t payload_len, message_t *out_msg, char *err, size_t err_len) {
return protocol_decode_message_datagram(payload, payload_len, out_msg, err, err_len);
}
void protocol_frame_decoder_init(protocol_frame_decoder_t *decoder) {
memset(decoder, 0, sizeof(*decoder));
}
void protocol_frame_decoder_reset(protocol_frame_decoder_t *decoder) {
decoder->len = 0;
}
void protocol_frame_decoder_destroy(protocol_frame_decoder_t *decoder) {
free(decoder->buffer);
memset(decoder, 0, sizeof(*decoder));
}
int protocol_frame_decoder_feed(protocol_frame_decoder_t *decoder, const uint8_t *data, size_t data_len) {
uint8_t *next_buffer;
size_t next_cap;
if (decoder->len + data_len > OMNI_MAX_FRAME_SIZE * 2U) {
errno = EMSGSIZE;
return -1;
}
if (decoder->len + data_len > decoder->cap) {
next_cap = decoder->cap == 0 ? 4096U : decoder->cap;
while (next_cap < decoder->len + data_len) {
next_cap *= 2U;
}
next_buffer = (uint8_t *) realloc(decoder->buffer, next_cap);
if (next_buffer == NULL) {
errno = ENOMEM;
return -1;
}
decoder->buffer = next_buffer;
decoder->cap = next_cap;
}
memcpy(decoder->buffer + decoder->len, data, data_len);
decoder->len += data_len;
return 0;
}
int protocol_frame_decoder_next(protocol_frame_decoder_t *decoder, uint8_t **payload, size_t *payload_len) {
uint32_t net_len;
uint32_t frame_len;
uint8_t *frame;
if (payload == NULL || payload_len == NULL) {
errno = EINVAL;
return -1;
}
*payload = NULL;
*payload_len = 0;
if (decoder->len < 4U) {
return 0;
}
memcpy(&net_len, decoder->buffer, 4);
frame_len = ntohl(net_len);
if (frame_len == 0 || frame_len > OMNI_MAX_FRAME_SIZE) {
errno = EMSGSIZE;
return -1;
}
if (decoder->len < 4U + frame_len) {
return 0;
}
frame = (uint8_t *) malloc(frame_len);
if (frame == NULL) {
errno = ENOMEM;
return -1;
}
memcpy(frame, decoder->buffer + 4, frame_len);
memmove(decoder->buffer, decoder->buffer + 4U + frame_len, decoder->len - 4U - frame_len);
decoder->len -= 4U + frame_len;
*payload = frame;
*payload_len = frame_len;
return 1;
}