fix: 当前的 relay 实现原来只记了一个“最后发包的下游客户端地址”

This commit is contained in:
Mock
2026-04-10 12:13:19 +08:00
parent 6c5d410bdc
commit 6cedf859db
2 changed files with 145 additions and 2 deletions

View File

@@ -56,7 +56,31 @@ def _run_server(binary_name: str, listen_addr: str):
process.wait(timeout=2.0)
def _connect_with_retry(session_cls, *, transport: str, server_addr: str, peer_id: str):
@contextmanager
def _run_relay(listen_addr: str, remote_addr: str):
binary = ROOT / 'bin' / 'kcpserver'
if not binary.exists():
pytest.skip(f'{binary} is not built')
process = subprocess.Popen(
[str(binary), '-mode', 'relay', '-listen', listen_addr, '-relay-remote', remote_addr],
cwd=str(ROOT),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
try:
time.sleep(0.2)
yield process
finally:
process.terminate()
try:
process.wait(timeout=2.0)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=2.0)
def _connect_with_retry(session_cls, *, transport: str, server_addr: str, peer_id: str, relay_via: str = ''):
deadline = time.monotonic() + 3.0
last_error: Exception | None = None
@@ -69,6 +93,8 @@ def _connect_with_retry(session_cls, *, transport: str, server_addr: str, peer_i
}
if transport == 'kcp':
kwargs.update(CONTROL_DEFAULTS)
if relay_via:
kwargs['relay_via'] = relay_via
else:
kwargs['enable_timestamping'] = False
session.connect(**kwargs)
@@ -210,6 +236,36 @@ def test_kcp_peer_a_video_stale_receiver_is_evicted() -> None:
receiver.close()
def test_kcp_relay_routes_multiple_sessions_by_conv() -> None:
hub_port = _reserve_port()
relay_port = _reserve_port()
hub_addr = f'127.0.0.1:{hub_port}'
relay_addr = f'127.0.0.1:{relay_port}'
with _run_server('kcpserver', hub_addr):
with _run_relay(relay_addr, hub_addr):
sender = _connect_with_retry(Session, transport='kcp', server_addr=hub_addr, peer_id='pytest-relay-sender', relay_via=relay_addr)
receiver = _connect_with_retry(Session, transport='kcp', server_addr=hub_addr, peer_id='pytest-relay-receiver', relay_via=relay_addr)
chatter = _connect_with_retry(Session, transport='kcp', server_addr=hub_addr, peer_id='pytest-relay-chatter', relay_via=relay_addr)
try:
chatter.send(to='pytest-relay-sender', data=b'chatter-primes-last-client')
from_peer, msg_type, recv_payload = sender.recv(timeout_ms=1000)
assert from_peer == 'pytest-relay-chatter'
assert msg_type == MSG_TYPE_BINARY
assert recv_payload == b'chatter-primes-last-client'
sender.send(to='pytest-relay-receiver', data=b'relay-video-frame')
from_peer, msg_type, recv_payload = receiver.recv(timeout_ms=1000)
assert from_peer == 'pytest-relay-sender'
assert msg_type == MSG_TYPE_BINARY
assert recv_payload == b'relay-video-frame'
finally:
sender.close()
receiver.close()
chatter.close()
def test_udp_session_close_interrupts_blocking_recv() -> None:
port = _reserve_port()
listen_addr = f'127.0.0.1:{port}'

View File

