Files
tienkung-szu/Deploy_Tienkung/udp_loopback/omnisocket_fsm_controller.py

356 lines
14 KiB
Python

"""Adapter that maps OmniSocket control packets onto a ControlFlag interface."""
from __future__ import annotations
import os
from pathlib import Path
import queue
import struct
import sys
import threading
import time
from typing import Dict, Optional
import yaml
from common.joystick import ControlFlag
try:
from .omnisocket_control import ControlPacket, MotionFrame, format_motion_frame
except ImportError: # pragma: no cover - direct script execution fallback
from omnisocket_control import ControlPacket, MotionFrame, format_motion_frame
def _load_omnisocket_api():
try:
from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session
except ImportError as exc: # pragma: no cover - environment dependent
raise RuntimeError(
"omnisocket is not installed. Install it before using "
"control_tool=omnisocket_loopback."
) from exc
return CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session
def _load_b_side_control_client():
try:
from omnisocket_b_side.client import BSideControlClient
except ImportError:
workspace_root = Path(__file__).resolve().parents[3]
python_root = workspace_root / "OmniSocketGo" / "python"
if str(python_root) not in sys.path:
sys.path.insert(0, str(python_root))
try:
from omnisocket_b_side.client import BSideControlClient
except ImportError as exc: # pragma: no cover - environment dependent
raise RuntimeError(
"omnisocket_b_side is not installed. Install it before using "
"OMNI_TRANSPORT_BACKEND=daemon."
) from exc
return BSideControlClient
class OmniSocketFSMFlag(ControlFlag):
"""FSM-facing flag produced from decoded OmniSocket control packets."""
def __init__(self) -> None:
super().__init__()
self.x_speed_command: float = 0.0
self.y_speed_command: float = 0.0
self.yaw_speed_command: float = 0.0
self.height_cmd: float = 0.89
class OmniSocketFSMController:
"""Receive OmniSocket control packets and expose them as ControlFlag."""
def __init__(self) -> None:
self.config: Dict[str, object] = {}
self.data_mutex = threading.Lock()
self._load_config()
self._init_data_structures()
self.recv_running = False
self.recv_thread: Optional[threading.Thread] = None
self.session = None
self.daemon_client = None
self._msg_type_binary = None
self.transport_backend = str(
os.getenv("OMNI_TRANSPORT_BACKEND", "daemon")
).strip().lower() or "daemon"
def _load_config(self) -> None:
config_path = Path(__file__).resolve().parent / "config" / "omnisocket_demo.yaml"
if config_path.exists():
with config_path.open("r", encoding="utf-8") as file:
self.config = yaml.safe_load(file) or {}
else:
self.config = {}
transport_cfg = self.config.get("transport", {})
receiver_cfg = self.config.get("control_receiver", {})
motion_cfg = self.config.get("motion", {})
self.server_addr = str(transport_cfg.get("server_addr", "127.0.0.1:10909"))
self.relay_via = str(transport_cfg.get("relay_via", ""))
self.bind_ip = str(transport_cfg.get("bind_ip", ""))
self.bind_device = str(transport_cfg.get("bind_device", ""))
self.peer_id = str(receiver_cfg.get("peer_id", "peer-b-ctrl"))
self.ctrl_socket_path = str(
os.getenv("OMNIBDAEMON_CTRL_SOCKET", "/tmp/omnisocket-b-ctrl.sock")
)
self.initial_lift = float(motion_cfg.get("initial_lift", 0.89))
self.lift_step = float(motion_cfg.get("lift_step", 0.05))
self.max_surge = float(motion_cfg.get("max_surge", 1.0))
self.max_sway = float(motion_cfg.get("max_sway", 0.5))
self.max_spin = float(motion_cfg.get("max_spin", 0.5))
self.max_lift = float(motion_cfg.get("max_lift", 0.90))
self.min_lift = float(motion_cfg.get("min_lift", 0.65))
self.surge_step = float(motion_cfg.get("surge_step", 0.1))
self.sway_step = float(motion_cfg.get("sway_step", 0.1))
self.spin_step = float(motion_cfg.get("spin_step", 0.1))
def _init_data_structures(self) -> None:
self.packet_queue: queue.Queue[ControlPacket] = queue.Queue(maxsize=128)
self.motion_frame = MotionFrame(lift_goal=self.initial_lift)
self.udp_flag = OmniSocketFSMFlag()
self.udp_flag.height_cmd = self.initial_lift
self.last_seq_id = -1
self.last_fsm_command_time = 0.0
def start(self) -> None:
if self.transport_backend == "daemon":
daemon_client_cls = _load_b_side_control_client()
self.daemon_client = daemon_client_cls(socket_path=self.ctrl_socket_path)
self.daemon_client.connect()
else:
control_defaults, msg_type_binary, session_cls = _load_omnisocket_api()
self._msg_type_binary = msg_type_binary
self.session = session_cls()
self.session.connect(
server_addr=self.server_addr,
peer_id=self.peer_id,
relay_via=self.relay_via,
bind_ip=self.bind_ip,
bind_device=self.bind_device,
**control_defaults,
)
self.recv_running = True
self.recv_thread = threading.Thread(target=self._recv_loop, daemon=True)
self.recv_thread.start()
if self.transport_backend == "daemon":
print(
"OmniSocket FSM controller connected to B-side daemon "
f"via {self.ctrl_socket_path}"
)
else:
print(
f"OmniSocket FSM controller listening as {self.peer_id} "
f"via {self.server_addr}"
)
def stop(self) -> None:
self.recv_running = False
if self.recv_thread and self.recv_thread.is_alive():
self.recv_thread.join(timeout=1.0)
if self.session is not None:
self.session.close()
self.session = None
if self.daemon_client is not None:
self.daemon_client.close()
self.daemon_client = None
print("OmniSocket FSM controller stopped")
def _recv_loop(self) -> None:
if self.transport_backend == "daemon":
self._recv_loop_daemon()
else:
self._recv_loop_direct()
def _recv_loop_direct(self) -> None:
while self.recv_running and self.session is not None:
item = self.session.recv(timeout_ms=200)
if item is None:
continue
from_peer, msg_type, payload = item
if msg_type != self._msg_type_binary:
print(
f"[omnisocket_fsm] ignore non-binary message "
f"from {from_peer}: {msg_type}"
)
continue
self._enqueue_payload(payload, from_peer=from_peer)
def _recv_loop_daemon(self) -> None:
while self.recv_running:
client = self.daemon_client
if client is None:
return
try:
payload = client.recv_control_packet(timeout_ms=200)
except OSError as exc:
print(f"[omnisocket_fsm] daemon control socket error: {exc}")
try:
client.close()
except OSError:
pass
self.daemon_client = None
if not self.recv_running:
return
time.sleep(0.5)
try:
daemon_client_cls = _load_b_side_control_client()
self.daemon_client = daemon_client_cls(socket_path=self.ctrl_socket_path)
self.daemon_client.connect()
except OSError as reconnect_error:
print(f"[omnisocket_fsm] reconnect daemon socket failed: {reconnect_error}")
time.sleep(0.5)
continue
if payload is None:
continue
self._enqueue_payload(payload, from_peer="daemon")
def _enqueue_payload(self, payload: bytes, *, from_peer: str) -> None:
try:
packet = ControlPacket.decode(payload)
except (ValueError, struct.error) as exc:
print(f"[omnisocket_fsm] drop invalid payload from {from_peer}: {exc}")
return
try:
self.packet_queue.put_nowait(packet)
except queue.Full:
try:
self.packet_queue.get_nowait()
self.packet_queue.put_nowait(packet)
except queue.Empty:
pass
def update_flag(self) -> None:
while not self.packet_queue.empty():
try:
packet = self.packet_queue.get_nowait()
except queue.Empty:
break
self._apply_packet(packet)
def get_udp_flag(self) -> OmniSocketFSMFlag:
with self.data_mutex:
flag_copy = OmniSocketFSMFlag()
flag_copy.__dict__.update(self.udp_flag.__dict__)
return flag_copy
def get_last_input_time(self) -> float:
with self.data_mutex:
return self.motion_frame.last_rx_time
def get_last_fsm_command_time(self) -> float:
with self.data_mutex:
return self.last_fsm_command_time
def init(self) -> int:
print("OmniSocket FSM controller initialized")
return 0
def _apply_packet(self, packet: ControlPacket) -> None:
event_code = packet.event_name
now = time.time()
with self.data_mutex:
self.last_seq_id = packet.seq_id
self.motion_frame.last_event_code = event_code
self.motion_frame.last_rx_time = now
if event_code == "pose_home":
self.motion_frame.mode_tag = "pose_home"
self.last_fsm_command_time = now
elif event_code == "pose_hold":
self.motion_frame.mode_tag = "pose_hold"
self.last_fsm_command_time = now
elif event_code == "mode_stride":
self.motion_frame.mode_tag = "mode_stride"
self.last_fsm_command_time = now
elif event_code == "surge_up":
self.motion_frame.surge_goal = min(
self.max_surge, self.motion_frame.surge_goal + self.surge_step
)
elif event_code == "surge_down":
self.motion_frame.surge_goal = max(
-self.max_surge, self.motion_frame.surge_goal - self.surge_step
)
elif event_code == "sway_left":
self.motion_frame.sway_goal = max(
-self.max_sway, self.motion_frame.sway_goal - self.sway_step
)
elif event_code == "sway_right":
self.motion_frame.sway_goal = min(
self.max_sway, self.motion_frame.sway_goal + self.sway_step
)
elif event_code == "spin_left":
self.motion_frame.spin_goal = max(
-self.max_spin, self.motion_frame.spin_goal - self.spin_step
)
elif event_code == "spin_right":
self.motion_frame.spin_goal = min(
self.max_spin, self.motion_frame.spin_goal + self.spin_step
)
elif event_code == "set_surge":
self.motion_frame.surge_goal = max(
-self.max_surge, min(self.max_surge, packet.drive_value)
)
elif event_code == "set_sway":
self.motion_frame.sway_goal = max(
-self.max_sway, min(self.max_sway, packet.drive_value)
)
elif event_code == "set_spin":
self.motion_frame.spin_goal = max(
-self.max_spin, min(self.max_spin, packet.drive_value)
)
elif event_code == "set_lift":
self.motion_frame.lift_goal = max(
self.min_lift, min(self.max_lift, packet.drive_value)
)
elif event_code == "lift_up":
self.motion_frame.lift_goal = min(
self.max_lift, self.motion_frame.lift_goal + self.lift_step
)
elif event_code == "lift_down":
self.motion_frame.lift_goal = max(
self.min_lift, self.motion_frame.lift_goal - self.lift_step
)
elif event_code == "trim_reset":
self.motion_frame.surge_goal = 0.0
self.motion_frame.sway_goal = 0.0
self.motion_frame.spin_goal = 0.0
elif event_code == "session_quit":
self.motion_frame.relay_on = False
self.motion_frame.mode_tag = "pose_hold"
self.recv_running = False
self._sync_motion_frame_to_flag()
print(
f"[omnisocket_fsm] seq={packet.seq_id} event={event_code} "
f"{format_motion_frame(self.motion_frame)} "
f"fsm={self.udp_flag.fsm_state_command}"
)
def _sync_motion_frame_to_flag(self) -> None:
mode_to_fsm_command = {
"pose_home": "gotoZERO",
"pose_hold": "gotoSTOP",
"mode_stride": "gotoWALKAMP",
}
self.udp_flag.enable = self.motion_frame.relay_on
self.udp_flag.fsm_state_command = mode_to_fsm_command.get(
self.motion_frame.mode_tag, self.udp_flag.fsm_state_command
)
self.udp_flag.x_speed_command = self.motion_frame.surge_goal
self.udp_flag.y_speed_command = self.motion_frame.sway_goal
self.udp_flag.yaw_speed_command = self.motion_frame.spin_goal
self.udp_flag.height_cmd = self.motion_frame.lift_goal