Files
OmniSocketGo/python/tests/test_sessions.py
2026-04-10 11:11:03 +08:00

247 lines
8.5 KiB
Python

from __future__ import annotations
from contextlib import contextmanager
from pathlib import Path
import socket
import subprocess
import sys
import threading
import time
import pytest
pytestmark = pytest.mark.skipif(sys.platform != 'linux', reason='Linux-only OmniSocket extension')
ROOT = Path(__file__).resolve().parents[2]
PYTHON_ROOT = ROOT / 'python'
if str(PYTHON_ROOT) not in sys.path:
sys.path.insert(0, str(PYTHON_ROOT))
omnisocket = pytest.importorskip('omnisocket')
CONTROL_DEFAULTS = omnisocket.CONTROL_DEFAULTS
MSG_TYPE_BINARY = omnisocket.MSG_TYPE_BINARY
Session = omnisocket.Session
UdpSession = omnisocket.UdpSession
def _reserve_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(('127.0.0.1', 0))
return int(sock.getsockname()[1])
@contextmanager
def _run_server(binary_name: str, listen_addr: str):
binary = ROOT / 'bin' / binary_name
if not binary.exists():
pytest.skip(f'{binary} is not built')
process = subprocess.Popen(
[str(binary), '-listen', listen_addr],
cwd=str(ROOT),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
try:
time.sleep(0.2)
yield process
finally:
process.terminate()
try:
process.wait(timeout=2.0)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=2.0)
def _connect_with_retry(session_cls, *, transport: str, server_addr: str, peer_id: str):
deadline = time.monotonic() + 3.0
last_error: Exception | None = None
while time.monotonic() < deadline:
session = session_cls()
try:
kwargs: dict[str, object] = {
'server_addr': server_addr,
'peer_id': peer_id,
}
if transport == 'kcp':
kwargs.update(CONTROL_DEFAULTS)
else:
kwargs['enable_timestamping'] = False
session.connect(**kwargs)
return session
except OSError as exc:
last_error = exc
time.sleep(0.1)
raise AssertionError(f'failed to connect {peer_id} to {server_addr}: {last_error}')
@pytest.mark.parametrize(
('transport', 'binary_name', 'session_cls'),
[
('udp', 'udpserver', UdpSession),
('kcp', 'kcpserver', Session),
],
)
def test_control_sessions_smoke(transport: str, binary_name: str, session_cls) -> None:
port = _reserve_port()
listen_addr = f'127.0.0.1:{port}'
sender_id = f'pytest-{transport}-sender'
receiver_id = f'pytest-{transport}-receiver'
with _run_server(binary_name, listen_addr):
sender = _connect_with_retry(session_cls, transport=transport, server_addr=listen_addr, peer_id=sender_id)
receiver = _connect_with_retry(session_cls, transport=transport, server_addr=listen_addr, peer_id=receiver_id)
try:
assert receiver.recv(timeout_ms=20) is None
payload = b'control-packet-1'
sender.send(to=receiver_id, data=payload)
from_peer, msg_type, recv_payload = receiver.recv(timeout_ms=1000)
assert from_peer == sender_id
assert msg_type == MSG_TYPE_BINARY
assert recv_payload == payload
payload2 = b'control-packet-2'
sender.send(to=receiver_id, data=payload2)
recv_buffer = bytearray(128)
meta = receiver.recv_into(buffer=recv_buffer, timeout_ms=1000)
assert meta is not None
assert meta['from'] == sender_id
assert meta['msg_type'] == MSG_TYPE_BINARY
assert meta['body_len'] == len(payload2)
assert bytes(recv_buffer[: meta['body_len']]) == payload2
sender_stats = sender.stats()
receiver_stats = receiver.stats()
assert sender_stats['connected'] == 1
assert receiver_stats['connected'] == 1
assert sender_stats['registered'] == 1
assert receiver_stats['registered'] == 1
assert sender_stats['send_calls'] >= 2
assert receiver_stats['recv_calls'] >= 2
if transport == 'kcp':
sender_kcp_stats = sender.kcp_stats()
receiver_kcp_stats = receiver.kcp_stats()
assert sender_kcp_stats['connected'] == 1
assert receiver_kcp_stats['connected'] == 1
assert 'srtt_ms' in sender_kcp_stats
assert 'snd_queue' in receiver_kcp_stats
finally:
sender.close()
receiver.close()
def test_kcp_duplicate_peer_new_instance_wins() -> None:
port = _reserve_port()
listen_addr = f'127.0.0.1:{port}'
shared_peer_id = 'pytest-kcp-shared-peer'
sender_id = 'pytest-kcp-unique-sender'
with _run_server('kcpserver', listen_addr):
original = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=shared_peer_id)
sender = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=sender_id)
replacement = None
try:
replacement = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=shared_peer_id)
replacement_stats = replacement.stats()
assert replacement_stats['connected'] == 1
assert replacement_stats['registered'] == 1
with pytest.raises(OSError):
original.recv(timeout_ms=1000)
payload = b'registered-replacement'
sender.send(to=shared_peer_id, data=payload)
from_peer, msg_type, recv_payload = replacement.recv(timeout_ms=1000)
assert from_peer == sender_id
assert msg_type == MSG_TYPE_BINARY
assert recv_payload == payload
finally:
original.close()
sender.close()
if replacement is not None:
replacement.close()
def test_kcp_idle_video_peers_survive_without_receive_loop() -> None:
port = _reserve_port()
listen_addr = f'127.0.0.1:{port}'
sender_id = 'peer-b-video'
receiver_id = 'peer-a-video'
with _run_server('kcpserver', listen_addr):
sender = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=sender_id)
receiver = _connect_with_retry(Session, transport='kcp', server_addr=listen_addr, peer_id=receiver_id)
try:
time.sleep(5.0)
payload = b'idle-video-session-still-alive'
sender.send(to=receiver_id, data=payload)
from_peer, msg_type, recv_payload = receiver.recv(timeout_ms=1000)
assert from_peer == sender_id
assert msg_type == MSG_TYPE_BINARY
assert recv_payload == payload
finally:
sender.close()
receiver.close()
def test_udp_session_close_interrupts_blocking_recv() -> None:
port = _reserve_port()
listen_addr = f'127.0.0.1:{port}'
receiver_id = 'pytest-udp-blocking-recv'
with _run_server('udpserver', listen_addr):
receiver = _connect_with_retry(
UdpSession,
transport='udp',
server_addr=listen_addr,
peer_id=receiver_id,
)
recv_error: list[BaseException] = []
close_error: list[BaseException] = []
recv_started = threading.Event()
recv_done = threading.Event()
close_done = threading.Event()
def recv_worker() -> None:
recv_started.set()
try:
receiver.recv()
except BaseException as exc: # pragma: no cover - assertion is on thread completion
recv_error.append(exc)
finally:
recv_done.set()
def close_worker() -> None:
try:
receiver.close()
except BaseException as exc: # pragma: no cover - assertion is on thread completion
close_error.append(exc)
finally:
close_done.set()
recv_thread = threading.Thread(target=recv_worker, daemon=True)
recv_thread.start()
assert recv_started.wait(timeout=1.0)
time.sleep(0.05)
close_thread = threading.Thread(target=close_worker, daemon=True)
close_thread.start()
assert close_done.wait(timeout=1.0), 'UdpSession.close() blocked while recv() was waiting'
assert recv_done.wait(timeout=1.0), 'UdpSession.recv() stayed blocked after close()'
assert not close_thread.is_alive()
assert not recv_thread.is_alive()
assert not close_error
assert not recv_error or isinstance(recv_error[0], OSError)