feat: 基于Python ROS2的控制程序

This commit is contained in:
2026-04-03 20:00:33 +08:00
parent 6ece408d9f
commit 9ffc36f50d
26 changed files with 2193 additions and 38 deletions

View File

@@ -0,0 +1 @@
"""OmniSocket teleop bridge package."""

View File

@@ -0,0 +1,207 @@
"""ROS 2 node that forwards TwistStamped teleop commands over OmniSocket."""
from __future__ import annotations
import threading
import time
from typing import Dict, Optional, Tuple
import rclpy
from geometry_msgs.msg import TwistStamped
from rclpy.node import Node
from .omni_transport import MSG_TYPE_ERROR, OmniTransport
from .protocol import (
DEFAULT_EXIT_ZERO_PACKETS,
DEFAULT_INPUT_TIMEOUT,
DEFAULT_INPUT_TOPIC,
DEFAULT_KEYBOARD_PEER_ID,
DEFAULT_QUEUE_DEPTH,
DEFAULT_SEND_RATE_HZ,
DEFAULT_TARGET_PEER,
DEFAULT_TRANSPORT,
ZERO_COMMAND,
pack_command,
)
CommandTuple = Tuple[float, float, float, float, float, float]
class CmdVelUdpSender(Node):
"""Forward TwistStamped messages to a remote OmniSocket peer."""
def __init__(self) -> None:
super().__init__('cmd_vel_udp_sender')
self.declare_parameter('transport', DEFAULT_TRANSPORT)
self.declare_parameter('server_addr', '')
self.declare_parameter('relay_via', '')
self.declare_parameter('peer_id', DEFAULT_KEYBOARD_PEER_ID)
self.declare_parameter('target_peer', DEFAULT_TARGET_PEER)
self.declare_parameter('input_topic', DEFAULT_INPUT_TOPIC)
self.declare_parameter('send_rate_hz', DEFAULT_SEND_RATE_HZ)
self.declare_parameter('input_timeout', DEFAULT_INPUT_TIMEOUT)
self.declare_parameter('queue_depth', DEFAULT_QUEUE_DEPTH)
self.declare_parameter('exit_zero_packets', DEFAULT_EXIT_ZERO_PACKETS)
self._transport_name = str(self.get_parameter('transport').value)
self._server_addr = str(self.get_parameter('server_addr').value)
self._relay_via = str(self.get_parameter('relay_via').value)
self._peer_id = str(self.get_parameter('peer_id').value)
self._target_peer = str(self.get_parameter('target_peer').value).strip()
self._input_topic = str(self.get_parameter('input_topic').value)
self._send_rate_hz = float(self.get_parameter('send_rate_hz').value)
self._input_timeout = float(self.get_parameter('input_timeout').value)
self._queue_depth = int(self.get_parameter('queue_depth').value)
self._exit_zero_packets = int(self.get_parameter('exit_zero_packets').value)
if self._send_rate_hz <= 0.0:
raise ValueError('send_rate_hz must be > 0')
if self._input_timeout < 0.0:
raise ValueError('input_timeout must be >= 0')
if self._queue_depth <= 0:
raise ValueError('queue_depth must be > 0')
if not self._target_peer:
raise ValueError('target_peer must not be empty')
self._transport = OmniTransport(
transport=self._transport_name,
server_addr=self._server_addr,
relay_via=self._relay_via,
peer_id=self._peer_id,
)
self._last_log_times: Dict[str, float] = {}
self._latest_command: CommandTuple = ZERO_COMMAND
self._last_input_monotonic: Optional[float] = None
self._last_sent_command: Optional[CommandTuple] = None
self._closing = threading.Event()
self.create_subscription(
TwistStamped,
self._input_topic,
self._handle_twist,
self._queue_depth,
)
self.create_timer(1.0 / self._send_rate_hz, self._send_latest_command)
self._drain_thread = threading.Thread(target=self._drain_incoming, daemon=True)
self._drain_thread.start()
self.get_logger().info(
'Forwarding TwistStamped from %s via %s://%s as %s -> %s at %.1f Hz '
'(input timeout %.2f s)'
% (
self._input_topic,
self._transport.transport,
self._transport.server_addr,
self._peer_id,
self._target_peer,
self._send_rate_hz,
self._input_timeout,
)
)
def _should_log(self, key: str, throttle_sec: float) -> bool:
now = time.monotonic()
previous = self._last_log_times.get(key)
if previous is None or (now - previous) >= throttle_sec:
self._last_log_times[key] = now
return True
return False
def _handle_twist(self, msg: TwistStamped) -> None:
self._latest_command = (
float(msg.twist.linear.x),
float(msg.twist.linear.y),
float(msg.twist.linear.z),
float(msg.twist.angular.x),
float(msg.twist.angular.y),
float(msg.twist.angular.z),
)
self._last_input_monotonic = time.monotonic()
def _command_for_current_tick(self) -> CommandTuple:
if self._last_input_monotonic is None:
return ZERO_COMMAND
if self._input_timeout == 0.0:
return self._latest_command
age = time.monotonic() - self._last_input_monotonic
if age > self._input_timeout:
return ZERO_COMMAND
return self._latest_command
def _send_command(self, command: CommandTuple) -> None:
payload = pack_command(command)
try:
self._transport.send(to=self._target_peer, data=payload)
self._last_sent_command = command
except OSError as exc:
if self._should_log('send_error', 2.0):
self.get_logger().error(f'OmniSocket send failed: {exc}')
def _send_latest_command(self) -> None:
self._send_command(self._command_for_current_tick())
def _log_inbound_message(self, from_peer: str, msg_type: int, payload: bytes) -> None:
if msg_type == MSG_TYPE_ERROR:
if self._should_log('server_error', 1.0):
text = payload.decode('utf-8', errors='replace')
self.get_logger().error(f'OmniSocket server error from {from_peer}: {text}')
return
if self._should_log('unexpected_inbound', 2.0):
self.get_logger().warning(
'Ignoring unexpected inbound message type %d from %s (%d bytes)'
% (msg_type, from_peer, len(payload))
)
def _drain_incoming(self) -> None:
while not self._closing.is_set() and rclpy.ok():
try:
result = self._transport.recv(timeout_ms=100)
except OSError as exc:
if not self._closing.is_set() and self._should_log('drain_error', 2.0):
self.get_logger().error(f'OmniSocket receive loop stopped: {exc}')
return
if result is None:
continue
from_peer, msg_type, payload = result
self._log_inbound_message(from_peer, msg_type, payload)
def send_zero_burst(self) -> None:
"""Best-effort stop command sent during shutdown."""
for _ in range(max(1, self._exit_zero_packets)):
self._send_command(ZERO_COMMAND)
time.sleep(0.02)
def close(self) -> None:
self._closing.set()
if hasattr(self, '_transport') and self._transport is not None:
try:
self._transport.close()
except OSError as exc:
if self._should_log('close_error', 2.0):
self.get_logger().warning(f'Closing OmniSocket transport failed: {exc}')
self._transport = None
if hasattr(self, '_drain_thread') and self._drain_thread.is_alive():
self._drain_thread.join(timeout=0.5)
def destroy_node(self) -> bool:
self.close()
return super().destroy_node()
def main(args: Optional[list[str]] = None) -> None:
rclpy.init(args=args)
node = CmdVelUdpSender()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.send_zero_burst()
node.destroy_node()
rclpy.shutdown()

