"""ROS2 Joy -> OmniSocket bridge for Xbox control.""" from __future__ import annotations from pathlib import Path from typing import Dict import rclpy from rclpy.node import Node from rclpy.qos import HistoryPolicy, QoSProfile, ReliabilityPolicy from sensor_msgs.msg import Joy import yaml try: from .omnisocket_control import make_control_packet except ImportError: # pragma: no cover - direct script execution fallback from omnisocket_control import make_control_packet def _load_omnisocket_api(): try: from omnisocket import CONTROL_DEFAULTS, Session except ImportError as exc: # pragma: no cover - environment dependent raise RuntimeError( "omnisocket is not installed. Install it before using " "omnisocket_xbox_sender.py." ) from exc return CONTROL_DEFAULTS, Session class OmniSocketXboxSender(Node): """Subscribe to Joy messages and forward them through OmniSocket.""" def __init__(self) -> None: super().__init__("omnisocket_xbox_sender") self.config: Dict[str, object] = {} self.seq_id = 0 self.last_buttons: Dict[str, int] = {} self.last_dpad_h = 0.0 self.session = None self._load_config() self._init_session() qos_profile = QoSProfile( reliability=ReliabilityPolicy.RELIABLE, history=HistoryPolicy.KEEP_LAST, depth=10, ) self.subscription = self.create_subscription( Joy, self.joy_topic, self._joy_callback, qos_profile ) self.get_logger().info( f"Forwarding {self.joy_topic} -> OmniSocket " f"{self.peer_id} -> {self.target_peer}" ) self.get_logger().info( "Buttons: A=WALKAMP X=ZERO Y=STOP START=reset" ) def destroy_node(self) -> bool: if self.session is not None: self.session.close() self.session = None return super().destroy_node() def _load_config(self) -> None: omni_config_path = ( Path(__file__).resolve().parent / "config" / "omnisocket_demo.yaml" ) main_config_path = Path(__file__).resolve().parents[1] / "config" / "dex_config.yaml" if omni_config_path.exists(): with omni_config_path.open("r", encoding="utf-8") as file: self.config = yaml.safe_load(file) or {} else: self.config = {} with main_config_path.open("r", encoding="utf-8") as file: main_config = yaml.safe_load(file) or {} transport_cfg = self.config.get("transport", {}) sender_cfg = self.config.get("control_sender", {}) xbox_cfg = main_config.get("xbox", {}) self.server_addr = str(transport_cfg.get("server_addr", "127.0.0.1:10909")) self.relay_via = str(transport_cfg.get("relay_via", "")) self.bind_ip = str(transport_cfg.get("bind_ip", "")) self.bind_device = str(transport_cfg.get("bind_device", "")) self.peer_id = str(sender_cfg.get("peer_id", "peer-a-ctrl")) self.target_peer = str(sender_cfg.get("target_peer", "peer-b-ctrl")) self.joy_topic = str(sender_cfg.get("joy_topic", "/xbox_data")) self.deadzone = float(sender_cfg.get("deadzone", 0.10)) self.analog_epsilon = float(sender_cfg.get("analog_epsilon", 0.01)) self.dpad_threshold = float(sender_cfg.get("dpad_threshold", 0.50)) self.trigger_pressed_threshold = float( sender_cfg.get("trigger_pressed_threshold", -0.50) ) self.forward_command_offset = float( sender_cfg.get( "forward_command_offset", xbox_cfg.get("forward_command_offset", 0.0), ) ) self.lateral_command_offset = float( sender_cfg.get( "lateral_command_offset", xbox_cfg.get("lateral_command_offset", 0.0), ) ) self.rotation_command_offset = float( sender_cfg.get( "rotation_command_offset", xbox_cfg.get("rotation_command_offset", 0.0), ) ) self.button_map = { "a": 0, "b": 1, "x": 2, "y": 3, "lb": 4, "rb": 5, "select": 6, "start": 7, "home": 8, } self.axis_map = { "lx": 0, "ly": 1, "l_trigger": 2, "rx": 3, "ry": 4, "r_trigger": 5, "dpad_h": 6, "dpad_v": 7, } self._merge_mapping(self.button_map, xbox_cfg.get("button_map")) self._merge_mapping(self.axis_map, xbox_cfg.get("axis_map")) self._merge_mapping(self.button_map, sender_cfg.get("button_map")) self._merge_mapping(self.axis_map, sender_cfg.get("axis_map")) def _merge_mapping(self, target: Dict[str, int], override: object) -> None: if not isinstance(override, dict): return for name, index in override.items(): if name in target: try: target[name] = int(index) except (TypeError, ValueError): pass def _init_session(self) -> None: control_defaults, session_cls = _load_omnisocket_api() self.session = session_cls() self.session.connect( server_addr=self.server_addr, peer_id=self.peer_id, relay_via=self.relay_via, bind_ip=self.bind_ip, bind_device=self.bind_device, **control_defaults, ) def _joy_callback(self, msg: Joy) -> None: axes = list(msg.axes) + [0.0] * 16 buttons = list(msg.buttons) + [0] * 32 state = { "a": self._button_value(buttons, "a"), "x": self._button_value(buttons, "x"), "y": self._button_value(buttons, "y"), "start": self._button_value(buttons, "start"), "lx": self._axis_value(axes, "lx"), "ly": self._axis_value(axes, "ly"), "rx": self._axis_value(axes, "rx"), "dpad_h": self._axis_value(axes, "dpad_h"), } self._send_mode_events(state) self._send_trim_event(state) self._send_lift_events(state) self._send_analog_events(state) self.last_buttons = { name: int(state[name]) for name in ("a", "x", "y", "start") } self.last_dpad_h = float(state["dpad_h"]) def _button_value(self, buttons: list[int], name: str) -> int: index = self.button_map[name] return int(buttons[index]) if index < len(buttons) else 0 def _axis_value(self, axes: list[float], name: str) -> float: index = self.axis_map[name] return float(axes[index]) if index < len(axes) else 0.0 def _send_mode_events(self, state: Dict[str, float]) -> None: if self._rising_edge(state, "y"): self._send_event("pose_hold", "y") elif self._rising_edge(state, "x"): self._send_event("pose_home", "x") elif self._rising_edge(state, "a"): self._send_event("mode_stride", "a") def _send_trim_event(self, state: Dict[str, float]) -> None: if self._rising_edge(state, "start"): self._send_event("trim_reset", "start") def _send_lift_events(self, state: Dict[str, float]) -> None: dpad_h = float(state["dpad_h"]) if dpad_h <= -self.dpad_threshold and self.last_dpad_h > -self.dpad_threshold: self._send_event("lift_up", "dpad_left") elif dpad_h >= self.dpad_threshold and self.last_dpad_h < self.dpad_threshold: self._send_event("lift_down", "dpad_right") def _send_analog_events(self, state: Dict[str, float]) -> None: surge = self._compute_surge(state["ly"]) sway = self._cleanup_command( self._apply_deadzone(state["lx"]) * -0.4 + self.lateral_command_offset ) spin = self._cleanup_command( self._apply_deadzone(state["rx"]) * -0.4 + self.rotation_command_offset ) self._send_event("set_surge", "left_stick_y", surge) self._send_event("set_sway", "left_stick_x", sway) self._send_event("set_spin", "right_stick_x", spin) def _compute_surge(self, ly: float) -> float: ly = self._apply_deadzone(ly) if ly >= 0.0: value = ly * 0.8 + self.forward_command_offset else: value = ly * 0.5 return self._cleanup_command(value) def _apply_deadzone(self, value: float) -> float: if abs(value) < self.deadzone: return 0.0 return float(value) def _cleanup_command(self, value: float) -> float: if abs(value) < self.analog_epsilon: return 0.0 return float(value) def _rising_edge(self, state: Dict[str, float], name: str) -> bool: previous = int(self.last_buttons.get(name, 0)) return int(state[name]) == 1 and previous == 0 def _send_event( self, event_code: str, key_name: str, drive_value: float = 1.0 ) -> None: if self.session is None: return packet = make_control_packet(self.seq_id, event_code, drive_value) self.seq_id += 1 self.session.send(to=self.target_peer, data=packet.encode()) self.get_logger().debug( f"sent seq={packet.seq_id} event={event_code} key={key_name}" ) def main(args: list[str] | None = None) -> None: rclpy.init(args=args) node = OmniSocketXboxSender() try: rclpy.spin(node) except KeyboardInterrupt: pass finally: node.destroy_node() rclpy.shutdown() if __name__ == "__main__": main()