diff --git a/python/tests/test_sessions.py b/python/tests/test_sessions.py index ecd02c1..e9785dc 100644 --- a/python/tests/test_sessions.py +++ b/python/tests/test_sessions.py @@ -56,7 +56,31 @@ def _run_server(binary_name: str, listen_addr: str): process.wait(timeout=2.0) -def _connect_with_retry(session_cls, *, transport: str, server_addr: str, peer_id: str): +@contextmanager +def _run_relay(listen_addr: str, remote_addr: str): + binary = ROOT / 'bin' / 'kcpserver' + if not binary.exists(): + pytest.skip(f'{binary} is not built') + + process = subprocess.Popen( + [str(binary), '-mode', 'relay', '-listen', listen_addr, '-relay-remote', remote_addr], + cwd=str(ROOT), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + try: + time.sleep(0.2) + yield process + finally: + process.terminate() + try: + process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=2.0) + + +def _connect_with_retry(session_cls, *, transport: str, server_addr: str, peer_id: str, relay_via: str = ''): deadline = time.monotonic() + 3.0 last_error: Exception | None = None @@ -69,6 +93,8 @@ def _connect_with_retry(session_cls, *, transport: str, server_addr: str, peer_i } if transport == 'kcp': kwargs.update(CONTROL_DEFAULTS) + if relay_via: + kwargs['relay_via'] = relay_via else: kwargs['enable_timestamping'] = False session.connect(**kwargs) @@ -210,6 +236,36 @@ def test_kcp_peer_a_video_stale_receiver_is_evicted() -> None: receiver.close() +def test_kcp_relay_routes_multiple_sessions_by_conv() -> None: + hub_port = _reserve_port() + relay_port = _reserve_port() + hub_addr = f'127.0.0.1:{hub_port}' + relay_addr = f'127.0.0.1:{relay_port}' + + with _run_server('kcpserver', hub_addr): + with _run_relay(relay_addr, hub_addr): + sender = _connect_with_retry(Session, transport='kcp', server_addr=hub_addr, peer_id='pytest-relay-sender', relay_via=relay_addr) + receiver = _connect_with_retry(Session, transport='kcp', server_addr=hub_addr, peer_id='pytest-relay-receiver', relay_via=relay_addr) + chatter = _connect_with_retry(Session, transport='kcp', server_addr=hub_addr, peer_id='pytest-relay-chatter', relay_via=relay_addr) + + try: + chatter.send(to='pytest-relay-sender', data=b'chatter-primes-last-client') + from_peer, msg_type, recv_payload = sender.recv(timeout_ms=1000) + assert from_peer == 'pytest-relay-chatter' + assert msg_type == MSG_TYPE_BINARY + assert recv_payload == b'chatter-primes-last-client' + + sender.send(to='pytest-relay-receiver', data=b'relay-video-frame') + from_peer, msg_type, recv_payload = receiver.recv(timeout_ms=1000) + assert from_peer == 'pytest-relay-sender' + assert msg_type == MSG_TYPE_BINARY + assert recv_payload == b'relay-video-frame' + finally: + sender.close() + receiver.close() + chatter.close() + + def test_udp_session_close_interrupts_blocking_recv() -> None: port = _reserve_port() listen_addr = f'127.0.0.1:{port}' diff --git a/src/server_udp_relay.c b/src/server_udp_relay.c index ecf0eca..3ca2ce5 100644 --- a/src/server_udp_relay.c +++ b/src/server_udp_relay.c @@ -15,6 +15,7 @@ struct udp_relay { struct sockaddr_storage client_addr; socklen_t client_addr_len; int has_client; + struct udp_relay_route *routes; pthread_mutex_t lock; pthread_mutex_t log_mu; pthread_mutex_t state_mu; @@ -29,6 +30,13 @@ struct udp_relay { int closed; }; +typedef struct udp_relay_route { + struct udp_relay_route *next; + uint32_t conv; + struct sockaddr_storage client_addr; + socklen_t client_addr_len; +} udp_relay_route_t; + static void udp_relay_parse_kcp_summary(const uint8_t *packet, size_t len, int *has_conv, uint32_t *conv, size_t *segment_count) { size_t offset = 0; size_t count = 0; @@ -139,6 +147,38 @@ static void udp_relay_record_client(udp_relay_t *relay, const struct sockaddr_st pthread_mutex_unlock(&relay->lock); } +static int udp_relay_record_route(udp_relay_t *relay, uint32_t conv, const struct sockaddr_storage *addr, socklen_t addr_len) { + udp_relay_route_t *route; + + if (relay == NULL || addr == NULL || addr_len == 0) { + errno = EINVAL; + return -1; + } + + pthread_mutex_lock(&relay->lock); + for (route = relay->routes; route != NULL; route = route->next) { + if (route->conv == conv) { + memcpy(&route->client_addr, addr, sizeof(*addr)); + route->client_addr_len = addr_len; + pthread_mutex_unlock(&relay->lock); + return 0; + } + } + + route = (udp_relay_route_t *) calloc(1, sizeof(*route)); + if (route == NULL) { + pthread_mutex_unlock(&relay->lock); + return -1; + } + route->conv = conv; + memcpy(&route->client_addr, addr, sizeof(*addr)); + route->client_addr_len = addr_len; + route->next = relay->routes; + relay->routes = route; + pthread_mutex_unlock(&relay->lock); + return 0; +} + static int udp_relay_copy_client(udp_relay_t *relay, struct sockaddr_storage *addr, socklen_t *addr_len) { int has_client; @@ -152,6 +192,42 @@ static int udp_relay_copy_client(udp_relay_t *relay, struct sockaddr_storage *ad return has_client; } +static int udp_relay_copy_route(udp_relay_t *relay, uint32_t conv, struct sockaddr_storage *addr, socklen_t *addr_len) { + udp_relay_route_t *route; + + pthread_mutex_lock(&relay->lock); + for (route = relay->routes; route != NULL; route = route->next) { + if (route->conv == conv) { + memcpy(addr, &route->client_addr, sizeof(*addr)); + *addr_len = route->client_addr_len; + pthread_mutex_unlock(&relay->lock); + return 1; + } + } + pthread_mutex_unlock(&relay->lock); + return 0; +} + +static void udp_relay_clear_routes(udp_relay_t *relay) { + udp_relay_route_t *route; + udp_relay_route_t *next; + + if (relay == NULL) { + return; + } + + pthread_mutex_lock(&relay->lock); + route = relay->routes; + relay->routes = NULL; + pthread_mutex_unlock(&relay->lock); + + while (route != NULL) { + next = route->next; + free(route); + route = next; + } +} + static void *udp_relay_forward_downstream_to_upstream(void *arg) { udp_relay_t *relay = (udp_relay_t *) arg; uint8_t buffer[UDP_RELAY_BUF_SIZE]; @@ -160,6 +236,8 @@ static void *udp_relay_forward_downstream_to_upstream(void *arg) { struct sockaddr_storage source; socklen_t source_len = sizeof(source); ssize_t n = recvfrom(relay->downstream_fd, buffer, sizeof(buffer), 0, (struct sockaddr *) &source, &source_len); + int has_conv = 0; + uint32_t conv = 0; if (n < 0) { int errnum = errno; @@ -175,6 +253,10 @@ static void *udp_relay_forward_downstream_to_upstream(void *arg) { } udp_relay_record_client(relay, &source, source_len); + udp_relay_parse_kcp_summary(buffer, (size_t) n, &has_conv, &conv, NULL); + if (has_conv) { + (void) udp_relay_record_route(relay, conv, &source, source_len); + } udp_relay_print_packet(relay, "relay_downstream_rx", relay->downstream_local_addr, &source, source_len, buffer, (size_t) n); for (;;) { if (send(relay->upstream_fd, buffer, (size_t) n, 0) >= 0) { @@ -205,6 +287,8 @@ static void *udp_relay_forward_upstream_to_downstream(void *arg) { struct sockaddr_storage client_addr; socklen_t client_addr_len = 0; ssize_t n = recv(relay->upstream_fd, buffer, sizeof(buffer), 0); + int has_conv = 0; + uint32_t conv = 0; if (n < 0) { int errnum = errno; @@ -220,7 +304,9 @@ static void *udp_relay_forward_upstream_to_downstream(void *arg) { } udp_relay_print_packet(relay, "relay_upstream_rx", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n); - if (!udp_relay_copy_client(relay, &client_addr, &client_addr_len)) { + udp_relay_parse_kcp_summary(buffer, (size_t) n, &has_conv, &conv, NULL); + if ((has_conv && !udp_relay_copy_route(relay, conv, &client_addr, &client_addr_len)) && + !udp_relay_copy_client(relay, &client_addr, &client_addr_len)) { udp_relay_print_packet(relay, "relay_upstream_drop_no_client", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n); continue; } @@ -409,6 +495,7 @@ void udp_relay_free(udp_relay_t *relay) { } udp_relay_close(relay); udp_relay_join_threads(relay); + udp_relay_clear_routes(relay); pthread_mutex_destroy(&relay->lock); pthread_mutex_destroy(&relay->log_mu); pthread_cond_destroy(&relay->state_cond);