View File

@@ -0,0 +1,95 @@
"""Helpers for working with OmniSocket transport sessions."""
from __future__ import annotations
from .protocol import default_server_addr_for_transport, normalize_transport
try:
from omnisocket import (
CONTROL_DEFAULTS,
MSG_TYPE_BINARY,
MSG_TYPE_ERROR,
Session,
UdpSession,
)
except ImportError as exc: # pragma: no cover - depends on external build/install
raise RuntimeError(
'omnisocket is not installed for this Python environment; run '
'`make python-ext && make python-install` on a Linux host first'
) from exc
def _normalize_optional(value: object) -> str:
return str(value).strip()
class OmniTransport:
"""Small wrapper that normalizes OmniSocket UDP/KCP session setup."""
def __init__(
self,
*,
transport: object,
server_addr: object,
peer_id: object,
relay_via: object = '',
bind_ip: object = '',
bind_device: object = '',
enable_timestamping: bool = False,
) -> None:
self.transport = normalize_transport(transport)
self.server_addr = _normalize_optional(server_addr) or default_server_addr_for_transport(self.transport)
self.peer_id = _normalize_optional(peer_id)
self.relay_via = _normalize_optional(relay_via)
self.bind_ip = _normalize_optional(bind_ip)
self.bind_device = _normalize_optional(bind_device)
if not self.peer_id:
raise ValueError('peer_id must not be empty')
session_cls = Session if self.transport == 'kcp' else UdpSession
self._session = session_cls()
connect_kwargs: dict[str, object] = {
'server_addr': self.server_addr,
'peer_id': self.peer_id,
}
if self.bind_ip:
connect_kwargs['bind_ip'] = self.bind_ip
if self.bind_device:
connect_kwargs['bind_device'] = self.bind_device
if self.transport == 'kcp':
if self.relay_via:
connect_kwargs['relay_via'] = self.relay_via
connect_kwargs.update(CONTROL_DEFAULTS)
else:
connect_kwargs['enable_timestamping'] = bool(enable_timestamping)
self._session.connect(**connect_kwargs)
def send(self, *, to: str, data: bytes) -> None:
self._session.send(to=to, data=data)
def recv(self, *, timeout_ms: int = -1):
return self._session.recv(timeout_ms=timeout_ms)
def recv_into(self, *, buffer, timeout_ms: int = -1):
return self._session.recv_into(buffer=buffer, timeout_ms=timeout_ms)
def close(self) -> None:
self._session.close()
def stats(self) -> dict[str, int]:
return self._session.stats()
__all__ = [
'CONTROL_DEFAULTS',
'MSG_TYPE_BINARY',
'MSG_TYPE_ERROR',
'OmniTransport',
'Session',
'UdpSession',
]

