"""Adapter that maps OmniSocket control packets onto a ControlFlag interface.""" from __future__ import annotations import os from pathlib import Path import queue import struct import sys import threading import time from typing import Dict, Optional import yaml from common.joystick import ControlFlag try: from .omnisocket_control import ControlPacket, MotionFrame, format_motion_frame except ImportError: # pragma: no cover - direct script execution fallback from omnisocket_control import ControlPacket, MotionFrame, format_motion_frame def _load_omnisocket_api(): try: from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session except ImportError as exc: # pragma: no cover - environment dependent raise RuntimeError( "omnisocket is not installed. Install it before using " "control_tool=omnisocket_loopback." ) from exc return CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session def _load_b_side_control_client(): try: from omnisocket_b_side.client import BSideControlClient except ImportError: workspace_root = Path(__file__).resolve().parents[3] python_root = workspace_root / "OmniSocketGo" / "python" if str(python_root) not in sys.path: sys.path.insert(0, str(python_root)) try: from omnisocket_b_side.client import BSideControlClient except ImportError as exc: # pragma: no cover - environment dependent raise RuntimeError( "omnisocket_b_side is not installed. Install it before using " "OMNI_TRANSPORT_BACKEND=daemon." ) from exc return BSideControlClient class OmniSocketFSMFlag(ControlFlag): """FSM-facing flag produced from decoded OmniSocket control packets.""" def __init__(self) -> None: super().__init__() self.x_speed_command: float = 0.0 self.y_speed_command: float = 0.0 self.yaw_speed_command: float = 0.0 self.height_cmd: float = 0.89 class OmniSocketFSMController: """Receive OmniSocket control packets and expose them as ControlFlag.""" def __init__(self) -> None: self.config: Dict[str, object] = {} self.data_mutex = threading.Lock() self._load_config() self._init_data_structures() self.recv_running = False self.recv_thread: Optional[threading.Thread] = None self.session = None self.daemon_client = None self._msg_type_binary = None self.transport_backend = str( os.getenv("OMNI_TRANSPORT_BACKEND", "daemon") ).strip().lower() or "daemon" def _load_config(self) -> None: config_path = Path(__file__).resolve().parent / "config" / "omnisocket_demo.yaml" if config_path.exists(): with config_path.open("r", encoding="utf-8") as file: self.config = yaml.safe_load(file) or {} else: self.config = {} transport_cfg = self.config.get("transport", {}) receiver_cfg = self.config.get("control_receiver", {}) motion_cfg = self.config.get("motion", {}) self.server_addr = str(transport_cfg.get("server_addr", "127.0.0.1:10909")) self.relay_via = str(transport_cfg.get("relay_via", "")) self.bind_ip = str(transport_cfg.get("bind_ip", "")) self.bind_device = str(transport_cfg.get("bind_device", "")) self.peer_id = str(receiver_cfg.get("peer_id", "peer-b-ctrl")) self.ctrl_socket_path = str( os.getenv("OMNIBDAEMON_CTRL_SOCKET", "/tmp/omnisocket-b-ctrl.sock") ) self.initial_lift = float(motion_cfg.get("initial_lift", 0.89)) self.lift_step = float(motion_cfg.get("lift_step", 0.05)) self.max_surge = float(motion_cfg.get("max_surge", 1.0)) self.max_sway = float(motion_cfg.get("max_sway", 0.5)) self.max_spin = float(motion_cfg.get("max_spin", 0.5)) self.max_lift = float(motion_cfg.get("max_lift", 0.90)) self.min_lift = float(motion_cfg.get("min_lift", 0.65)) self.surge_step = float(motion_cfg.get("surge_step", 0.1)) self.sway_step = float(motion_cfg.get("sway_step", 0.1)) self.spin_step = float(motion_cfg.get("spin_step", 0.1)) def _init_data_structures(self) -> None: self.packet_queue: queue.Queue[ControlPacket] = queue.Queue(maxsize=128) self.motion_frame = MotionFrame(lift_goal=self.initial_lift) self.udp_flag = OmniSocketFSMFlag() self.udp_flag.height_cmd = self.initial_lift self.last_seq_id = -1 self.last_fsm_command_time = 0.0 def start(self) -> None: if self.transport_backend == "daemon": daemon_client_cls = _load_b_side_control_client() self.daemon_client = daemon_client_cls(socket_path=self.ctrl_socket_path) self.daemon_client.connect() else: control_defaults, msg_type_binary, session_cls = _load_omnisocket_api() self._msg_type_binary = msg_type_binary self.session = session_cls() self.session.connect( server_addr=self.server_addr, peer_id=self.peer_id, relay_via=self.relay_via, bind_ip=self.bind_ip, bind_device=self.bind_device, **control_defaults, ) self.recv_running = True self.recv_thread = threading.Thread(target=self._recv_loop, daemon=True) self.recv_thread.start() if self.transport_backend == "daemon": print( "OmniSocket FSM controller connected to B-side daemon " f"via {self.ctrl_socket_path}" ) else: print( f"OmniSocket FSM controller listening as {self.peer_id} " f"via {self.server_addr}" ) def stop(self) -> None: self.recv_running = False if self.recv_thread and self.recv_thread.is_alive(): self.recv_thread.join(timeout=1.0) if self.session is not None: self.session.close() self.session = None if self.daemon_client is not None: self.daemon_client.close() self.daemon_client = None print("OmniSocket FSM controller stopped") def _recv_loop(self) -> None: if self.transport_backend == "daemon": self._recv_loop_daemon() else: self._recv_loop_direct() def _recv_loop_direct(self) -> None: while self.recv_running and self.session is not None: item = self.session.recv(timeout_ms=200) if item is None: continue from_peer, msg_type, payload = item if msg_type != self._msg_type_binary: print( f"[omnisocket_fsm] ignore non-binary message " f"from {from_peer}: {msg_type}" ) continue self._enqueue_payload(payload, from_peer=from_peer) def _recv_loop_daemon(self) -> None: while self.recv_running: client = self.daemon_client if client is None: return try: payload = client.recv_control_packet(timeout_ms=200) except OSError as exc: print(f"[omnisocket_fsm] daemon control socket error: {exc}") try: client.close() except OSError: pass self.daemon_client = None if not self.recv_running: return time.sleep(0.5) try: daemon_client_cls = _load_b_side_control_client() self.daemon_client = daemon_client_cls(socket_path=self.ctrl_socket_path) self.daemon_client.connect() except OSError as reconnect_error: print(f"[omnisocket_fsm] reconnect daemon socket failed: {reconnect_error}") time.sleep(0.5) continue if payload is None: continue self._enqueue_payload(payload, from_peer="daemon") def _enqueue_payload(self, payload: bytes, *, from_peer: str) -> None: try: packet = ControlPacket.decode(payload) except (ValueError, struct.error) as exc: print(f"[omnisocket_fsm] drop invalid payload from {from_peer}: {exc}") return try: self.packet_queue.put_nowait(packet) except queue.Full: try: self.packet_queue.get_nowait() self.packet_queue.put_nowait(packet) except queue.Empty: pass def update_flag(self) -> None: while not self.packet_queue.empty(): try: packet = self.packet_queue.get_nowait() except queue.Empty: break self._apply_packet(packet) def get_udp_flag(self) -> OmniSocketFSMFlag: with self.data_mutex: flag_copy = OmniSocketFSMFlag() flag_copy.__dict__.update(self.udp_flag.__dict__) return flag_copy def get_last_input_time(self) -> float: with self.data_mutex: return self.motion_frame.last_rx_time def get_last_fsm_command_time(self) -> float: with self.data_mutex: return self.last_fsm_command_time def init(self) -> int: print("OmniSocket FSM controller initialized") return 0 def _apply_packet(self, packet: ControlPacket) -> None: event_code = packet.event_name now = time.time() with self.data_mutex: self.last_seq_id = packet.seq_id self.motion_frame.last_event_code = event_code self.motion_frame.last_rx_time = now if event_code == "pose_home": self.motion_frame.mode_tag = "pose_home" self.last_fsm_command_time = now elif event_code == "pose_hold": self.motion_frame.mode_tag = "pose_hold" self.last_fsm_command_time = now elif event_code == "mode_stride": self.motion_frame.mode_tag = "mode_stride" self.last_fsm_command_time = now elif event_code == "surge_up": self.motion_frame.surge_goal = min( self.max_surge, self.motion_frame.surge_goal + self.surge_step ) elif event_code == "surge_down": self.motion_frame.surge_goal = max( -self.max_surge, self.motion_frame.surge_goal - self.surge_step ) elif event_code == "sway_left": self.motion_frame.sway_goal = max( -self.max_sway, self.motion_frame.sway_goal - self.sway_step ) elif event_code == "sway_right": self.motion_frame.sway_goal = min( self.max_sway, self.motion_frame.sway_goal + self.sway_step ) elif event_code == "spin_left": self.motion_frame.spin_goal = max( -self.max_spin, self.motion_frame.spin_goal - self.spin_step ) elif event_code == "spin_right": self.motion_frame.spin_goal = min( self.max_spin, self.motion_frame.spin_goal + self.spin_step ) elif event_code == "set_surge": self.motion_frame.surge_goal = max( -self.max_surge, min(self.max_surge, packet.drive_value) ) elif event_code == "set_sway": self.motion_frame.sway_goal = max( -self.max_sway, min(self.max_sway, packet.drive_value) ) elif event_code == "set_spin": self.motion_frame.spin_goal = max( -self.max_spin, min(self.max_spin, packet.drive_value) ) elif event_code == "set_lift": self.motion_frame.lift_goal = max( self.min_lift, min(self.max_lift, packet.drive_value) ) elif event_code == "lift_up": self.motion_frame.lift_goal = min( self.max_lift, self.motion_frame.lift_goal + self.lift_step ) elif event_code == "lift_down": self.motion_frame.lift_goal = max( self.min_lift, self.motion_frame.lift_goal - self.lift_step ) elif event_code == "trim_reset": self.motion_frame.surge_goal = 0.0 self.motion_frame.sway_goal = 0.0 self.motion_frame.spin_goal = 0.0 elif event_code == "session_quit": self.motion_frame.relay_on = False self.motion_frame.mode_tag = "pose_hold" self.recv_running = False self._sync_motion_frame_to_flag() print( f"[omnisocket_fsm] seq={packet.seq_id} event={event_code} " f"{format_motion_frame(self.motion_frame)} " f"fsm={self.udp_flag.fsm_state_command}" ) def _sync_motion_frame_to_flag(self) -> None: mode_to_fsm_command = { "pose_home": "gotoZERO", "pose_hold": "gotoSTOP", "mode_stride": "gotoWALKAMP", } self.udp_flag.enable = self.motion_frame.relay_on self.udp_flag.fsm_state_command = mode_to_fsm_command.get( self.motion_frame.mode_tag, self.udp_flag.fsm_state_command ) self.udp_flag.x_speed_command = self.motion_frame.surge_goal self.udp_flag.y_speed_command = self.motion_frame.sway_goal self.udp_flag.yaw_speed_command = self.motion_frame.spin_goal self.udp_flag.height_cmd = self.motion_frame.lift_goal