@@ -15,6 +15,7 @@ struct udp_relay {
struct sockaddr_storage client_addr;
socklen_t client_addr_len;
int has_client;
struct udp_relay_route *routes;
pthread_mutex_t lock;
pthread_mutex_t log_mu;
pthread_mutex_t state_mu;
@@ -29,6 +30,13 @@ struct udp_relay {
int closed;
};
typedef struct udp_relay_route {
struct udp_relay_route *next;
uint32_t conv;
struct sockaddr_storage client_addr;
socklen_t client_addr_len;
} udp_relay_route_t;
static void udp_relay_parse_kcp_summary(const uint8_t *packet, size_t len, int *has_conv, uint32_t *conv, size_t *segment_count) {
size_t offset = 0;
size_t count = 0;
@@ -139,6 +147,38 @@ static void udp_relay_record_client(udp_relay_t *relay, const struct sockaddr_st
pthread_mutex_unlock(&relay->lock);
}
static int udp_relay_record_route(udp_relay_t *relay, uint32_t conv, const struct sockaddr_storage *addr, socklen_t addr_len) {
udp_relay_route_t *route;
if (relay == NULL || addr == NULL || addr_len == 0) {
errno = EINVAL;
return -1;
}
pthread_mutex_lock(&relay->lock);
for (route = relay->routes; route != NULL; route = route->next) {
if (route->conv == conv) {
memcpy(&route->client_addr, addr, sizeof(*addr));
route->client_addr_len = addr_len;
pthread_mutex_unlock(&relay->lock);
return 0;
}
}
route = (udp_relay_route_t *) calloc(1, sizeof(*route));
if (route == NULL) {
pthread_mutex_unlock(&relay->lock);
return -1;
}
route->conv = conv;
memcpy(&route->client_addr, addr, sizeof(*addr));
route->client_addr_len = addr_len;
route->next = relay->routes;
relay->routes = route;
pthread_mutex_unlock(&relay->lock);
return 0;
}
static int udp_relay_copy_client(udp_relay_t *relay, struct sockaddr_storage *addr, socklen_t *addr_len) {
int has_client;
@@ -152,6 +192,42 @@ static int udp_relay_copy_client(udp_relay_t *relay, struct sockaddr_storage *ad
return has_client;
}
static int udp_relay_copy_route(udp_relay_t *relay, uint32_t conv, struct sockaddr_storage *addr, socklen_t *addr_len) {
udp_relay_route_t *route;
pthread_mutex_lock(&relay->lock);
for (route = relay->routes; route != NULL; route = route->next) {
if (route->conv == conv) {
memcpy(addr, &route->client_addr, sizeof(*addr));
*addr_len = route->client_addr_len;
pthread_mutex_unlock(&relay->lock);
return 1;
}
}
pthread_mutex_unlock(&relay->lock);
return 0;
}
static void udp_relay_clear_routes(udp_relay_t *relay) {
udp_relay_route_t *route;
udp_relay_route_t *next;
if (relay == NULL) {
return;
}
pthread_mutex_lock(&relay->lock);
route = relay->routes;
relay->routes = NULL;
pthread_mutex_unlock(&relay->lock);
while (route != NULL) {
next = route->next;
free(route);
route = next;
}
}
static void *udp_relay_forward_downstream_to_upstream(void *arg) {
udp_relay_t *relay = (udp_relay_t *) arg;
uint8_t buffer[UDP_RELAY_BUF_SIZE];
@@ -160,6 +236,8 @@ static void *udp_relay_forward_downstream_to_upstream(void *arg) {
struct sockaddr_storage source;
socklen_t source_len = sizeof(source);
ssize_t n = recvfrom(relay->downstream_fd, buffer, sizeof(buffer), 0, (struct sockaddr *) &source, &source_len);
int has_conv = 0;
uint32_t conv = 0;
if (n < 0) {
int errnum = errno;
@@ -175,6 +253,10 @@ static void *udp_relay_forward_downstream_to_upstream(void *arg) {
}
udp_relay_record_client(relay, &source, source_len);
udp_relay_parse_kcp_summary(buffer, (size_t) n, &has_conv, &conv, NULL);
if (has_conv) {
(void) udp_relay_record_route(relay, conv, &source, source_len);
}
udp_relay_print_packet(relay, "relay_downstream_rx", relay->downstream_local_addr, &source, source_len, buffer, (size_t) n);
for (;;) {
if (send(relay->upstream_fd, buffer, (size_t) n, 0) >= 0) {
@@ -205,6 +287,8 @@ static void *udp_relay_forward_upstream_to_downstream(void *arg) {
struct sockaddr_storage client_addr;
socklen_t client_addr_len = 0;
ssize_t n = recv(relay->upstream_fd, buffer, sizeof(buffer), 0);
int has_conv = 0;
uint32_t conv = 0;
if (n < 0) {
int errnum = errno;
@@ -220,7 +304,9 @@ static void *udp_relay_forward_upstream_to_downstream(void *arg) {
}
udp_relay_print_packet(relay, "relay_upstream_rx", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n);
if (!udp_relay_copy_client(relay, &client_addr, &client_addr_len)) {
udp_relay_parse_kcp_summary(buffer, (size_t) n, &has_conv, &conv, NULL);
if ((has_conv && !udp_relay_copy_route(relay, conv, &client_addr, &client_addr_len)) &&
!udp_relay_copy_client(relay, &client_addr, &client_addr_len)) {
udp_relay_print_packet(relay, "relay_upstream_drop_no_client", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n);
continue;
}
@@ -409,6 +495,7 @@ void udp_relay_free(udp_relay_t *relay) {
}
udp_relay_close(relay);
udp_relay_join_threads(relay);
udp_relay_clear_routes(relay);
pthread_mutex_destroy(&relay->lock);
pthread_mutex_destroy(&relay->log_mu);
pthread_cond_destroy(&relay->state_cond);