from __future__ import annotations import json import socket import sys import threading import time from typing import Any from .common import ( CONTROL_PACKET_SIZE, CONTROL_SOURCE_NATIVE_UDP, CONTROL_SOURCE_PRIORITY, JsonlRunLogger, ZERO_CONTROL_PAYLOAD, WORKSPACE_ROOT, load_omnisocket_config, parse_host_port, ) from .video import safe_kcp_stats class ControlAckTracker: def __init__(self) -> None: self._lock = threading.Lock() self._event_logger = JsonlRunLogger("BLITZ_A_CONTROL_EVENTS_LOG_PATH", "a-control-events") self._ack_logger = JsonlRunLogger("BLITZ_A_CONTROL_ACKS_LOG_PATH", "a-control-acks") self._pending: dict[int, dict[str, Any]] = {} self._latest_estimate: dict[str, Any] = { "ack_available": False, "updated_at": None, "received_mono_ns": 0, "control_loop_rtt_ms": None, "b_recv_to_persist_ms": None, "control_oneway_network_est_ms": None, "control_to_persist_est_ms": None, "sample_reason": None, } def register_send( self, *, message_id: int, issued_at_unix_ns: int, issued_at_mono_ns: int, source: str, payload: bytes, send_call_latency_us: int, ) -> None: event = { "ts_unix_nano": issued_at_unix_ns, "message_id": message_id, "issued_at_unix_ns": issued_at_unix_ns, "issued_at_mono_ns": issued_at_mono_ns, "source": source, "command_signature": payload.hex(), "payload_size": len(payload), "send_call_latency_us": send_call_latency_us, } with self._lock: self._pending[message_id] = event self._prune_locked(issued_at_mono_ns) self._event_logger.write(event) def handle_ack(self, ack_payload: dict[str, Any], received_unix_ns: int, received_mono_ns: int) -> None: try: message_id = int(ack_payload["message_id"]) except (KeyError, TypeError, ValueError): return with self._lock: event = self._pending.pop(message_id, None) self._prune_locked(received_mono_ns) if event is None: return try: control_loop_rtt_ms = round((received_unix_ns - int(event["issued_at_unix_ns"])) / 1_000_000.0, 3) b_recv_to_persist_ms = round(float(ack_payload.get("b_recv_to_persist_us", 0)) / 1000.0, 3) except (TypeError, ValueError): return control_oneway_network_est_ms = round(max(0.0, (control_loop_rtt_ms - b_recv_to_persist_ms) / 2.0), 3) control_to_persist_est_ms = round(control_oneway_network_est_ms + b_recv_to_persist_ms, 3) ack_record = { "ts_unix_nano": received_unix_ns, "received_unix_ns": received_unix_ns, "received_mono_ns": received_mono_ns, "message_id": message_id, "ack_phase": str(ack_payload.get("ack_phase") or "persist_end"), "sample_reason": str(ack_payload.get("sample_reason") or ""), "b_recv_to_persist_us": ack_payload.get("b_recv_to_persist_us"), "unix_send_ok": bool(ack_payload.get("unix_send_ok", False)), "issued_at_unix_ns": event["issued_at_unix_ns"], "source": event["source"], "control_loop_rtt_ms": control_loop_rtt_ms, "b_recv_to_persist_ms": b_recv_to_persist_ms, "control_oneway_network_est_ms": control_oneway_network_est_ms, "control_to_persist_est_ms": control_to_persist_est_ms, } self._ack_logger.write(ack_record) with self._lock: self._latest_estimate = { "ack_available": True, "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(received_unix_ns / 1_000_000_000)), "received_mono_ns": received_mono_ns, "control_loop_rtt_ms": control_loop_rtt_ms, "b_recv_to_persist_ms": b_recv_to_persist_ms, "control_oneway_network_est_ms": control_oneway_network_est_ms, "control_to_persist_est_ms": control_to_persist_est_ms, "sample_reason": ack_record["sample_reason"], } def get_latest_estimate(self) -> dict[str, Any]: with self._lock: estimate = dict(self._latest_estimate) if int(estimate.get("received_mono_ns", 0) or 0) > 0 and time.monotonic_ns() - int(estimate["received_mono_ns"]) > 10_000_000_000: estimate["ack_available"] = False estimate["control_loop_rtt_ms"] = None estimate["b_recv_to_persist_ms"] = None estimate["control_oneway_network_est_ms"] = None estimate["control_to_persist_est_ms"] = None estimate["sample_reason"] = None estimate.pop("received_mono_ns", None) return estimate def close(self) -> None: self._event_logger.close() self._ack_logger.close() def _prune_locked(self, now_mono_ns: int) -> None: stale_ids = [ message_id for message_id, event in self._pending.items() if now_mono_ns - int(event.get("issued_at_mono_ns", 0)) > 60_000_000_000 ] for message_id in stale_ids: self._pending.pop(message_id, None) class OmniSocketControlSender: def __init__(self, ack_tracker: ControlAckTracker) -> None: self._lock = threading.Lock() self._ack_tracker = ack_tracker self._session = None self._session_cls = None self._msg_type_error = None self._control_defaults: dict[str, Any] = {} self._started = False self._drain_thread: threading.Thread | None = None self._closing = threading.Event() self._target_peer = "" self._send_count = 0 self._send_errors = 0 self._drain_errors = 0 self._last_error = "" self._reconnect_count = 0 self._ever_connected = False self._registered = False self._supports_send_with_id = False self._load_backend() def _load_backend(self) -> None: try: self._import_backend() except Exception as error: # pragma: no cover - optional runtime dependency self._last_error = f"omnisocket import failed: {error}" def _import_backend(self) -> None: try: from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_ERROR, Session # type: ignore except ImportError: python_dir = WORKSPACE_ROOT / "OmniSocketGo" / "python" if python_dir.exists(): sys.path.insert(0, str(python_dir)) from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_ERROR, Session # type: ignore self._session_cls = Session self._msg_type_error = MSG_TYPE_ERROR self._control_defaults = dict(CONTROL_DEFAULTS) def _connect_session(self): assert self._session_cls is not None config = load_omnisocket_config() transport_cfg = config.get("transport", {}) control_cfg = config.get("control_sender", {}) session = self._session_cls() session.connect( server_addr=str(transport_cfg.get("server_addr", "127.0.0.1:10909")), peer_id=str(control_cfg.get("peer_id", "peer-a-ctrl")), relay_via=str(transport_cfg.get("relay_via", "")), bind_ip=str(transport_cfg.get("bind_ip", "")), bind_device=str(transport_cfg.get("bind_device", "")), **self._control_defaults, ) target_peer = str(control_cfg.get("target_peer", "peer-b-ctrl")) return session, target_peer def ensure_started(self) -> None: if self._session_cls is None: return with self._lock: if self._closing.is_set(): return if self._started and self._session is not None: return session, target_peer = self._connect_session() self._session = session self._target_peer = target_peer self._closing.clear() self._started = True self._last_error = "" self._registered = bool(dict(session.stats()).get("registered", 0)) self._supports_send_with_id = hasattr(session, "send_with_id") if self._ever_connected: self._reconnect_count += 1 else: self._ever_connected = True self._drain_thread = threading.Thread( target=self._drain_loop, name="omnisocket-control-drain", daemon=True, ) self._drain_thread.start() def _reset_session(self, session: Any | None) -> None: with self._lock: if session is not None and session is not self._session: return current = self._session self._session = None self._started = False self._registered = False self._supports_send_with_id = False if current is not None: try: current.close() except Exception: pass def send_payload(self, payload: bytes, *, source: str) -> None: if len(payload) != CONTROL_PACKET_SIZE: raise ValueError(f"expected {CONTROL_PACKET_SIZE} bytes, got {len(payload)}") self.ensure_started() with self._lock: session = self._session target_peer = self._target_peer supports_send_with_id = self._supports_send_with_id if session is None: raise RuntimeError("control session is not available") try: issued_at_unix_ns = time.time_ns() issued_at_mono_ns = time.monotonic_ns() send_started_ns = time.perf_counter_ns() message_id: int | None = None if supports_send_with_id: message_id = int(session.send_with_id(to=target_peer, data=payload)) else: session.send(to=target_peer, data=payload) send_call_latency_us = max(0, int((time.perf_counter_ns() - send_started_ns) / 1000)) except Exception as error: with self._lock: self._send_errors += 1 self._last_error = str(error) self._reset_session(session) raise if message_id is not None: self._ack_tracker.register_send( message_id=message_id, issued_at_unix_ns=issued_at_unix_ns, issued_at_mono_ns=issued_at_mono_ns, source=source, payload=payload, send_call_latency_us=send_call_latency_us, ) with self._lock: self._send_count += 1 def send_zero_burst(self, count: int) -> None: for _ in range(max(0, count)): try: self.send_payload(ZERO_CONTROL_PAYLOAD, source="zero_burst") except Exception: return def _drain_loop(self) -> None: while not self._closing.is_set(): with self._lock: session = self._session if session is None: return try: result = session.recv(timeout_ms=100) except Exception as error: last_server_error = "" try: last_server_error = str(dict(session.stats()).get("last_server_error", "") or "") except Exception: last_server_error = "" with self._lock: self._drain_errors += 1 self._registered = False self._last_error = last_server_error or str(error) if not self._closing.is_set(): self._reset_session(session) return if result is None: try: stats = dict(session.stats()) except Exception: stats = {} with self._lock: self._registered = bool(stats.get("registered", 0)) if stats.get("last_server_error"): self._last_error = str(stats.get("last_server_error")) continue from_peer, msg_type, payload = result if msg_type == self._msg_type_error: text = payload.decode("utf-8", errors="replace") try: stats = dict(session.stats()) except Exception: stats = {} with self._lock: self._last_error = f"server error from {from_peer}: {text}" self._registered = bool(stats.get("registered", 0)) def session_stats(self) -> dict[str, Any]: with self._lock: session = self._session if session is None: return {"connected": 0, "registered": 0, "last_server_error": self._last_error} try: return dict(session.stats()) except Exception: return {"connected": 0, "registered": 0, "last_server_error": self._last_error} def session_kcp_stats(self) -> dict[str, Any]: with self._lock: session = self._session return safe_kcp_stats(session) def get_status(self) -> dict[str, Any]: config = load_omnisocket_config() control_cfg = config.get("control_sender", {}) session_stats = self.session_stats() with self._lock: return { "backend_ready": self._session_cls is not None, "started": self._started, "connected": self._session is not None, "registered": bool(session_stats.get("registered", 0)), "peer_id": str(control_cfg.get("peer_id", "")), "target_peer": str(control_cfg.get("target_peer", "")), "send_count": self._send_count, "send_errors": self._send_errors, "drain_errors": self._drain_errors, "reconnect_count": self._reconnect_count, "last_server_error": str(session_stats.get("last_server_error", "") or ""), "last_error": self._last_error, } def close(self) -> None: self._closing.set() self.send_zero_burst(1) self._reset_session(None) drain_thread = self._drain_thread if drain_thread is not None and drain_thread.is_alive(): drain_thread.join(timeout=0.5) class OmniSocketControlAckReceiver: def __init__(self, ack_tracker: ControlAckTracker) -> None: self._ack_tracker = ack_tracker self._lock = threading.Lock() self._thread: threading.Thread | None = None self._started = False self._session = None self._session_cls = None self._msg_type_text = None self._msg_type_error = None self._control_defaults: dict[str, Any] = {} self._closing = threading.Event() self._last_error = "" self._reconnect_count = 0 self._ever_connected = False self._load_backend() def _load_backend(self) -> None: try: self._import_backend() except Exception as error: # pragma: no cover self._last_error = f"omnisocket import failed: {error}" def _import_backend(self) -> None: try: from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_ERROR, MSG_TYPE_TEXT, Session # type: ignore except ImportError: python_dir = WORKSPACE_ROOT / "OmniSocketGo" / "python" if python_dir.exists(): sys.path.insert(0, str(python_dir)) from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_ERROR, MSG_TYPE_TEXT, Session # type: ignore self._session_cls = Session self._msg_type_text = MSG_TYPE_TEXT self._msg_type_error = MSG_TYPE_ERROR self._control_defaults = dict(CONTROL_DEFAULTS) def _connect_session(self): assert self._session_cls is not None config = load_omnisocket_config() transport_cfg = config.get("transport", {}) ack_cfg = config.get("control_ack_receiver", {}) session = self._session_cls() session.connect( server_addr=str(transport_cfg.get("server_addr", "127.0.0.1:10909")), peer_id=str(ack_cfg.get("peer_id", "peer-a-ctrl-ack")), relay_via=str(transport_cfg.get("relay_via", "")), bind_ip=str(transport_cfg.get("bind_ip", "")), bind_device=str(transport_cfg.get("bind_device", "")), **self._control_defaults, ) return session, str(ack_cfg.get("expected_sender", "peer-b-ctrl-ack")) def ensure_started(self) -> None: if self._session_cls is None: return with self._lock: if self._started or self._closing.is_set(): return self._started = True self._thread = threading.Thread(target=self._run, name="omnisocket-control-ack", daemon=True) self._thread.start() def _run(self) -> None: while not self._closing.is_set(): expected_sender = "" try: session, expected_sender = self._connect_session() with self._lock: self._session = session self._last_error = "" if self._ever_connected: self._reconnect_count += 1 else: self._ever_connected = True while not self._closing.is_set(): result = session.recv(timeout_ms=1000) if result is None: continue from_peer, msg_type, payload = result if msg_type == self._msg_type_error: with self._lock: self._last_error = f"ack session error from {from_peer}: {payload.decode('utf-8', errors='replace')}" continue if msg_type != self._msg_type_text: continue if expected_sender and from_peer != expected_sender: continue try: ack_payload = json.loads(payload.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError): continue self._ack_tracker.handle_ack(ack_payload, time.time_ns(), time.monotonic_ns()) except Exception as error: # pragma: no cover if not self._closing.is_set(): with self._lock: self._last_error = str(error) time.sleep(2) finally: if self._session is not None: try: self._session.close() except Exception: pass with self._lock: self._session = None if self._closing.is_set(): self._started = False def get_status(self) -> dict[str, Any]: config = load_omnisocket_config().get("control_ack_receiver", {}) with self._lock: return { "backend_ready": self._session_cls is not None, "started": self._started, "connected": self._session is not None, "peer_id": str(config.get("peer_id", "")), "expected_sender": str(config.get("expected_sender", "")), "reconnect_count": self._reconnect_count, "last_error": self._last_error, } def close(self) -> None: self._closing.set() with self._lock: session = self._session if session is not None: try: session.close() except Exception: pass thread = self._thread if thread is not None and thread.is_alive(): thread.join(timeout=0.5) class ControlArbiter: def __init__(self, sender: OmniSocketControlSender) -> None: self._sender = sender self._lock = threading.Lock() self._thread: threading.Thread | None = None self._closing = threading.Event() self._started = False self._source_lease_ms = 300 self._send_rate_hz = 20.0 self._zero_burst_packets = 3 self._latest_by_source: dict[str, tuple[bytes, float]] = {} self._packet_counts = {source: 0 for source in CONTROL_SOURCE_PRIORITY} self._last_payload = ZERO_CONTROL_PAYLOAD self._last_sent_at = 0.0 self._active_source: str | None = None self._last_error = "" def _load_config(self) -> None: cfg = load_omnisocket_config().get("control_ingress", {}) self._source_lease_ms = max(50, int(cfg.get("source_lease_ms", 300))) self._send_rate_hz = max(1.0, float(cfg.get("send_rate_hz", 20.0))) self._zero_burst_packets = max(1, int(cfg.get("zero_burst_packets", 3))) def ensure_started(self) -> None: self._load_config() with self._lock: if self._closing.is_set(): return if self._started: return self._started = True self._thread = threading.Thread( target=self._send_loop, name="control-arbiter", daemon=True, ) self._thread.start() def ingest_command(self, source: str, payload: bytes) -> None: if source not in CONTROL_SOURCE_PRIORITY: raise ValueError(f"unsupported control source: {source}") if len(payload) != CONTROL_PACKET_SIZE: raise ValueError(f"expected {CONTROL_PACKET_SIZE} bytes, got {len(payload)}") self.ensure_started() now = time.monotonic() with self._lock: self._latest_by_source[source] = (payload, now) self._packet_counts[source] += 1 def _resolve_active_locked(self, now: float) -> tuple[str | None, bytes, int]: lease_seconds = self._source_lease_ms / 1000.0 expired_sources = [ source for source, (_, updated_at) in self._latest_by_source.items() if (now - updated_at) > lease_seconds ] for source in expired_sources: self._latest_by_source.pop(source, None) for source in CONTROL_SOURCE_PRIORITY: entry = self._latest_by_source.get(source) if entry is None: continue payload, updated_at = entry remaining_ms = max(0, int((lease_seconds - (now - updated_at)) * 1000)) return source, payload, remaining_ms return None, ZERO_CONTROL_PAYLOAD, 0 def _send_loop(self) -> None: interval = 1.0 / max(self._send_rate_hz, 1.0) previous_active: str | None = None while not self._closing.is_set(): now = time.monotonic() with self._lock: active_source, payload, _lease_ms = self._resolve_active_locked(now) self._active_source = active_source self._last_payload = payload if previous_active is not None and active_source is None: try: self._sender.send_zero_burst(self._zero_burst_packets) except Exception as error: with self._lock: self._last_error = str(error) elif active_source is not None: try: self._sender.send_payload(payload, source=active_source) with self._lock: self._last_sent_at = time.monotonic() self._last_error = "" except Exception as error: with self._lock: self._last_error = str(error) previous_active = active_source self._closing.wait(interval) try: self._sender.send_zero_burst(self._zero_burst_packets) except Exception: pass def get_status(self) -> dict[str, Any]: self.ensure_started() now = time.monotonic() with self._lock: active_source, _payload, lease_ms = self._resolve_active_locked(now) return { "active_source": active_source, "control_lease_remaining_ms": lease_ms, "packet_counts": dict(self._packet_counts), "send_rate_hz": self._send_rate_hz, "source_lease_ms": self._source_lease_ms, "zero_burst_packets": self._zero_burst_packets, "last_error": self._last_error, "last_sent_at_monotonic": self._last_sent_at, } def close(self) -> None: self._closing.set() thread = self._thread if thread is not None and thread.is_alive(): thread.join(timeout=0.5) class NativeUdpControlIngress: def __init__(self, arbiter: ControlArbiter) -> None: self._arbiter = arbiter self._lock = threading.Lock() self._thread: threading.Thread | None = None self._closing = threading.Event() self._started = False self._bind_addr = "127.0.0.1:10921" self._packets_received = 0 self._invalid_packets = 0 self._last_sender = "" self._last_error = "" def ensure_started(self) -> None: bind_addr = str(load_omnisocket_config().get("control_ingress", {}).get("native_udp_bind", "127.0.0.1:10921")) with self._lock: self._bind_addr = bind_addr if self._closing.is_set(): return if self._thread is not None and self._thread.is_alive(): return self._started = True self._thread = threading.Thread( target=self._run, name="native-udp-control-ingress", daemon=True, ) self._thread.start() def _run(self) -> None: try: try: host, port = parse_host_port(self._bind_addr) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) sock.settimeout(0.1) except Exception as error: with self._lock: self._last_error = str(error) return with sock: while not self._closing.is_set(): try: payload, sender_addr = sock.recvfrom(CONTROL_PACKET_SIZE + 64) except socket.timeout: continue except OSError as error: with self._lock: if not self._closing.is_set(): self._last_error = str(error) return with self._lock: self._last_sender = f"{sender_addr[0]}:{sender_addr[1]}" if len(payload) != CONTROL_PACKET_SIZE: with self._lock: self._invalid_packets += 1 continue try: self._arbiter.ingest_command(CONTROL_SOURCE_NATIVE_UDP, payload) except Exception as error: with self._lock: self._last_error = str(error) continue with self._lock: self._packets_received += 1 finally: with self._lock: self._started = False self._thread = None def get_status(self) -> dict[str, Any]: self.ensure_started() with self._lock: return { "started": self._started, "bind_addr": self._bind_addr, "packets_received": self._packets_received, "invalid_packets": self._invalid_packets, "last_sender": self._last_sender, "last_error": self._last_error, } def close(self) -> None: self._closing.set() thread = self._thread if thread is not None and thread.is_alive(): thread.join(timeout=0.5)