View File

@@ -0,0 +1,74 @@
"""Shared teleop protocol helpers and transport defaults."""
from __future__ import annotations
import math
import struct
from typing import Iterable, Tuple
COMMAND_STRUCT = struct.Struct('<6f')
PACKET_SIZE = COMMAND_STRUCT.size
SUPPORTED_TRANSPORTS = ('udp', 'kcp')
DEFAULT_TRANSPORT = 'udp'
DEFAULT_OMNI_UDP_SERVER_ADDR = '127.0.0.1:9001'
DEFAULT_OMNI_KCP_SERVER_ADDR = '127.0.0.1:9002'
DEFAULT_KEYBOARD_PEER_ID = 'ros-keyboard-ctrl'
DEFAULT_GAMEPAD_PEER_ID = 'ros-gamepad-ctrl'
DEFAULT_BRIDGE_PEER_ID = 'ros-bridge-ctrl'
DEFAULT_TARGET_PEER = DEFAULT_BRIDGE_PEER_ID
DEFAULT_FRAME_ID = 'pelvis'
DEFAULT_INPUT_TOPIC = '/teleop/cmd_vel'
DEFAULT_OUTPUT_TOPIC = '/hric/robot/cmd_vel'
DEFAULT_SEND_RATE_HZ = 20.0
DEFAULT_INPUT_TIMEOUT = 0.75
DEFAULT_WATCHDOG_TIMEOUT = 0.5
DEFAULT_PUBLISH_RATE_HZ = 100.0
DEFAULT_QUEUE_DEPTH = 10
DEFAULT_EXIT_ZERO_PACKETS = 3
DEFAULT_RECV_BUFFER_BYTES = 2048
ZERO_COMMAND = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
def normalize_transport(value: object) -> str:
"""Return a supported transport name."""
transport = str(value).strip().lower()
if transport not in SUPPORTED_TRANSPORTS:
supported = ', '.join(SUPPORTED_TRANSPORTS)
raise ValueError(f"Unsupported transport '{transport}', expected one of: {supported}")
return transport
def default_server_addr_for_transport(transport: str) -> str:
"""Return the default OmniSocket server for the chosen transport."""
transport = normalize_transport(transport)
if transport == 'udp':
return DEFAULT_OMNI_UDP_SERVER_ADDR
return DEFAULT_OMNI_KCP_SERVER_ADDR
def normalize_command(values: Iterable[float]) -> Tuple[float, float, float, float, float, float]:
"""Return a finite six-float command tuple."""
command = tuple(float(value) for value in values)
if len(command) != 6:
raise ValueError(f'Expected 6 command values, got {len(command)}')
if any(not math.isfinite(value) for value in command):
raise ValueError('Command contains a non-finite value')
return command
def pack_command(values: Iterable[float]) -> bytes:
"""Pack six floats into the wire format."""
return COMMAND_STRUCT.pack(*normalize_command(values))
def unpack_command(payload: bytes) -> Tuple[float, float, float, float, float, float]:
"""Decode a control packet into a six-float command tuple."""
if len(payload) != PACKET_SIZE:
raise ValueError(f'Expected {PACKET_SIZE} bytes, got {len(payload)}')
return normalize_command(COMMAND_STRUCT.unpack(payload))

