From 79dba2a664e7770a140ac0a95bd1bb6b63e714bc Mon Sep 17 00:00:00 2001 From: Mock Date: Fri, 10 Apr 2026 11:11:03 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=95=BF=E4=BF=9D=E6=8C=81=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=EF=BC=8C=E6=8E=A7=E5=88=B6=E7=AB=AF=E5=8F=AF=E9=87=8D?= =?UTF-8?q?=E5=90=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/b_side_omnid.c | 89 +++- cmd/kcppeer.c | 4 +- include/peer_kcp_client.h | 6 + python/omnisocket/_omnisocket.c | 8 +- python/omnisocket/omnisocket_client.c | 38 +- python/omnisocket/omnisocket_client.h | 2 + python/tests/test_sessions.py | 59 +++ .../udp_teleop_bridge/udp_cmd_vel_receiver.py | 80 +++- src/peer_kcp_client.c | 380 ++++++++++++++++-- src/server_kcp_hub.c | 369 +++++++++++++++-- 10 files changed, 930 insertions(+), 105 deletions(-) diff --git a/cmd/b_side_omnid.c b/cmd/b_side_omnid.c index f47c4e0..cd39799 100644 --- a/cmd/b_side_omnid.c +++ b/cmd/b_side_omnid.c @@ -10,6 +10,7 @@ #include #include "control_protocol.h" +#include "protocol.h" #include "video_pipeline.h" #define CONTROL_DEFAULT_PEER_ID "peer-b-ctrl" @@ -19,6 +20,7 @@ typedef struct unix_dgram_client { int fd; char bind_path[108]; + char dest_path[108]; struct sockaddr_un dest_addr; socklen_t dest_len; } unix_dgram_client_t; @@ -28,7 +30,9 @@ typedef struct control_bridge_stats { uint64_t packets_forwarded; uint64_t invalid_packets; uint64_t unix_send_errors; - int connected; + uint64_t reconnect_count; + int ever_connected; + int registered; char last_error[256]; kcp_runtime_stats_t transport; } control_bridge_stats_t; @@ -109,6 +113,8 @@ static void control_bridge_stats_destroy(control_bridge_stats_t *stats) { pthread_mutex_destroy(&stats->mutex); } +static void unix_dgram_client_close(unix_dgram_client_t *client); + static void control_bridge_set_error(control_bridge_stats_t *stats, const char *message) { if (stats == NULL) { return; @@ -142,7 +148,8 @@ static void control_bridge_stats_snapshot(control_bridge_stats_t *stats, control out_stats->packets_forwarded = stats->packets_forwarded; out_stats->invalid_packets = stats->invalid_packets; out_stats->unix_send_errors = stats->unix_send_errors; - out_stats->connected = stats->connected; + out_stats->reconnect_count = stats->reconnect_count; + out_stats->registered = stats->registered; snprintf(out_stats->last_error, sizeof(out_stats->last_error), "%s", stats->last_error); out_stats->transport = stats->transport; pthread_mutex_unlock(&stats->mutex); @@ -178,6 +185,7 @@ static int unix_dgram_client_init(unix_dgram_client_t *client, const char *dest_ memset(&client->dest_addr, 0, sizeof(client->dest_addr)); client->dest_addr.sun_family = AF_UNIX; + snprintf(client->dest_path, sizeof(client->dest_path), "%s", dest_path); snprintf(client->dest_addr.sun_path, sizeof(client->dest_addr.sun_path), "%s", dest_path); client->dest_len = (socklen_t) sizeof(client->dest_addr); return 0; @@ -199,6 +207,22 @@ static int unix_dgram_client_send(unix_dgram_client_t *client, const void *data, return 0; } +static int unix_dgram_client_reopen(unix_dgram_client_t *client) { + char dest_path[sizeof(client->dest_path)]; + + if (client == NULL || client->dest_path[0] == '\0') { + errno = EINVAL; + return -1; + } + snprintf(dest_path, sizeof(dest_path), "%s", client->dest_path); + unix_dgram_client_close(client); + return unix_dgram_client_init(client, dest_path); +} + +static int unix_dgram_client_should_reopen(int error_code) { + return error_code == ENOENT || error_code == ECONNREFUSED || error_code == EBADF || error_code == ENOTCONN; +} + static void unix_dgram_client_close(unix_dgram_client_t *client) { if (client == NULL) { return; @@ -253,27 +277,56 @@ static void *control_thread_main(void *arg) { continue; } - pthread_mutex_lock(&state->control_stats.mutex); - state->control_stats.connected = 1; - state->control_stats.last_error[0] = '\0'; - pthread_mutex_unlock(&state->control_stats.mutex); + { + kcp_client_state_t client_state; + + memset(&client_state, 0, sizeof(client_state)); + kcp_client_state_snapshot(client, &client_state); + pthread_mutex_lock(&state->control_stats.mutex); + if (state->control_stats.ever_connected) { + state->control_stats.reconnect_count += 1; + } else { + state->control_stats.ever_connected = 1; + } + state->control_stats.registered = client_state.registered; + snprintf(state->control_stats.last_error, sizeof(state->control_stats.last_error), "%s", client_state.last_server_error); + pthread_mutex_unlock(&state->control_stats.mutex); + } while (!*state->stop_requested) { message_t msg; int rc; + kcp_client_state_t client_state; protocol_message_init(&msg); rc = kcp_client_receive_timed(client, &msg, 100); if (rc == 1) { protocol_message_clear(&msg); + memset(&client_state, 0, sizeof(client_state)); + kcp_client_state_snapshot(client, &client_state); + pthread_mutex_lock(&state->control_stats.mutex); + state->control_stats.registered = client_state.registered; + snprintf(state->control_stats.last_error, sizeof(state->control_stats.last_error), "%s", client_state.last_server_error); + pthread_mutex_unlock(&state->control_stats.mutex); continue; } if (rc != 0) { - control_bridge_set_errno_error(&state->control_stats, "control receive loop stopped"); + memset(&client_state, 0, sizeof(client_state)); + kcp_client_state_snapshot(client, &client_state); + if (client_state.last_server_error[0] != '\0') { + control_bridge_set_error(&state->control_stats, client_state.last_server_error); + } else { + control_bridge_set_errno_error(&state->control_stats, "control receive loop stopped"); + } protocol_message_clear(&msg); break; } + if (msg.type == MSG_TYPE_ERROR && strcmp(msg.from, SERVER_PEER_ID) == 0) { + control_bridge_set_error(&state->control_stats, (const char *) msg.body); + protocol_message_clear(&msg); + continue; + } if (state->control_expected_sender[0] != '\0' && strcmp(msg.from, state->control_expected_sender) != 0) { pthread_mutex_lock(&state->control_stats.mutex); state->control_stats.invalid_packets += 1; @@ -291,6 +344,21 @@ static void *control_thread_main(void *arg) { } if (unix_dgram_client_send(&state->unix_client, msg.body, msg.body_len) != 0) { + int send_errno = errno; + int recovered = 0; + + if (unix_dgram_client_should_reopen(send_errno) && unix_dgram_client_reopen(&state->unix_client) == 0) { + recovered = unix_dgram_client_send(&state->unix_client, msg.body, msg.body_len) == 0; + } + if (recovered) { + pthread_mutex_lock(&state->control_stats.mutex); + state->control_stats.packets_forwarded += 1; + kcp_client_runtime_stats_snapshot(client, &state->control_stats.transport); + pthread_mutex_unlock(&state->control_stats.mutex); + protocol_message_clear(&msg); + continue; + } + errno = send_errno; pthread_mutex_lock(&state->control_stats.mutex); state->control_stats.unix_send_errors += 1; pthread_mutex_unlock(&state->control_stats.mutex); @@ -307,7 +375,7 @@ static void *control_thread_main(void *arg) { } pthread_mutex_lock(&state->control_stats.mutex); - state->control_stats.connected = 0; + state->control_stats.registered = 0; pthread_mutex_unlock(&state->control_stats.mutex); kcp_client_close(client); kcp_client_free(client); @@ -330,12 +398,13 @@ static void print_stats(daemon_state_t *state) { fprintf( stderr, - "[b_side_omnid] video connected=%d frames=%llu bytes=%llu srtt=%dms | control connected=%d forwarded=%llu invalid=%llu unix_err=%llu srtt=%dms\n", + "[b_side_omnid] video registered=%d frames=%llu bytes=%llu srtt=%dms | control registered=%d reconnects=%llu forwarded=%llu invalid=%llu unix_err=%llu srtt=%dms\n", video_stats.connected, (unsigned long long) video_stats.frames_sent, (unsigned long long) video_stats.bytes_sent, video_stats.transport.srtt_ms, - control_stats.connected, + control_stats.registered, + (unsigned long long) control_stats.reconnect_count, (unsigned long long) control_stats.packets_forwarded, (unsigned long long) control_stats.invalid_packets, (unsigned long long) control_stats.unix_send_errors, diff --git a/cmd/kcppeer.c b/cmd/kcppeer.c index 7913a21..e71981c 100644 --- a/cmd/kcppeer.c +++ b/cmd/kcppeer.c @@ -262,9 +262,9 @@ int main(int argc, char **argv) { goto cleanup; } if (relay_via[0] != '\0') { - fprintf(stderr, "opened KCP session as %s; logical server=%s, actual dial target=%s via relay; register not yet confirmed\n", kcp_client_id(client), server_addr, actual_dial_target); + fprintf(stderr, "opened KCP session as %s; logical server=%s, actual dial target=%s via relay; registration confirmed\n", kcp_client_id(client), server_addr, actual_dial_target); } else { - fprintf(stderr, "opened KCP session as %s; logical server=%s, actual dial target=%s; register not yet confirmed\n", kcp_client_id(client), server_addr, actual_dial_target); + fprintf(stderr, "opened KCP session as %s; logical server=%s, actual dial target=%s; registration confirmed\n", kcp_client_id(client), server_addr, actual_dial_target); } receive_ctx.client = client; diff --git a/include/peer_kcp_client.h b/include/peer_kcp_client.h index e492941..97e9529 100644 --- a/include/peer_kcp_client.h +++ b/include/peer_kcp_client.h @@ -16,6 +16,11 @@ typedef struct kcp_client_recv_meta { char file_name[OMNI_MAX_FILE_NAME]; size_t body_len; } kcp_client_recv_meta_t; +typedef struct kcp_client_state { + int connected; + int registered; + char last_server_error[256]; +} kcp_client_state_t; kcp_client_t *kcp_client_dial_with_options(const char *server_addr, const char *dial_addr, const char *peer_id, const char *bind_ip, const char *bind_device, const kcp_conn_options_t *options, latency_logger_t *logger, kcp_packet_debug_logger_t *packet_logger, kcp_session_stats_logger_t *stats_logger, int stats_interval_ms); kcp_client_t *kcp_client_dial(const char *server_addr, const char *dial_addr, const char *peer_id, const char *bind_ip, const char *bind_device, latency_logger_t *logger, kcp_packet_debug_logger_t *packet_logger, kcp_session_stats_logger_t *stats_logger, int stats_interval_ms); @@ -27,6 +32,7 @@ int kcp_client_receive_timed(kcp_client_t *client, message_t *out_msg, int timeo int kcp_client_receive(kcp_client_t *client, message_t *out_msg); int kcp_client_receive_binary_into(kcp_client_t *client, void *buffer, size_t buffer_len, kcp_client_recv_meta_t *out_meta, int timeout_ms); int kcp_client_persist_message(kcp_client_t *client, const message_t *msg, const char *inbox_dir, char *out_path, size_t out_path_len); +void kcp_client_state_snapshot(kcp_client_t *client, kcp_client_state_t *out_state); void kcp_client_runtime_stats_snapshot(kcp_client_t *client, kcp_runtime_stats_t *out_stats); int kcp_client_close(kcp_client_t *client); void kcp_client_free(kcp_client_t *client); diff --git a/python/omnisocket/_omnisocket.c b/python/omnisocket/_omnisocket.c index 0538ec4..cd55a6c 100644 --- a/python/omnisocket/_omnisocket.c +++ b/python/omnisocket/_omnisocket.c @@ -67,7 +67,7 @@ static PyObject *build_recv_meta_dict( static PyObject *build_stats_dict(const omnisocket_session_stats_t *stats) { return Py_BuildValue( - "{s:K,s:K,s:K,s:K,s:K,s:K,s:K,s:i}", + "{s:K,s:K,s:K,s:K,s:K,s:K,s:K,s:i,s:i,s:s}", "send_calls", (unsigned long long) stats->send_calls, "send_bytes", @@ -83,7 +83,11 @@ static PyObject *build_stats_dict(const omnisocket_session_stats_t *stats) { "recv_errors", (unsigned long long) stats->recv_errors, "connected", - stats->connected + stats->connected, + "registered", + stats->registered, + "last_server_error", + stats->last_server_error ); } diff --git a/python/omnisocket/omnisocket_client.c b/python/omnisocket/omnisocket_client.c index b181b23..79eff06 100644 --- a/python/omnisocket/omnisocket_client.c +++ b/python/omnisocket/omnisocket_client.c @@ -1,5 +1,33 @@ #include "omnisocket_client.h" +static void omnisocket_session_sync_client_state_locked(omnisocket_session_t *session, kcp_client_t *client) { + kcp_client_state_t client_state; + + if (session == NULL) { + return; + } + memset(&client_state, 0, sizeof(client_state)); + if (client != NULL) { + kcp_client_state_snapshot(client, &client_state); + } + session->stats.connected = client_state.connected; + session->stats.registered = client_state.registered; + snprintf( + session->stats.last_server_error, + sizeof(session->stats.last_server_error), + "%s", + client_state.last_server_error + ); +} + +static void omnisocket_session_mark_disconnected_locked(omnisocket_session_t *session) { + if (session == NULL) { + return; + } + session->stats.connected = 0; + session->stats.registered = 0; +} + int omnisocket_session_init(omnisocket_session_t *session) { int rc; @@ -97,7 +125,7 @@ int omnisocket_session_connect( return -1; } session->client = client; - session->stats.connected = 1; + omnisocket_session_sync_client_state_locked(session, client); pthread_mutex_unlock(&session->mutex); return 0; } @@ -119,7 +147,7 @@ int omnisocket_session_close(omnisocket_session_t *session) { session->closing = 1; session->client = NULL; } - session->stats.connected = 0; + omnisocket_session_mark_disconnected_locked(session); pthread_mutex_unlock(&session->mutex); if (client != NULL) { @@ -158,6 +186,7 @@ int omnisocket_session_send(omnisocket_session_t *session, const char *to, const } else { session->stats.send_errors += 1; } + omnisocket_session_sync_client_state_locked(session, client); if (session->active_ops > 0) { session->active_ops -= 1; } @@ -190,6 +219,7 @@ int omnisocket_session_recv(omnisocket_session_t *session, message_t *out_msg, i } else { session->stats.recv_errors += 1; } + omnisocket_session_sync_client_state_locked(session, client); if (session->active_ops > 0) { session->active_ops -= 1; } @@ -228,6 +258,7 @@ int omnisocket_session_recv_into( } else { session->stats.recv_errors += 1; } + omnisocket_session_sync_client_state_locked(session, client); if (session->active_ops > 0) { session->active_ops -= 1; } @@ -376,6 +407,8 @@ int omnisocket_udp_session_connect( } session->client = client; session->stats.connected = 1; + session->stats.registered = 1; + session->stats.last_server_error[0] = '\0'; pthread_mutex_unlock(&session->mutex); return 0; } @@ -398,6 +431,7 @@ int omnisocket_udp_session_close(omnisocket_udp_session_t *session) { session->client = NULL; } session->stats.connected = 0; + session->stats.registered = 0; pthread_mutex_unlock(&session->mutex); if (client != NULL) { diff --git a/python/omnisocket/omnisocket_client.h b/python/omnisocket/omnisocket_client.h index d24c7af..a6842ab 100644 --- a/python/omnisocket/omnisocket_client.h +++ b/python/omnisocket/omnisocket_client.h @@ -13,6 +13,8 @@ typedef struct omnisocket_session_stats { uint64_t recv_timeouts; uint64_t recv_errors; int connected; + int registered; + char last_server_error[256]; } omnisocket_session_stats_t; typedef struct omnisocket_session_kcp_stats { diff --git a/python/tests/test_sessions.py b/python/tests/test_sessions.py index 68ff920..04fca02 100644 --- a/python/tests/test_sessions.py +++ b/python/tests/test_sessions.py @@ -121,6 +121,8 @@ def test_control_sessions_smoke(transport: str, binary_name: str, session_cls) - receiver_stats = receiver.stats() assert sender_stats['connected'] == 1 assert receiver_stats['connected'] == 1 + assert sender_stats['registered'] == 1 + assert receiver_stats['registered'] == 1 assert sender_stats['send_calls'] >= 2 assert receiver_stats['recv_calls'] >= 2 if transport == 'kcp': @@ -135,6 +137,63 @@ def test_control_sessions_smoke(transport: str, binary_name: str, session_cls) - receiver.close() +def test_kcp_duplicate_peer_new_instance_wins() -> None: + port = _reserve_port() + listen_addr = f'127.0.0.1:{port}' + shared_peer_id = 'pytest-kcp-shared-peer' + sender_id = 'pytest-kcp-unique-sender' + + with _run_server('kcpserver', listen_addr): + original = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=shared_peer_id) + sender = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=sender_id) + replacement = None + + try: + replacement = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=shared_peer_id) + replacement_stats = replacement.stats() + assert replacement_stats['connected'] == 1 + assert replacement_stats['registered'] == 1 + + with pytest.raises(OSError): + original.recv(timeout_ms=1000) + + payload = b'registered-replacement' + sender.send(to=shared_peer_id, data=payload) + from_peer, msg_type, recv_payload = replacement.recv(timeout_ms=1000) + assert from_peer == sender_id + assert msg_type == MSG_TYPE_BINARY + assert recv_payload == payload + finally: + original.close() + sender.close() + if replacement is not None: + replacement.close() + + +def test_kcp_idle_video_peers_survive_without_receive_loop() -> None: + port = _reserve_port() + listen_addr = f'127.0.0.1:{port}' + sender_id = 'peer-b-video' + receiver_id = 'peer-a-video' + + with _run_server('kcpserver', listen_addr): + sender = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=sender_id) + receiver = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=receiver_id) + + try: + time.sleep(5.0) + + payload = b'idle-video-session-still-alive' + sender.send(to=receiver_id, data=payload) + from_peer, msg_type, recv_payload = receiver.recv(timeout_ms=1000) + assert from_peer == sender_id + assert msg_type == MSG_TYPE_BINARY + assert recv_payload == payload + finally: + sender.close() + receiver.close() + + def test_udp_session_close_interrupts_blocking_recv() -> None: port = _reserve_port() listen_addr = f'127.0.0.1:{port}' diff --git a/ros-control-py/udp_teleop_bridge/udp_teleop_bridge/udp_cmd_vel_receiver.py b/ros-control-py/udp_teleop_bridge/udp_teleop_bridge/udp_cmd_vel_receiver.py index e620516..b954357 100644 --- a/ros-control-py/udp_teleop_bridge/udp_teleop_bridge/udp_cmd_vel_receiver.py +++ b/ros-control-py/udp_teleop_bridge/udp_teleop_bridge/udp_cmd_vel_receiver.py @@ -81,12 +81,7 @@ class UdpCmdVelReceiver(Node): self._msg_type_binary = MSG_TYPE_BINARY self._msg_type_error = MSG_TYPE_ERROR - self._transport = OmniTransport( - transport=self._transport_name, - server_addr=self._server_addr, - relay_via=self._relay_via, - peer_id=self._peer_id, - ) + self._transport = self._create_transport() self._lock = threading.Lock() self._last_log_times: Dict[str, float] = {} @@ -151,6 +146,61 @@ class UdpCmdVelReceiver(Node): self._unix_socket.bind(self._local_socket_path) self._unix_socket.settimeout(0.1) + def _close_unix_socket(self) -> None: + if self._unix_socket is not None: + try: + self._unix_socket.close() + except OSError: + pass + self._unix_socket = None + + def _create_transport(self): + from .omni_transport import OmniTransport + + return OmniTransport( + transport=self._transport_name, + server_addr=self._server_addr, + relay_via=self._relay_via, + peer_id=self._peer_id, + ) + + def _reconnect_transport(self) -> bool: + while not self._closing.is_set() and rclpy.ok(): + current_transport = self._transport + if current_transport is not None: + try: + current_transport.close() + except OSError: + pass + try: + self._transport = self._create_transport() + if self._should_log('transport_reconnected', 1.0): + self.get_logger().info( + 'Reconnected OmniSocket transport %s://%s as %s' + % (self._transport_name, self._server_addr, self._peer_id) + ) + return True + except OSError as exc: + self._transport = None + if self._should_log('transport_reconnect_error', 2.0): + self.get_logger().error(f'Failed to reconnect OmniSocket transport: {exc}') + time.sleep(0.5) + return False + + def _rebind_unix_socket(self) -> bool: + while not self._closing.is_set() and rclpy.ok(): + self._close_unix_socket() + try: + self._setup_unix_socket() + if self._should_log('unix_rebound', 1.0): + self.get_logger().info(f'Rebound unix datagram socket at {self._local_socket_path}') + return True + except OSError as exc: + if self._should_log('unix_rebind_error', 2.0): + self.get_logger().error(f'Failed to rebind unix datagram socket: {exc}') + time.sleep(0.5) + return False + def _should_log(self, key: str, throttle_sec: float) -> bool: now = time.monotonic() previous = self._last_log_times.get(key) @@ -189,7 +239,9 @@ class UdpCmdVelReceiver(Node): except OSError as exc: if not self._closing.is_set() and self._should_log('recv_error', 2.0): self.get_logger().error(f'OmniSocket receive loop stopped: {exc}') - return + if not self._reconnect_transport(): + return + continue if meta is None: continue @@ -244,11 +296,20 @@ class UdpCmdVelReceiver(Node): try: payload = self._unix_socket.recv(DEFAULT_RECV_BUFFER_BYTES) except socket.timeout: + if not os.path.exists(self._local_socket_path): + if self._should_log('unix_socket_missing', 2.0): + self.get_logger().warning( + f'Unix datagram socket path disappeared, rebinding {self._local_socket_path}' + ) + if not self._rebind_unix_socket(): + return continue except OSError as exc: if not self._closing.is_set() and self._should_log('unix_recv_error', 2.0): self.get_logger().error(f'Unix datagram receive loop stopped: {exc}') - return + if not self._rebind_unix_socket(): + return + continue if len(payload) != PACKET_SIZE: if self._should_log('unix_packet_size', 2.0): @@ -305,11 +366,10 @@ class UdpCmdVelReceiver(Node): self._transport = None if self._unix_socket is not None: try: - self._unix_socket.close() + self._close_unix_socket() except OSError as exc: if self._should_log('unix_close_error', 2.0): self.get_logger().warning(f'Closing unix socket failed: {exc}') - self._unix_socket = None try: os.unlink(self._local_socket_path) except FileNotFoundError: diff --git a/src/peer_kcp_client.c b/src/peer_kcp_client.c index 9db5aa8..b7ee51b 100644 --- a/src/peer_kcp_client.c +++ b/src/peer_kcp_client.c @@ -1,24 +1,315 @@ #include "peer_kcp_client.h" #include +#include +#include #include +#define KCP_CLIENT_REGISTER_TIMEOUT_MS 3000 +#define KCP_CLIENT_CTRL_REGISTER_OK "{\"type\":\"server_register_ok\"}" +#define KCP_CLIENT_CTRL_PEER_REPLACED "{\"type\":\"server_peer_replaced\",\"reason\":\"new_instance_wins\"}" +#define KCP_CLIENT_CTRL_HEARTBEAT "{\"type\":\"server_heartbeat\"}" +#define KCP_CLIENT_CTRL_HEARTBEAT_ACK "{\"type\":\"server_heartbeat_ack\"}" + struct kcp_client { char id[OMNI_MAX_PEER_ID]; char server_addr[OMNI_MAX_ADDR_TEXT]; kcp_conn_t *conn; latency_logger_t *logger; - pthread_mutex_t id_mu; + pthread_mutex_t state_mu; uint64_t next_message_id; + int registered; + char last_server_error[256]; }; static int kcp_client_next_message_id(kcp_client_t *client, uint64_t *out_id) { - pthread_mutex_lock(&client->id_mu); + if (client == NULL || out_id == NULL) { + errno = EINVAL; + return -1; + } + pthread_mutex_lock(&client->state_mu); *out_id = ++client->next_message_id; - pthread_mutex_unlock(&client->id_mu); + pthread_mutex_unlock(&client->state_mu); return 0; } +static void kcp_client_set_registered(kcp_client_t *client, int registered) { + if (client == NULL) { + return; + } + pthread_mutex_lock(&client->state_mu); + client->registered = registered != 0; + pthread_mutex_unlock(&client->state_mu); +} + +static void kcp_client_set_last_server_error(kcp_client_t *client, const char *message) { + if (client == NULL) { + return; + } + pthread_mutex_lock(&client->state_mu); + snprintf(client->last_server_error, sizeof(client->last_server_error), "%s", message == NULL ? "" : message); + pthread_mutex_unlock(&client->state_mu); +} + +static void kcp_client_clear_last_server_error(kcp_client_t *client) { + kcp_client_set_last_server_error(client, ""); +} + +static int kcp_client_is_registered(kcp_client_t *client) { + int registered; + + if (client == NULL) { + return 0; + } + pthread_mutex_lock(&client->state_mu); + registered = client->registered; + pthread_mutex_unlock(&client->state_mu); + return registered; +} + +static int kcp_client_text_body_equals(const message_t *msg, const char *payload) { + size_t expected_len; + + if (msg == NULL || payload == NULL || msg->body == NULL) { + return 0; + } + expected_len = strlen(payload); + return msg->body_len == expected_len && memcmp(msg->body, payload, expected_len) == 0; +} + +static void kcp_client_copy_server_error_body(const message_t *msg, char *buffer, size_t buffer_len) { + size_t copy_len; + + if (buffer == NULL || buffer_len == 0) { + return; + } + buffer[0] = '\0'; + if (msg == NULL || msg->body == NULL || msg->body_len == 0) { + return; + } + copy_len = msg->body_len < (buffer_len - 1U) ? msg->body_len : (buffer_len - 1U); + memcpy(buffer, msg->body, copy_len); + buffer[copy_len] = '\0'; +} + +static int kcp_client_registration_errno_from_message(const char *message) { + if (message == NULL || message[0] == '\0') { + return ECONNREFUSED; + } + if (strstr(message, "duplicate peer id") != NULL) { + return EEXIST; + } + if (strstr(message, "first message must be register") != NULL) { + return EPROTO; + } + return ECONNREFUSED; +} + +static int kcp_client_send_text_internal(kcp_client_t *client, const char *to, const char *text, int log_business_event) { + message_t msg; + uint64_t id; + + if (client == NULL || to == NULL || text == NULL || client->conn == NULL) { + errno = EINVAL; + return -1; + } + + protocol_message_init(&msg); + if (kcp_client_next_message_id(client, &id) != 0) { + return -1; + } + msg.type = MSG_TYPE_TEXT; + msg.id = id; + snprintf(msg.from, sizeof(msg.from), "%s", client->id); + snprintf(msg.to, sizeof(msg.to), "%s", to); + msg.body = (uint8_t *) omni_strdup(text); + if (msg.body == NULL) { + return -1; + } + msg.body_len = strlen((const char *) msg.body); + if (log_business_event) { + latencylog_log_message_event(client->logger, OMNI_NODE_ROLE_PEER, client->id, EVENT_A_APP_PREP_BEGIN, &msg); + } + if (kcp_conn_send(client->conn, &msg) != 0) { + protocol_message_clear(&msg); + return -1; + } + protocol_message_clear(&msg); + return 0; +} + +static int kcp_client_send_business_preflight(kcp_client_t *client) { + if (client == NULL || client->conn == NULL) { + errno = ENOTCONN; + return -1; + } + if (!kcp_client_is_registered(client)) { + errno = ENOTCONN; + return -1; + } + return 0; +} + +static int kcp_client_handle_reserved_server_message(kcp_client_t *client, const message_t *msg) { + if (client == NULL || msg == NULL) { + errno = EINVAL; + return -1; + } + if (msg->type != MSG_TYPE_TEXT || strcmp(msg->from, SERVER_PEER_ID) != 0) { + return 0; + } + if (kcp_client_text_body_equals(msg, KCP_CLIENT_CTRL_REGISTER_OK)) { + kcp_client_set_registered(client, 1); + kcp_client_clear_last_server_error(client); + return 1; + } + if (kcp_client_text_body_equals(msg, KCP_CLIENT_CTRL_HEARTBEAT)) { + if (kcp_client_send_text_internal(client, SERVER_PEER_ID, KCP_CLIENT_CTRL_HEARTBEAT_ACK, 0) != 0) { + kcp_client_set_registered(client, 0); + kcp_client_set_last_server_error(client, "failed to acknowledge server heartbeat"); + (void) kcp_conn_close(client->conn); + return -1; + } + return 1; + } + if (kcp_client_text_body_equals(msg, KCP_CLIENT_CTRL_HEARTBEAT_ACK)) { + return 1; + } + if (kcp_client_text_body_equals(msg, KCP_CLIENT_CTRL_PEER_REPLACED)) { + kcp_client_set_registered(client, 0); + kcp_client_set_last_server_error(client, "server peer replaced this session"); + (void) kcp_conn_close(client->conn); + errno = ECONNRESET; + return -1; + } + return 0; +} + +static int kcp_client_remaining_timeout_ms(int original_timeout_ms, uint32_t start_ms) { + uint32_t elapsed_ms; + + if (original_timeout_ms < 0) { + return -1; + } + elapsed_ms = omni_now_millis32() - start_ms; + if (elapsed_ms >= (uint32_t) original_timeout_ms) { + return 0; + } + return original_timeout_ms - (int) elapsed_ms; +} + +static int kcp_client_wait_for_register_ok(kcp_client_t *client) { + uint32_t start_ms; + + if (client == NULL || client->conn == NULL) { + errno = EINVAL; + return -1; + } + + start_ms = omni_now_millis32(); + for (;;) { + message_t msg; + int rc; + int remaining_timeout_ms = kcp_client_remaining_timeout_ms(KCP_CLIENT_REGISTER_TIMEOUT_MS, start_ms); + + if (remaining_timeout_ms <= 0) { + kcp_client_set_registered(client, 0); + kcp_client_set_last_server_error(client, "timed out waiting for server_register_ok"); + (void) kcp_conn_close(client->conn); + errno = ETIMEDOUT; + return -1; + } + + protocol_message_init(&msg); + rc = kcp_conn_receive_timed(client->conn, &msg, remaining_timeout_ms); + if (rc == 1) { + protocol_message_clear(&msg); + kcp_client_set_registered(client, 0); + kcp_client_set_last_server_error(client, "timed out waiting for server_register_ok"); + (void) kcp_conn_close(client->conn); + errno = ETIMEDOUT; + return -1; + } + if (rc != 0) { + protocol_message_clear(&msg); + kcp_client_set_registered(client, 0); + return -1; + } + if (msg.type == MSG_TYPE_ERROR && strcmp(msg.from, SERVER_PEER_ID) == 0) { + char error_text[256]; + + kcp_client_copy_server_error_body(&msg, error_text, sizeof(error_text)); + kcp_client_set_registered(client, 0); + kcp_client_set_last_server_error(client, error_text); + protocol_message_clear(&msg); + (void) kcp_conn_close(client->conn); + errno = kcp_client_registration_errno_from_message(error_text); + return -1; + } + rc = kcp_client_handle_reserved_server_message(client, &msg); + protocol_message_clear(&msg); + if (rc < 0) { + return -1; + } + if (rc > 0 && kcp_client_is_registered(client)) { + return 0; + } + + kcp_client_set_registered(client, 0); + kcp_client_set_last_server_error(client, "unexpected message while waiting for server_register_ok"); + (void) kcp_conn_close(client->conn); + errno = EPROTO; + return -1; + } +} + +static int kcp_client_receive_business_timed(kcp_client_t *client, message_t *out_msg, int timeout_ms) { + uint32_t start_ms; + + if (client == NULL || out_msg == NULL || client->conn == NULL) { + errno = EINVAL; + return -1; + } + + start_ms = omni_now_millis32(); + protocol_message_init(out_msg); + for (;;) { + int rc; + int reserved_rc; + int effective_timeout_ms = timeout_ms < 0 ? -1 : kcp_client_remaining_timeout_ms(timeout_ms, start_ms); + + if (timeout_ms >= 0 && effective_timeout_ms <= 0) { + return 1; + } + protocol_message_clear(out_msg); + rc = kcp_conn_receive_timed(client->conn, out_msg, effective_timeout_ms); + if (rc != 0) { + if (rc != 1) { + kcp_client_set_registered(client, 0); + } + return rc; + } + + reserved_rc = kcp_client_handle_reserved_server_message(client, out_msg); + if (reserved_rc < 0) { + protocol_message_clear(out_msg); + return -1; + } + if (reserved_rc > 0) { + protocol_message_clear(out_msg); + continue; + } + if (out_msg->type == MSG_TYPE_ERROR && strcmp(out_msg->from, SERVER_PEER_ID) == 0) { + char error_text[256]; + + kcp_client_copy_server_error_body(out_msg, error_text, sizeof(error_text)); + kcp_client_set_last_server_error(client, error_text); + } + latencylog_log_message_event(client->logger, OMNI_NODE_ROLE_PEER, client->id, EVENT_B_APP_RECV, out_msg); + return 0; + } +} + static int kcp_client_persist_message_to_disk(const message_t *msg, const char *inbox_dir, char *out_path, size_t out_path_len) { char path[512]; @@ -107,7 +398,7 @@ kcp_client_t *kcp_client_dial_with_options(const char *server_addr, const char * } snprintf(client->id, sizeof(client->id), "%s", peer_id); snprintf(client->server_addr, sizeof(client->server_addr), "%s", server_addr == NULL ? "" : server_addr); - pthread_mutex_init(&client->id_mu, NULL); + pthread_mutex_init(&client->state_mu, NULL); client->logger = logger; client->conn = kcp_conn_dial_with_options(actual_dial_addr, bind_ip, bind_device, options, packet_logger, logger, OMNI_NODE_ROLE_PEER, peer_id, stats_logger, stats_interval_ms); if (client->conn == NULL) { @@ -128,6 +419,12 @@ kcp_client_t *kcp_client_dial_with_options(const char *server_addr, const char * errno = saved_errno; return NULL; } + if (kcp_client_wait_for_register_ok(client) != 0) { + saved_errno = errno; + kcp_client_free(client); + errno = saved_errno; + return NULL; + } return client; } @@ -140,31 +437,14 @@ const char *kcp_client_id(const kcp_client_t *client) { } int kcp_client_send_text(kcp_client_t *client, const char *to, const char *text) { - message_t msg; - uint64_t id; - if (client == NULL || to == NULL || text == NULL) { errno = EINVAL; return -1; } - protocol_message_init(&msg); - kcp_client_next_message_id(client, &id); - msg.type = MSG_TYPE_TEXT; - msg.id = id; - snprintf(msg.from, sizeof(msg.from), "%s", client->id); - snprintf(msg.to, sizeof(msg.to), "%s", to); - msg.body = (uint8_t *) omni_strdup(text); - if (msg.body == NULL) { + if (kcp_client_send_business_preflight(client) != 0) { return -1; } - msg.body_len = strlen((const char *) msg.body); - latencylog_log_message_event(client->logger, OMNI_NODE_ROLE_PEER, client->id, EVENT_A_APP_PREP_BEGIN, &msg); - if (kcp_conn_send(client->conn, &msg) != 0) { - protocol_message_clear(&msg); - return -1; - } - protocol_message_clear(&msg); - return 0; + return kcp_client_send_text_internal(client, to, text, 1); } int kcp_client_send_binary(kcp_client_t *client, const char *to, const void *data, size_t data_len) { @@ -175,8 +455,13 @@ int kcp_client_send_binary(kcp_client_t *client, const char *to, const void *dat errno = EINVAL; return -1; } + if (kcp_client_send_business_preflight(client) != 0) { + return -1; + } protocol_message_init(&msg); - kcp_client_next_message_id(client, &id); + if (kcp_client_next_message_id(client, &id) != 0) { + return -1; + } msg.type = MSG_TYPE_BINARY; msg.id = id; snprintf(msg.from, sizeof(msg.from), "%s", client->id); @@ -209,11 +494,17 @@ int kcp_client_send_file_path(kcp_client_t *client, const char *to, const char * errno = EINVAL; return -1; } + if (kcp_client_send_business_preflight(client) != 0) { + return -1; + } if (omni_read_file(path, &body, &body_len) != 0) { return -1; } protocol_message_init(&msg); - kcp_client_next_message_id(client, &id); + if (kcp_client_next_message_id(client, &id) != 0) { + free(body); + return -1; + } msg.type = MSG_TYPE_FILE; msg.id = id; snprintf(msg.from, sizeof(msg.from), "%s", client->id); @@ -231,18 +522,11 @@ int kcp_client_send_file_path(kcp_client_t *client, const char *to, const char * } int kcp_client_receive_timed(kcp_client_t *client, message_t *out_msg, int timeout_ms) { - int rc; - if (client == NULL || out_msg == NULL) { errno = EINVAL; return -1; } - rc = kcp_conn_receive_timed(client->conn, out_msg, timeout_ms); - if (rc != 0) { - return rc; - } - latencylog_log_message_event(client->logger, OMNI_NODE_ROLE_PEER, client->id, EVENT_B_APP_RECV, out_msg); - return 0; + return kcp_client_receive_business_timed(client, out_msg, timeout_ms); } int kcp_client_receive(kcp_client_t *client, message_t *out_msg) { @@ -264,6 +548,7 @@ int kcp_client_receive_binary_into(kcp_client_t *client, void *buffer, size_t bu protocol_message_init(&msg); rc = kcp_client_receive_timed(client, &msg, timeout_ms); if (rc != 0) { + protocol_message_clear(&msg); return rc; } @@ -294,6 +579,27 @@ int kcp_client_persist_message(kcp_client_t *client, const message_t *msg, const return 0; } +void kcp_client_state_snapshot(kcp_client_t *client, kcp_client_state_t *out_state) { + kcp_runtime_stats_t runtime_stats; + + if (out_state == NULL) { + return; + } + memset(out_state, 0, sizeof(*out_state)); + if (client == NULL) { + return; + } + memset(&runtime_stats, 0, sizeof(runtime_stats)); + if (client->conn != NULL) { + kcp_conn_runtime_stats_snapshot(client->conn, &runtime_stats); + out_state->connected = runtime_stats.connected; + } + pthread_mutex_lock(&client->state_mu); + out_state->registered = client->registered; + snprintf(out_state->last_server_error, sizeof(out_state->last_server_error), "%s", client->last_server_error); + pthread_mutex_unlock(&client->state_mu); +} + void kcp_client_runtime_stats_snapshot(kcp_client_t *client, kcp_runtime_stats_t *out_stats) { if (out_stats == NULL) { return; @@ -307,7 +613,11 @@ void kcp_client_runtime_stats_snapshot(kcp_client_t *client, kcp_runtime_stats_t } int kcp_client_close(kcp_client_t *client) { - return client == NULL ? 0 : kcp_conn_close(client->conn); + if (client == NULL) { + return 0; + } + kcp_client_set_registered(client, 0); + return kcp_conn_close(client->conn); } void kcp_client_free(kcp_client_t *client) { @@ -315,6 +625,6 @@ void kcp_client_free(kcp_client_t *client) { return; } kcp_conn_free(client->conn); - pthread_mutex_destroy(&client->id_mu); + pthread_mutex_destroy(&client->state_mu); free(client); } diff --git a/src/server_kcp_hub.c b/src/server_kcp_hub.c index 91802c4..f422c38 100644 --- a/src/server_kcp_hub.c +++ b/src/server_kcp_hub.c @@ -3,17 +3,29 @@ #include "cJSON.h" #include +#include +#include +#include #include #define KCP_RELAY_MAX_DATAGRAM_SIZE (60 * 1024) +#define KCP_HUB_MAINTENANCE_INTERVAL_MS 250 #define KCP_HUB_DEFAULT_TELEMETRY_INTERVAL_MS 500 +#define KCP_HUB_DEFAULT_HEARTBEAT_INTERVAL_MS 1000 +#define KCP_HUB_DEFAULT_LEASE_TIMEOUT_MS 4000 #define KCP_HUB_TELEMETRY_NODE_ID "hub-telemetry" #define KCP_HUB_DEFAULT_NODE_ID "hub" +#define KCP_HUB_CTRL_REGISTER_OK "{\"type\":\"server_register_ok\"}" +#define KCP_HUB_CTRL_PEER_REPLACED "{\"type\":\"server_peer_replaced\",\"reason\":\"new_instance_wins\"}" +#define KCP_HUB_CTRL_HEARTBEAT "{\"type\":\"server_heartbeat\"}" +#define KCP_HUB_CTRL_HEARTBEAT_ACK "{\"type\":\"server_heartbeat_ack\"}" typedef struct kcp_peer_entry { struct kcp_peer_entry *next; char peer_id[OMNI_MAX_PEER_ID]; kcp_conn_t *conn; + uint32_t last_seen_ms; + uint32_t last_heartbeat_sent_ms; } kcp_peer_entry_t; typedef struct kcp_session_thread_ctx { @@ -21,6 +33,12 @@ typedef struct kcp_session_thread_ctx { kcp_conn_t *conn; } kcp_session_thread_ctx_t; +typedef struct kcp_hub_pending_action { + struct kcp_hub_pending_action *next; + char peer_id[OMNI_MAX_PEER_ID]; + kcp_conn_t *conn; +} kcp_hub_pending_action_t; + struct kcp_hub { pthread_rwlock_t lock; kcp_peer_entry_t *peers; @@ -29,6 +47,8 @@ struct kcp_hub { int stats_interval_ms; char telemetry_peer_id[OMNI_MAX_PEER_ID]; int telemetry_interval_ms; + int heartbeat_interval_ms; + int lease_timeout_ms; pthread_t telemetry_thread; int telemetry_thread_started; int relay_fd; @@ -41,11 +61,65 @@ struct kcp_hub { static int kcp_hub_peer_id_has_suffix(const char *peer_id, const char *suffix); static int kcp_hub_deliver_to_local_peer(kcp_hub_t *hub, const message_t *msg); +static int kcp_hub_send_server_text(kcp_conn_t *conn, const char *to, const char *payload); +static void kcp_hub_touch_peer(kcp_hub_t *hub, const char *peer_id, kcp_conn_t *conn); +static void kcp_hub_run_maintenance(kcp_hub_t *hub); + +static uint32_t kcp_hub_now_ms(void) { + return omni_now_millis32(); +} + +static uint32_t kcp_hub_elapsed_ms(uint32_t now_ms, uint32_t then_ms) { + return now_ms - then_ms; +} + +static int kcp_hub_text_body_equals(const message_t *msg, const char *payload) { + size_t expected_len; + + if (msg == NULL || payload == NULL) { + return 0; + } + expected_len = strlen(payload); + return msg->body_len == expected_len && msg->body != NULL && memcmp(msg->body, payload, expected_len) == 0; +} + +static int kcp_hub_append_pending_action(kcp_hub_pending_action_t **head, const char *peer_id, kcp_conn_t *conn) { + kcp_hub_pending_action_t *action; + + if (head == NULL || peer_id == NULL || conn == NULL) { + errno = EINVAL; + return -1; + } + action = (kcp_hub_pending_action_t *) calloc(1, sizeof(*action)); + if (action == NULL) { + return -1; + } + snprintf(action->peer_id, sizeof(action->peer_id), "%s", peer_id); + action->conn = conn; + action->next = *head; + *head = action; + return 0; +} + +static void kcp_hub_free_pending_actions(kcp_hub_pending_action_t *head) { + while (head != NULL) { + kcp_hub_pending_action_t *next = head->next; + free(head); + head = next; + } +} static int kcp_hub_peer_is_telemetry(const char *peer_id) { return kcp_hub_peer_id_has_suffix(peer_id, "-telemetry"); } +static int kcp_hub_peer_uses_server_lease(const char *peer_id) { + if (peer_id == NULL || peer_id[0] == '\0') { + return 0; + } + return kcp_hub_peer_id_has_suffix(peer_id, "-ctrl") || kcp_hub_peer_is_telemetry(peer_id); +} + static const char *kcp_hub_peer_node_id(const char *peer_id) { return kcp_hub_peer_is_telemetry(peer_id) ? KCP_HUB_TELEMETRY_NODE_ID : KCP_HUB_DEFAULT_NODE_ID; } @@ -118,6 +192,27 @@ static int kcp_hub_configure_peer_transport(kcp_conn_t *conn, const char *peer_i return 0; } +static void kcp_hub_touch_peer_locked(kcp_hub_t *hub, const char *peer_id, kcp_conn_t *conn) { + kcp_peer_entry_t *entry; + + if (hub == NULL || peer_id == NULL || peer_id[0] == '\0') { + return; + } + entry = kcp_hub_find_peer(hub, peer_id); + if (entry != NULL && (conn == NULL || entry->conn == conn)) { + entry->last_seen_ms = kcp_hub_now_ms(); + } +} + +static void kcp_hub_touch_peer(kcp_hub_t *hub, const char *peer_id, kcp_conn_t *conn) { + if (hub == NULL || peer_id == NULL || peer_id[0] == '\0') { + return; + } + pthread_rwlock_wrlock(&hub->lock); + kcp_hub_touch_peer_locked(hub, peer_id, conn); + pthread_rwlock_unlock(&hub->lock); +} + static int kcp_hub_add_runtime_stats_json(cJSON *object, const kcp_runtime_stats_t *stats) { if (object == NULL || stats == NULL) { errno = EINVAL; @@ -287,25 +382,53 @@ static int kcp_hub_push_telemetry_snapshot(kcp_hub_t *hub) { static void *kcp_hub_telemetry_thread_main(void *arg) { kcp_hub_t *hub = (kcp_hub_t *) arg; + uint32_t last_telemetry_push_ms = 0; while (!atomic_load(&hub->closed)) { int interval_ms = KCP_HUB_DEFAULT_TELEMETRY_INTERVAL_MS; + uint32_t now_ms = kcp_hub_now_ms(); + int telemetry_enabled = 0; pthread_rwlock_rdlock(&hub->lock); - if (hub->telemetry_interval_ms > 0) { + telemetry_enabled = hub->telemetry_peer_id[0] != '\0'; + if (telemetry_enabled && hub->telemetry_interval_ms > 0) { interval_ms = hub->telemetry_interval_ms; } pthread_rwlock_unlock(&hub->lock); - (void) kcp_hub_push_telemetry_snapshot(hub); + if (telemetry_enabled && (last_telemetry_push_ms == 0 || kcp_hub_elapsed_ms(now_ms, last_telemetry_push_ms) >= (uint32_t) interval_ms)) { + (void) kcp_hub_push_telemetry_snapshot(hub); + last_telemetry_push_ms = now_ms; + } + kcp_hub_run_maintenance(hub); if (atomic_load(&hub->closed)) { break; } - usleep((useconds_t) interval_ms * 1000U); + usleep((useconds_t) KCP_HUB_MAINTENANCE_INTERVAL_MS * 1000U); } return NULL; } +static int kcp_hub_send_server_text(kcp_conn_t *conn, const char *to, const char *payload) { + message_t msg; + + protocol_message_init(&msg); + msg.type = MSG_TYPE_TEXT; + snprintf(msg.from, sizeof(msg.from), "%s", SERVER_PEER_ID); + snprintf(msg.to, sizeof(msg.to), "%s", (to == NULL || to[0] == '\0') ? "unknown" : to); + msg.body = (uint8_t *) omni_strdup(payload == NULL ? "" : payload); + if (msg.body == NULL) { + return -1; + } + msg.body_len = strlen((const char *) msg.body); + if (kcp_conn_send(conn, &msg) != 0) { + protocol_message_clear(&msg); + return -1; + } + protocol_message_clear(&msg); + return 0; +} + static int kcp_hub_send_server_error(kcp_conn_t *conn, const char *to, const char *message) { message_t msg; protocol_message_init(&msg); @@ -495,8 +618,56 @@ static int kcp_hub_handle_peer_message(kcp_hub_t *hub, const char *peer_id, kcp_ char *error_text = NULL; int relay_status = 0; + kcp_hub_touch_peer(hub, peer_id, conn); switch (msg->type) { case MSG_TYPE_TEXT: + if (strcmp(msg->to, SERVER_PEER_ID) == 0) { + if (kcp_hub_text_body_equals(msg, KCP_HUB_CTRL_HEARTBEAT_ACK)) { + return 0; + } + if (kcp_hub_send_server_error(conn, peer_id, "unsupported server control message") != 0) { + return -1; + } + errno = EPROTO; + return -1; + } + snprintf(msg->from, sizeof(msg->from), "%s", peer_id); + if (kcp_hub_deliver_to_local_peer(hub, msg) == 0) { + return 0; + } + if (errno != ENOENT) { + error_text = omni_strdup_printf("failed to forward to %s", msg->to); + if (error_text == NULL) { + return -1; + } + if (kcp_hub_send_server_error(conn, peer_id, error_text) != 0) { + free(error_text); + return -1; + } + free(error_text); + return 0; + } + if (kcp_hub_forward_to_relay(hub, msg, &relay_status) == 0) { + return 0; + } + if (relay_status == 1) { + error_text = omni_strdup_printf("unknown target: %s", msg->to); + } else if (relay_status == 2) { + error_text = omni_strdup("failed to relay to remote peer"); + } else if (relay_status == 3) { + error_text = omni_strdup("message too large for relay udp"); + } else { + error_text = omni_strdup("failed to relay to remote peer"); + } + if (error_text == NULL) { + return -1; + } + if (kcp_hub_send_server_error(conn, peer_id, error_text) != 0) { + free(error_text); + return -1; + } + free(error_text); + return 0; case MSG_TYPE_FILE: case MSG_TYPE_BINARY: snprintf(msg->from, sizeof(msg->from), "%s", peer_id); @@ -558,9 +729,55 @@ static int kcp_hub_handle_peer_message(kcp_hub_t *hub, const char *peer_id, kcp_ } } +static int kcp_hub_commit_registered_conn( + kcp_hub_t *hub, + const char *peer_id, + kcp_conn_t *conn, + uint32_t now_ms, + kcp_conn_t **out_old_conn +) { + kcp_peer_entry_t *entry; + + if (hub == NULL || peer_id == NULL || peer_id[0] == '\0' || conn == NULL) { + errno = EINVAL; + return -1; + } + if (out_old_conn != NULL) { + *out_old_conn = NULL; + } + + pthread_rwlock_wrlock(&hub->lock); + entry = kcp_hub_find_peer(hub, peer_id); + if (entry != NULL) { + if (out_old_conn != NULL) { + *out_old_conn = entry->conn; + } + entry->conn = conn; + entry->last_seen_ms = now_ms; + entry->last_heartbeat_sent_ms = 0; + pthread_rwlock_unlock(&hub->lock); + return 0; + } + + entry = (kcp_peer_entry_t *) calloc(1, sizeof(*entry)); + if (entry == NULL) { + pthread_rwlock_unlock(&hub->lock); + return -1; + } + snprintf(entry->peer_id, sizeof(entry->peer_id), "%s", peer_id); + entry->conn = conn; + entry->last_seen_ms = now_ms; + entry->last_heartbeat_sent_ms = 0; + entry->next = hub->peers; + hub->peers = entry; + pthread_rwlock_unlock(&hub->lock); + return 0; +} + static int kcp_hub_register_conn(kcp_hub_t *hub, kcp_conn_t *conn, char *peer_id, size_t peer_id_len) { message_t msg; - kcp_peer_entry_t *entry; + kcp_conn_t *old_conn = NULL; + uint32_t now_ms; protocol_message_init(&msg); if (kcp_conn_receive(conn, &msg) != 0) { @@ -574,34 +791,22 @@ static int kcp_hub_register_conn(kcp_hub_t *hub, kcp_conn_t *conn, char *peer_id return -1; } - pthread_rwlock_wrlock(&hub->lock); - entry = kcp_hub_find_peer(hub, msg.from); - if (entry != NULL) { - char *error_text; - pthread_rwlock_unlock(&hub->lock); - error_text = omni_strdup_printf("duplicate peer id: %s", msg.from); - if (error_text != NULL) { - (void) kcp_hub_send_server_error(conn, msg.from, error_text); - free(error_text); - } - protocol_message_clear(&msg); - errno = EEXIST; - return -1; - } - - entry = (kcp_peer_entry_t *) calloc(1, sizeof(*entry)); - if (entry == NULL) { - pthread_rwlock_unlock(&hub->lock); - protocol_message_clear(&msg); - return -1; - } - snprintf(entry->peer_id, sizeof(entry->peer_id), "%s", msg.from); - entry->conn = conn; - entry->next = hub->peers; - hub->peers = entry; - pthread_rwlock_unlock(&hub->lock); - snprintf(peer_id, peer_id_len, "%s", msg.from); + if (kcp_hub_send_server_text(conn, msg.from, KCP_HUB_CTRL_REGISTER_OK) != 0) { + protocol_message_clear(&msg); + return -1; + } + + now_ms = kcp_hub_now_ms(); + if (kcp_hub_commit_registered_conn(hub, msg.from, conn, now_ms, &old_conn) != 0) { + protocol_message_clear(&msg); + return -1; + } + + if (old_conn != NULL && old_conn != conn) { + (void) kcp_hub_send_server_text(old_conn, msg.from, KCP_HUB_CTRL_PEER_REPLACED); + kcp_conn_close(old_conn); + } protocol_message_clear(&msg); return 0; } @@ -623,8 +828,16 @@ kcp_hub_t *kcp_hub_new(latency_logger_t *logger, kcp_session_stats_logger_t *sta hub->stats_logger = stats_logger; hub->stats_interval_ms = stats_interval_ms > 0 ? stats_interval_ms : KCP_DEFAULT_STATS_INTERVAL_MS; hub->telemetry_interval_ms = KCP_HUB_DEFAULT_TELEMETRY_INTERVAL_MS; + hub->heartbeat_interval_ms = KCP_HUB_DEFAULT_HEARTBEAT_INTERVAL_MS; + hub->lease_timeout_ms = KCP_HUB_DEFAULT_LEASE_TIMEOUT_MS; hub->relay_fd = -1; atomic_init(&hub->closed, 0); + if (pthread_create(&hub->telemetry_thread, NULL, kcp_hub_telemetry_thread_main, hub) != 0) { + pthread_rwlock_destroy(&hub->lock); + free(hub); + return NULL; + } + hub->telemetry_thread_started = 1; return hub; } @@ -732,8 +945,6 @@ int kcp_hub_set_relay(kcp_hub_t *hub, int relay_fd, const struct sockaddr *peer_ } int kcp_hub_set_telemetry(kcp_hub_t *hub, const char *peer_id, int interval_ms) { - int start_thread = 0; - if (hub == NULL || peer_id == NULL) { errno = EINVAL; return -1; @@ -741,22 +952,92 @@ int kcp_hub_set_telemetry(kcp_hub_t *hub, const char *peer_id, int interval_ms) pthread_rwlock_wrlock(&hub->lock); snprintf(hub->telemetry_peer_id, sizeof(hub->telemetry_peer_id), "%s", peer_id); hub->telemetry_interval_ms = interval_ms > 0 ? interval_ms : KCP_HUB_DEFAULT_TELEMETRY_INTERVAL_MS; - if (!hub->telemetry_thread_started && hub->telemetry_peer_id[0] != '\0') { - start_thread = 1; - hub->telemetry_thread_started = 1; + pthread_rwlock_unlock(&hub->lock); + return 0; +} + +static void kcp_hub_run_maintenance(kcp_hub_t *hub) { + kcp_hub_pending_action_t *heartbeat_actions = NULL; + kcp_hub_pending_action_t *close_actions = NULL; + uint32_t now_ms; + int heartbeat_interval_ms; + int lease_timeout_ms; + + if (hub == NULL) { + return; + } + + now_ms = kcp_hub_now_ms(); + heartbeat_interval_ms = KCP_HUB_DEFAULT_HEARTBEAT_INTERVAL_MS; + lease_timeout_ms = KCP_HUB_DEFAULT_LEASE_TIMEOUT_MS; + + pthread_rwlock_wrlock(&hub->lock); + if (hub->heartbeat_interval_ms > 0) { + heartbeat_interval_ms = hub->heartbeat_interval_ms; + } + if (hub->lease_timeout_ms > 0) { + lease_timeout_ms = hub->lease_timeout_ms; + } + { + kcp_peer_entry_t *prev = NULL; + kcp_peer_entry_t *entry = hub->peers; + + while (entry != NULL) { + kcp_peer_entry_t *next = entry->next; + uint32_t idle_ms = kcp_hub_elapsed_ms(now_ms, entry->last_seen_ms); + int uses_server_lease = kcp_hub_peer_uses_server_lease(entry->peer_id); + + if (entry->conn == NULL || entry->peer_id[0] == '\0') { + prev = entry; + entry = next; + continue; + } + if (uses_server_lease && lease_timeout_ms > 0 && idle_ms >= (uint32_t) lease_timeout_ms) { + if (prev == NULL) { + hub->peers = next; + } else { + prev->next = next; + } + (void) kcp_hub_append_pending_action(&close_actions, entry->peer_id, entry->conn); + free(entry); + entry = next; + continue; + } + if ( + uses_server_lease + && + heartbeat_interval_ms > 0 + && idle_ms >= (uint32_t) heartbeat_interval_ms + && (entry->last_heartbeat_sent_ms == 0 || kcp_hub_elapsed_ms(now_ms, entry->last_heartbeat_sent_ms) >= (uint32_t) heartbeat_interval_ms) + ) { + entry->last_heartbeat_sent_ms = now_ms; + (void) kcp_hub_append_pending_action(&heartbeat_actions, entry->peer_id, entry->conn); + } + prev = entry; + entry = next; + } } pthread_rwlock_unlock(&hub->lock); - if (start_thread) { - if (pthread_create(&hub->telemetry_thread, NULL, kcp_hub_telemetry_thread_main, hub) != 0) { - pthread_rwlock_wrlock(&hub->lock); - hub->telemetry_thread_started = 0; - hub->telemetry_peer_id[0] = '\0'; - pthread_rwlock_unlock(&hub->lock); - return -1; + while (heartbeat_actions != NULL) { + kcp_hub_pending_action_t *next = heartbeat_actions->next; + if (kcp_hub_send_server_text(heartbeat_actions->conn, heartbeat_actions->peer_id, KCP_HUB_CTRL_HEARTBEAT) != 0) { + kcp_hub_unregister(hub, heartbeat_actions->peer_id, heartbeat_actions->conn); + kcp_conn_close(heartbeat_actions->conn); } + free(heartbeat_actions); + heartbeat_actions = next; } - return 0; + + while (close_actions != NULL) { + kcp_hub_pending_action_t *next = close_actions->next; + kcp_conn_close(close_actions->conn); + free(close_actions); + close_actions = next; + } + + kcp_hub_free_pending_actions(heartbeat_actions); + kcp_hub_free_pending_actions(close_actions); } int kcp_hub_serve_relay(kcp_hub_t *hub) {