View File

@@ -0,0 +1,238 @@
"""ROS 2 node that receives OmniSocket teleop packets and republishes TwistStamped."""
from __future__ import annotations
import threading
import time
from typing import Dict, Optional, Tuple
import rclpy
from geometry_msgs.msg import TwistStamped
from rclpy.node import Node
from .omni_transport import MSG_TYPE_BINARY, MSG_TYPE_ERROR, OmniTransport
from .protocol import (
DEFAULT_BRIDGE_PEER_ID,
DEFAULT_FRAME_ID,
DEFAULT_OUTPUT_TOPIC,
DEFAULT_PUBLISH_RATE_HZ,
DEFAULT_QUEUE_DEPTH,
DEFAULT_RECV_BUFFER_BYTES,
DEFAULT_TRANSPORT,
DEFAULT_WATCHDOG_TIMEOUT,
PACKET_SIZE,
ZERO_COMMAND,
unpack_command,
)
CommandTuple = Tuple[float, float, float, float, float, float]
class UdpCmdVelReceiver(Node):
"""Publish TwistStamped commands from the OmniSocket control wire format."""
def __init__(self) -> None:
super().__init__('udp_cmd_vel_receiver')
self.declare_parameter('transport', DEFAULT_TRANSPORT)
self.declare_parameter('server_addr', '')
self.declare_parameter('relay_via', '')
self.declare_parameter('peer_id', DEFAULT_BRIDGE_PEER_ID)
self.declare_parameter('expected_sender', '')
self.declare_parameter('output_topic', DEFAULT_OUTPUT_TOPIC)
self.declare_parameter('frame_id', DEFAULT_FRAME_ID)
self.declare_parameter('watchdog_timeout', DEFAULT_WATCHDOG_TIMEOUT)
self.declare_parameter('publish_rate_hz', DEFAULT_PUBLISH_RATE_HZ)
self.declare_parameter('queue_depth', DEFAULT_QUEUE_DEPTH)
self._transport_name = str(self.get_parameter('transport').value)
self._server_addr = str(self.get_parameter('server_addr').value)
self._relay_via = str(self.get_parameter('relay_via').value)
self._peer_id = str(self.get_parameter('peer_id').value)
self._expected_sender = str(self.get_parameter('expected_sender').value).strip()
self._output_topic = str(self.get_parameter('output_topic').value)
self._frame_id = str(self.get_parameter('frame_id').value)
self._watchdog_timeout = float(self.get_parameter('watchdog_timeout').value)
self._publish_rate_hz = float(self.get_parameter('publish_rate_hz').value)
self._queue_depth = int(self.get_parameter('queue_depth').value)
if self._watchdog_timeout <= 0.0:
raise ValueError('watchdog_timeout must be > 0')
if self._publish_rate_hz <= 0.0:
raise ValueError('publish_rate_hz must be > 0')
if self._queue_depth <= 0:
raise ValueError('queue_depth must be > 0')
self._publisher = self.create_publisher(TwistStamped, self._output_topic, self._queue_depth)
self._transport = OmniTransport(
transport=self._transport_name,
server_addr=self._server_addr,
relay_via=self._relay_via,
peer_id=self._peer_id,
)
self._lock = threading.Lock()
self._last_log_times: Dict[str, float] = {}
self._latest_command: CommandTuple = ZERO_COMMAND
self._last_packet_monotonic: Optional[float] = None
self._last_published_command: CommandTuple = ZERO_COMMAND
self._closing = threading.Event()
self._recv_buffer = bytearray(DEFAULT_RECV_BUFFER_BYTES)
self.create_timer(1.0 / self._publish_rate_hz, self._publish_tick)
self._recv_thread = threading.Thread(target=self._recv_loop, daemon=True)
self._recv_thread.start()
self.get_logger().info(
'Receiving teleop commands via %s://%s as %s and publishing TwistStamped to %s '
'at %.1f Hz (frame_id=%s, watchdog %.2f s)'
% (
self._transport.transport,
self._transport.server_addr,
self._peer_id,
self._output_topic,
self._publish_rate_hz,
self._frame_id,
self._watchdog_timeout,
)
)
def _should_log(self, key: str, throttle_sec: float) -> bool:
now = time.monotonic()
previous = self._last_log_times.get(key)
if previous is None or (now - previous) >= throttle_sec:
self._last_log_times[key] = now
return True
return False
def _publish_command(self, command: CommandTuple) -> None:
msg = TwistStamped()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = self._frame_id
msg.twist.linear.x = command[0]
msg.twist.linear.y = command[1]
msg.twist.linear.z = command[2]
msg.twist.angular.x = command[3]
msg.twist.angular.y = command[4]
msg.twist.angular.z = command[5]
self._publisher.publish(msg)
self._last_published_command = command
def _handle_error_message(self, from_peer: str, body_len: int) -> None:
if self._should_log('server_error', 1.0):
text = bytes(self._recv_buffer[:body_len]).decode('utf-8', errors='replace')
self.get_logger().error(f'OmniSocket server error from {from_peer}: {text}')
def _recv_loop(self) -> None:
while not self._closing.is_set() and rclpy.ok():
try:
meta = self._transport.recv_into(buffer=self._recv_buffer, timeout_ms=100)
except BufferError as exc:
if self._should_log('buffer_error', 2.0):
self.get_logger().warning(f'Dropped oversized OmniSocket frame: {exc}')
continue
except OSError as exc:
if not self._closing.is_set() and self._should_log('recv_error', 2.0):
self.get_logger().error(f'OmniSocket receive loop stopped: {exc}')
return
if meta is None:
continue
from_peer = str(meta['from'])
msg_type = int(meta['msg_type'])
body_len = int(meta['body_len'])
if msg_type == MSG_TYPE_ERROR:
self._handle_error_message(from_peer, body_len)
continue
if self._expected_sender and from_peer != self._expected_sender:
if self._should_log('unexpected_sender', 2.0):
self.get_logger().warning(
'Ignoring message from unexpected sender %s (expected %s)'
% (from_peer, self._expected_sender)
)
continue
if msg_type != MSG_TYPE_BINARY:
if self._should_log('unexpected_type', 2.0):
self.get_logger().warning(
'Ignoring unexpected message type %d from %s (%d bytes)'
% (msg_type, from_peer, body_len)
)
continue
if body_len != PACKET_SIZE:
if self._should_log('packet_size', 2.0):
self.get_logger().warning(
'Dropped binary payload from %s with invalid size %d (expected %d)'
% (from_peer, body_len, PACKET_SIZE)
)
continue
try:
command = unpack_command(self._recv_buffer[:PACKET_SIZE])
except ValueError as exc:
if self._should_log('decode_error', 2.0):
self.get_logger().warning(f'Dropped malformed command payload: {exc}')
continue
with self._lock:
self._latest_command = command
self._last_packet_monotonic = time.monotonic()
def _command_for_publish_tick(self) -> tuple[CommandTuple, Optional[float], bool]:
with self._lock:
latest_command = self._latest_command
last_packet_monotonic = self._last_packet_monotonic
if last_packet_monotonic is None:
return ZERO_COMMAND, None, False
age = time.monotonic() - last_packet_monotonic
if age > self._watchdog_timeout:
return ZERO_COMMAND, age, True
return latest_command, age, False
def _publish_tick(self) -> None:
publish_command, age, timed_out = self._command_for_publish_tick()
if timed_out and self._last_published_command != ZERO_COMMAND:
if self._should_log('watchdog_stop', 2.0):
self.get_logger().warning(
'Command stream timed out after %.2f s, publishing zero velocity stop'
% age
)
self._publish_command(publish_command)
def close(self) -> None:
self._closing.set()
if hasattr(self, '_transport') and self._transport is not None:
try:
self._transport.close()
except OSError as exc:
if self._should_log('close_error', 2.0):
self.get_logger().warning(f'Closing OmniSocket transport failed: {exc}')
self._transport = None
if hasattr(self, '_recv_thread') and self._recv_thread.is_alive():
self._recv_thread.join(timeout=0.5)
def destroy_node(self) -> bool:
self.close()
return super().destroy_node()
def main(args: Optional[list[str]] = None) -> None:
rclpy.init(args=args)
node = UdpCmdVelReceiver()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()