清理多余策略
This commit is contained in:
@@ -14,8 +14,6 @@ class FSMStateName(Enum):
|
|||||||
STOP = 0 # 停止状态
|
STOP = 0 # 停止状态
|
||||||
ZERO = 1 # 零位状态
|
ZERO = 1 # 零位状态
|
||||||
WALKAMP = 2 # WALKAMP策略状态
|
WALKAMP = 2 # WALKAMP策略状态
|
||||||
MYPOLICY = 3 # 自定义策略状态
|
|
||||||
XSIMRUN = 4 # 更贴近sim2sim的xSIM run状态
|
|
||||||
|
|
||||||
class FSMState(ABC):
|
class FSMState(ABC):
|
||||||
"""FSM状态抽象基类"""
|
"""FSM状态抽象基类"""
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ Complete FSM implementation with state management
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
from .fsm_base import RobotFSM, FSMStateName
|
from .fsm_base import RobotFSM, FSMStateName
|
||||||
from policy.walk_amp.fsm_walkamp import FSMStateWALKAMP
|
from policy.walk_amp.fsm_walkamp import FSMStateWALKAMP
|
||||||
from policy.mypolicy.fsm_mypolicy import FSMStateMYPOLICY
|
|
||||||
from policy.xsim_run.fsm_xsim_run import FSMStateXSIMRUN
|
|
||||||
from policy.zero.fsm_zero import FSMStateZero
|
from policy.zero.fsm_zero import FSMStateZero
|
||||||
from policy.stop.fsm_stop import FSMStateStop
|
from policy.stop.fsm_stop import FSMStateStop
|
||||||
from policy.beyond_mimic.fsm_beyond_mimic import FSMStateBeyondMimic
|
from policy.beyond_mimic.fsm_beyond_mimic import FSMStateBeyondMimic
|
||||||
@@ -52,8 +50,6 @@ class RobotFSMImpl(RobotFSM):
|
|||||||
self.state_objects[FSMStateName.STOP] = FSMStateStop(self.robot_data_)
|
self.state_objects[FSMStateName.STOP] = FSMStateStop(self.robot_data_)
|
||||||
self.state_objects[FSMStateName.ZERO] = FSMStateZero(self.robot_data_)
|
self.state_objects[FSMStateName.ZERO] = FSMStateZero(self.robot_data_)
|
||||||
self.state_objects[FSMStateName.WALKAMP] = FSMStateWALKAMP(self.robot_data_)
|
self.state_objects[FSMStateName.WALKAMP] = FSMStateWALKAMP(self.robot_data_)
|
||||||
self.state_objects[FSMStateName.MYPOLICY] = FSMStateMYPOLICY(self.robot_data_)
|
|
||||||
self.state_objects[FSMStateName.XSIMRUN] = FSMStateXSIMRUN(self.robot_data_)
|
|
||||||
|
|
||||||
# TODO: 添加其他状态对象
|
# TODO: 添加其他状态对象
|
||||||
@timing_decorator
|
@timing_decorator
|
||||||
|
|||||||
@@ -194,8 +194,6 @@ class RobotInterfaceImpl(RobotInterface):
|
|||||||
"STOP": FSMStateName.STOP,
|
"STOP": FSMStateName.STOP,
|
||||||
"ZERO": FSMStateName.ZERO,
|
"ZERO": FSMStateName.ZERO,
|
||||||
"WALKAMP": FSMStateName.WALKAMP,
|
"WALKAMP": FSMStateName.WALKAMP,
|
||||||
"MYPOLICY": FSMStateName.MYPOLICY,
|
|
||||||
"XSIMRUN": FSMStateName.XSIMRUN,
|
|
||||||
}
|
}
|
||||||
self.waist_control_status = [state_to_FSMState[state] for state in config.get('waist_control_status')]
|
self.waist_control_status = [state_to_FSMState[state] for state in config.get('waist_control_status')]
|
||||||
self.legs_control_status = [state_to_FSMState[state] for state in config.get('legs_control_status')]
|
self.legs_control_status = [state_to_FSMState[state] for state in config.get('legs_control_status')]
|
||||||
|
|||||||
@@ -61,8 +61,6 @@ class KeyboardController:
|
|||||||
print(" z - Goto ZERO state")
|
print(" z - Goto ZERO state")
|
||||||
print(" c - Goto STOP state")
|
print(" c - Goto STOP state")
|
||||||
print(" m - Goto WALKAMP state")
|
print(" m - Goto WALKAMP state")
|
||||||
print(" p - Goto MYPOLICY state")
|
|
||||||
print(" n - Goto XSIMRUN state")
|
|
||||||
print(" Left/Right arrows - Adjust height")
|
print(" Left/Right arrows - Adjust height")
|
||||||
print(" w/a/s/d - Movement controls")
|
print(" w/a/s/d - Movement controls")
|
||||||
print(" q/e - Rotation controls (turn left/right)")
|
print(" q/e - Rotation controls (turn left/right)")
|
||||||
@@ -196,10 +194,6 @@ class KeyboardController:
|
|||||||
self._on_x_key()
|
self._on_x_key()
|
||||||
elif key == '4':
|
elif key == '4':
|
||||||
self._on_g_key()
|
self._on_g_key()
|
||||||
elif key == 'p':
|
|
||||||
self._on_p_key()
|
|
||||||
elif key == 'n':
|
|
||||||
self._on_n_key()
|
|
||||||
elif key == '6':
|
elif key == '6':
|
||||||
self._on_o_key()
|
self._on_o_key()
|
||||||
elif key == 'v':
|
elif key == 'v':
|
||||||
@@ -359,26 +353,6 @@ class KeyboardController:
|
|||||||
self.last_fsm_command_time = time.time()
|
self.last_fsm_command_time = time.time()
|
||||||
print("Command: gotoWALKAMP")
|
print("Command: gotoWALKAMP")
|
||||||
|
|
||||||
def _on_p_key(self):
|
|
||||||
"""处理p键 - 切换到MYPOLICY状态"""
|
|
||||||
with self.data_mutex:
|
|
||||||
self.keyboard_flag.x_speed_command = 0.0
|
|
||||||
self.keyboard_flag.y_speed_command = 0.0
|
|
||||||
self.keyboard_flag.yaw_speed_command = 0.0
|
|
||||||
self.keyboard_flag.fsm_state_command = "gotoMYPOLICY"
|
|
||||||
self.last_fsm_command_time = time.time()
|
|
||||||
print("Command: gotoMYPOLICY (movement commands reset to zero)")
|
|
||||||
|
|
||||||
def _on_n_key(self):
|
|
||||||
"""处理n键 - 切换到XSIMRUN状态"""
|
|
||||||
with self.data_mutex:
|
|
||||||
self.keyboard_flag.x_speed_command = 0.0
|
|
||||||
self.keyboard_flag.y_speed_command = 0.0
|
|
||||||
self.keyboard_flag.yaw_speed_command = 0.0
|
|
||||||
self.keyboard_flag.fsm_state_command = "gotoXSIMRUN"
|
|
||||||
self.last_fsm_command_time = time.time()
|
|
||||||
print("Command: gotoXSIMRUN (movement commands reset to zero)")
|
|
||||||
|
|
||||||
def _handle_ctrl_c(self):
|
def _handle_ctrl_c(self):
|
||||||
"""处理Ctrl+C - 发送SIGINT信号给主进程"""
|
"""处理Ctrl+C - 发送SIGINT信号给主进程"""
|
||||||
# 先停止键盘控制器
|
# 先停止键盘控制器
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ xbox:
|
|||||||
|
|
||||||
robot_interface:
|
robot_interface:
|
||||||
clip_actions: false
|
clip_actions: false
|
||||||
waist_control_status: ["ZERO", "STOP", "WALKAMP", "MYPOLICY", "XSIMRUN"] #
|
waist_control_status: ["ZERO", "STOP", "WALKAMP"] #
|
||||||
legs_control_status: [] #空代表都允许控制,仅腿部是这个逻辑
|
legs_control_status: [] #空代表都允许控制,仅腿部是这个逻辑
|
||||||
arms_control_status: ["ZERO", "STOP", "WALKAMP", "MYPOLICY", "XSIMRUN"] #
|
arms_control_status: ["ZERO", "STOP", "WALKAMP"] #
|
||||||
left_arm_only_status: []
|
left_arm_only_status: []
|
||||||
right_arm_only_status: []
|
right_arm_only_status: []
|
||||||
xsense_data_roll_offset: 0.0
|
xsense_data_roll_offset: 0.0
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
model_path: "policy.onnx"
|
|
||||||
motor_num: 29
|
|
||||||
actions_size: 23
|
|
||||||
dt: 0.01
|
|
||||||
warm_start_time: 0.0
|
|
||||||
xsense_data_roll_offset: 0.0
|
|
||||||
joint_names: [
|
|
||||||
hip_pitch_l_joint, hip_pitch_r_joint, waist_yaw_joint,
|
|
||||||
hip_roll_l_joint, hip_roll_r_joint, waist_roll_joint,
|
|
||||||
hip_yaw_l_joint, hip_yaw_r_joint, waist_pitch_joint,
|
|
||||||
knee_pitch_l_joint, knee_pitch_r_joint,
|
|
||||||
shoulder_pitch_l_joint, shoulder_pitch_r_joint,
|
|
||||||
ankle_pitch_l_joint, ankle_pitch_r_joint,
|
|
||||||
shoulder_roll_l_joint, shoulder_roll_r_joint,
|
|
||||||
ankle_roll_l_joint, ankle_roll_r_joint,
|
|
||||||
shoulder_yaw_l_joint, shoulder_yaw_r_joint,
|
|
||||||
elbow_pitch_l_joint, elbow_pitch_r_joint
|
|
||||||
]
|
|
||||||
zero_pos_offset: [
|
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
|
||||||
0.0, 0.0, 0.0,
|
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
|
||||||
]
|
|
||||||
ct_scale: [
|
|
||||||
2.1, 2.1, 3.207, 2.673, 2.6, 2.6,
|
|
||||||
2.1, 2.1, 3.207, 2.673, 2.6, 2.6,
|
|
||||||
3.207, 3.207, 3.207,
|
|
||||||
3.207, 2.28, 5.89, 5.89, 3.35, 2.3, 2.3,
|
|
||||||
3.207, 2.28, 5.89, 5.89, 3.35, 2.3, 2.3
|
|
||||||
]
|
|
||||||
control:
|
|
||||||
action_scale: 0.25
|
|
||||||
decimation: 2
|
|
||||||
|
|
||||||
gait:
|
|
||||||
gait_air_ratio_l: 0.6
|
|
||||||
gait_air_ratio_r: 0.6
|
|
||||||
gait_phase_offset_l: 0.6
|
|
||||||
gait_phase_offset_r: 0.1
|
|
||||||
gait_cycle: 0.64
|
|
||||||
|
|
||||||
normalization:
|
|
||||||
clip_scales:
|
|
||||||
clip_observations: 100.0
|
|
||||||
clip_actions: 100.0
|
|
||||||
obs_scales:
|
|
||||||
lin_vel: 1.0
|
|
||||||
ang_vel: 1.0
|
|
||||||
dof_pos: 1.0
|
|
||||||
dof_vel: 1.0
|
|
||||||
|
|
||||||
size:
|
|
||||||
num_hist: 10
|
|
||||||
observations_size: 84
|
|
||||||
|
|
||||||
gains:
|
|
||||||
kp: [
|
|
||||||
300.0, 300.0, 400.0,
|
|
||||||
300.0, 300.0, 400.0,
|
|
||||||
150.0, 150.0, 400.0,
|
|
||||||
350.0, 350.0,
|
|
||||||
150.0, 150.0,
|
|
||||||
30.0, 30.0,
|
|
||||||
50.0, 50.0,
|
|
||||||
16.8, 16.8,
|
|
||||||
50.0, 50.0,
|
|
||||||
150.0, 150.0
|
|
||||||
]
|
|
||||||
kd: [
|
|
||||||
10.0, 10.0, 5.0,
|
|
||||||
10.0, 10.0, 10.0,
|
|
||||||
5.0, 5.0, 10.0,
|
|
||||||
10.0, 10.0,
|
|
||||||
7.5, 7.5,
|
|
||||||
2.5, 2.5,
|
|
||||||
2.5, 2.5,
|
|
||||||
1.4, 1.4,
|
|
||||||
2.5, 2.5,
|
|
||||||
5.0, 5.0
|
|
||||||
]
|
|
||||||
|
|
||||||
init_state:
|
|
||||||
default_joint_angles: [
|
|
||||||
0.0, 0.0, 0.0,
|
|
||||||
-0.5, -0.5, 0.0,
|
|
||||||
0.0, 0.0, 0.0,
|
|
||||||
1.0, 1.0,
|
|
||||||
0.0, 0.0,
|
|
||||||
-0.5, -0.5,
|
|
||||||
0.2, -0.2,
|
|
||||||
0.0, 0.0,
|
|
||||||
0.0, 0.0,
|
|
||||||
-0.3, -0.3
|
|
||||||
]
|
|
||||||
@@ -1,373 +0,0 @@
|
|||||||
"""
|
|
||||||
FSM state implementation for the standalone MYPOLICY controller.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime as ort
|
|
||||||
import yaml
|
|
||||||
from scipy.spatial.transform import Rotation
|
|
||||||
|
|
||||||
from FSM.fsm_base import FSMState, FSMStateName
|
|
||||||
from common.BasicFunction import gait_phase
|
|
||||||
from common.joystick import ControlFlag
|
|
||||||
from common.robot_data import RobotData
|
|
||||||
|
|
||||||
|
|
||||||
class FSMStateMYPOLICY(FSMState):
|
|
||||||
"""Standalone FSM implementation for the custom ONNX policy."""
|
|
||||||
|
|
||||||
def __init__(self, robot_data: RobotData):
|
|
||||||
super().__init__(robot_data)
|
|
||||||
self.current_state_name = FSMStateName.MYPOLICY
|
|
||||||
self.log_prefix = "FSMStateMYPOLICY"
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
config_path = os.path.join(current_dir, "config", "mypolicy.yaml")
|
|
||||||
with open(config_path, "r") as f:
|
|
||||||
policy_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
self.action_num_ = policy_config.get("actions_size")
|
|
||||||
self.motor_num_ = policy_config.get("motor_num")
|
|
||||||
self.dt_ = policy_config.get("dt")
|
|
||||||
|
|
||||||
size_config = policy_config.get("size", {})
|
|
||||||
self.num_hist_ = size_config.get("num_hist")
|
|
||||||
self.obs_size_ = size_config.get("observations_size")
|
|
||||||
|
|
||||||
control_config = policy_config.get("control", {})
|
|
||||||
self.action_scale_ = control_config.get("action_scale")
|
|
||||||
self.decimation_ = control_config.get("decimation")
|
|
||||||
self.warm_start_time_ = control_config.get(
|
|
||||||
"warm_start_time",
|
|
||||||
policy_config.get("warm_start_time", 0.3),
|
|
||||||
)
|
|
||||||
|
|
||||||
norm_config = policy_config.get("normalization", {})
|
|
||||||
clip_config = norm_config.get("clip_scales", {})
|
|
||||||
obs_config = norm_config.get("obs_scales", {})
|
|
||||||
self.clip_obs_ = clip_config.get("clip_observations", 100.0)
|
|
||||||
self.clip_act_ = clip_config.get("clip_actions", 100.0)
|
|
||||||
self.lin_vel_scale_ = obs_config.get("lin_vel")
|
|
||||||
self.ang_vel_scale_ = obs_config.get("ang_vel")
|
|
||||||
self.dof_pos_scale_ = obs_config.get("dof_pos")
|
|
||||||
self.dof_vel_scale_ = obs_config.get("dof_vel")
|
|
||||||
|
|
||||||
self.observations_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
|
|
||||||
self.proprio_hist_buf_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
|
|
||||||
self.last_actions_ = np.zeros(self.action_num_, dtype=np.float32)
|
|
||||||
self.actions_ = np.zeros(self.action_num_, dtype=np.float32)
|
|
||||||
self._warm_start_pose = np.zeros(self.motor_num_, dtype=np.float32)
|
|
||||||
|
|
||||||
self.is_first_obs_ = True
|
|
||||||
self.is_first_action_ = True
|
|
||||||
self.is_first_step_ = True
|
|
||||||
self.timer_gait_ = 0.0
|
|
||||||
|
|
||||||
gait_config = policy_config.get("gait", {})
|
|
||||||
self.gait_cycle = gait_config.get("gait_cycle", 0.64)
|
|
||||||
self.left_phase_ratio = gait_config.get("gait_air_ratio_l", 0.6)
|
|
||||||
self.right_phase_ratio = gait_config.get("gait_air_ratio_r", 0.6)
|
|
||||||
self.left_theta_offset = gait_config.get("gait_phase_offset_l", 0.6)
|
|
||||||
self.right_theta_offset = gait_config.get("gait_phase_offset_r", 0.1)
|
|
||||||
|
|
||||||
step = (self.decimation_ if self.decimation_ else 1) * self.dt_
|
|
||||||
if self.warm_start_time_ > 0 and step > 0:
|
|
||||||
self._warm_start_steps = max(1, int(self.warm_start_time_ / step))
|
|
||||||
else:
|
|
||||||
self._warm_start_steps = 0
|
|
||||||
self._warmup_inference_counter = 0
|
|
||||||
|
|
||||||
self.waiting_for_motion = True
|
|
||||||
self.motion_threshold = 1e-3
|
|
||||||
self.hold_pose = np.zeros(self.motor_num_, dtype=np.float32)
|
|
||||||
self.filtered_x_speed = 0.0
|
|
||||||
|
|
||||||
self.model_path = os.path.join(current_dir, "model", policy_config["model_path"])
|
|
||||||
self._init_onnx_session()
|
|
||||||
|
|
||||||
joint_names = policy_config.get("joint_names")
|
|
||||||
if joint_names is None:
|
|
||||||
raise ValueError("[FSMStateMYPOLICY] Missing 'joint_names' in mypolicy.yaml")
|
|
||||||
self.joint_seq = list(joint_names)
|
|
||||||
|
|
||||||
if self.action_scale_ is None:
|
|
||||||
raise ValueError("[FSMStateMYPOLICY] Missing 'control.action_scale' in mypolicy.yaml")
|
|
||||||
if np.isscalar(self.action_scale_):
|
|
||||||
self.action_scale = np.full(len(self.joint_seq), float(self.action_scale_), dtype=np.float32)
|
|
||||||
else:
|
|
||||||
self.action_scale = np.array(self.action_scale_, dtype=np.float32)
|
|
||||||
if len(self.action_scale) != len(self.joint_seq):
|
|
||||||
raise ValueError(
|
|
||||||
f"[FSMStateMYPOLICY] control.action_scale length {len(self.action_scale)} does not match joint count {len(self.joint_seq)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
init_state_config = policy_config.get("init_state", {})
|
|
||||||
default_joint_angles = init_state_config.get("default_joint_angles")
|
|
||||||
if default_joint_angles is None:
|
|
||||||
raise ValueError("[FSMStateMYPOLICY] Missing 'init_state.default_joint_angles' in mypolicy.yaml")
|
|
||||||
self.joint_pos_array_seq = np.array(default_joint_angles, dtype=np.float32)
|
|
||||||
if len(self.joint_pos_array_seq) != len(self.joint_seq):
|
|
||||||
raise ValueError(
|
|
||||||
f"[FSMStateMYPOLICY] init_state.default_joint_angles length {len(self.joint_pos_array_seq)} does not match joint count {len(self.joint_seq)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
gains_config = policy_config.get("gains", {})
|
|
||||||
kp_values = gains_config.get("kp")
|
|
||||||
kd_values = gains_config.get("kd")
|
|
||||||
if kp_values is None or kd_values is None:
|
|
||||||
raise ValueError("[FSMStateMYPOLICY] Missing 'gains.kp' or 'gains.kd' in mypolicy.yaml")
|
|
||||||
self.stiffness_array_seq = np.array(kp_values, dtype=np.float32)
|
|
||||||
self.damping_array_seq = np.array(kd_values, dtype=np.float32)
|
|
||||||
if len(self.stiffness_array_seq) != len(self.joint_seq):
|
|
||||||
raise ValueError(
|
|
||||||
f"[FSMStateMYPOLICY] gains.kp length {len(self.stiffness_array_seq)} does not match joint count {len(self.joint_seq)}"
|
|
||||||
)
|
|
||||||
if len(self.damping_array_seq) != len(self.joint_seq):
|
|
||||||
raise ValueError(
|
|
||||||
f"[FSMStateMYPOLICY] gains.kd length {len(self.damping_array_seq)} does not match joint count {len(self.joint_seq)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.joint_xml = [
|
|
||||||
"hip_pitch_l_joint", "hip_roll_l_joint", "hip_yaw_l_joint",
|
|
||||||
"knee_pitch_l_joint", "ankle_pitch_l_joint", "ankle_roll_l_joint",
|
|
||||||
"hip_pitch_r_joint", "hip_roll_r_joint", "hip_yaw_r_joint",
|
|
||||||
"knee_pitch_r_joint", "ankle_pitch_r_joint", "ankle_roll_r_joint",
|
|
||||||
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
|
||||||
"shoulder_pitch_l_joint", "shoulder_roll_l_joint", "shoulder_yaw_l_joint",
|
|
||||||
"elbow_pitch_l_joint", "elbow_yaw_l_joint", "wrist_pitch_l_joint", "wrist_roll_l_joint",
|
|
||||||
"shoulder_pitch_r_joint", "shoulder_roll_r_joint", "shoulder_yaw_r_joint",
|
|
||||||
"elbow_pitch_r_joint", "elbow_yaw_r_joint", "wrist_pitch_r_joint", "wrist_roll_r_joint",
|
|
||||||
]
|
|
||||||
|
|
||||||
self.lab2mj = []
|
|
||||||
for name in self.joint_seq:
|
|
||||||
if name not in self.joint_xml:
|
|
||||||
raise ValueError(f"[FSMStateMYPOLICY] joint '{name}' from mypolicy.yaml not found in joint_xml")
|
|
||||||
self.lab2mj.append(self.joint_xml.index(name))
|
|
||||||
self.lab2mj = np.array(self.lab2mj, dtype=int)
|
|
||||||
|
|
||||||
n_mj = len(self.joint_xml)
|
|
||||||
self.joint_pos_array = np.zeros(n_mj, dtype=np.float32)
|
|
||||||
self.stiffness_array = np.zeros(n_mj, dtype=np.float32)
|
|
||||||
self.damping_array = np.zeros(n_mj, dtype=np.float32)
|
|
||||||
for lab_idx, mj_idx in enumerate(self.lab2mj):
|
|
||||||
self.joint_pos_array[mj_idx] = self.joint_pos_array_seq[lab_idx]
|
|
||||||
self.stiffness_array[mj_idx] = self.stiffness_array_seq[lab_idx]
|
|
||||||
self.damping_array[mj_idx] = self.damping_array_seq[lab_idx]
|
|
||||||
|
|
||||||
self.kps_lab = self.stiffness_array_seq
|
|
||||||
self.kds_lab = self.damping_array_seq
|
|
||||||
self.default_angles_lab = self.joint_pos_array_seq
|
|
||||||
self.action_scale_lab = self.action_scale
|
|
||||||
|
|
||||||
def _init_onnx_session(self):
|
|
||||||
try:
|
|
||||||
options = ort.SessionOptions()
|
|
||||||
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
options.intra_op_num_threads = 1
|
|
||||||
options.inter_op_num_threads = 1
|
|
||||||
options.enable_mem_pattern = False
|
|
||||||
options.enable_mem_reuse = True
|
|
||||||
self.ort_session_ = ort.InferenceSession(
|
|
||||||
self.model_path,
|
|
||||||
options,
|
|
||||||
providers=["CPUExecutionProvider"],
|
|
||||||
)
|
|
||||||
print(f"[{self.log_prefix}-ONNX] ONNX model loaded successfully: {self.model_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{self.log_prefix}] Failed to load ONNX model: {e}")
|
|
||||||
self.ort_session_ = None
|
|
||||||
|
|
||||||
def _reset_internal_state(self):
|
|
||||||
self.observations_.fill(0.0)
|
|
||||||
self.proprio_hist_buf_.fill(0.0)
|
|
||||||
self.last_actions_.fill(0.0)
|
|
||||||
self.actions_.fill(0.0)
|
|
||||||
self.is_first_obs_ = True
|
|
||||||
self.is_first_action_ = True
|
|
||||||
self.is_first_step_ = True
|
|
||||||
|
|
||||||
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
||||||
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = self.joint_pos_array
|
|
||||||
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
|
|
||||||
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
|
|
||||||
|
|
||||||
def on_enter(self):
|
|
||||||
self._reset_internal_state()
|
|
||||||
print(f"[{self.log_prefix}] enter")
|
|
||||||
self.timer_gait_ = 0.0
|
|
||||||
self.waiting_for_motion = True
|
|
||||||
self._warmup_inference_counter = 0
|
|
||||||
if self.robot_data_ is not None:
|
|
||||||
try:
|
|
||||||
current_pose = self.robot_data_.get_joint_pos().copy()
|
|
||||||
self._warm_start_pose = current_pose
|
|
||||||
self.hold_pose = current_pose
|
|
||||||
except Exception:
|
|
||||||
self._warm_start_pose.fill(0.0)
|
|
||||||
self.hold_pose.fill(0.0)
|
|
||||||
else:
|
|
||||||
self._warm_start_pose.fill(0.0)
|
|
||||||
self.hold_pose.fill(0.0)
|
|
||||||
print(f"[{self.log_prefix}] waiting for motion command before starting policy")
|
|
||||||
|
|
||||||
def run(self, flag: ControlFlag):
|
|
||||||
walk_cmd = np.array(self.robot_data_.get_walk_cmd(), dtype=np.float32)
|
|
||||||
if self.waiting_for_motion:
|
|
||||||
if np.max(np.abs(walk_cmd)) <= self.motion_threshold:
|
|
||||||
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
||||||
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = self.hold_pose
|
|
||||||
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
|
|
||||||
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
|
|
||||||
self.robot_data_.joint_kp_p_[:len(self.joint_xml)] = self.stiffness_array
|
|
||||||
self.robot_data_.joint_kd_p_[:len(self.joint_xml)] = self.damping_array
|
|
||||||
return
|
|
||||||
self.waiting_for_motion = False
|
|
||||||
self._warm_start_pose = self.robot_data_.get_joint_pos().copy()
|
|
||||||
self._warmup_inference_counter = 0
|
|
||||||
print(f"[{self.log_prefix}] motion command detected: {walk_cmd}, policy activated")
|
|
||||||
|
|
||||||
print(f"[{self.log_prefix}] run")
|
|
||||||
gait = gait_phase(
|
|
||||||
self.timer_gait_,
|
|
||||||
self.gait_cycle,
|
|
||||||
self.left_theta_offset,
|
|
||||||
self.right_theta_offset,
|
|
||||||
self.left_phase_ratio,
|
|
||||||
self.right_phase_ratio,
|
|
||||||
).astype(np.float32)
|
|
||||||
|
|
||||||
if int(self.robot_data_.time_now_ / self.dt_) % self.decimation_ == 0:
|
|
||||||
self.compute_observation(flag, gait)
|
|
||||||
self.compute_actions()
|
|
||||||
|
|
||||||
target_dof_pos_lab = self.actions_ * self.action_scale_lab + self.default_angles_lab
|
|
||||||
target_dof_pos_mj = self.robot_data_.get_joint_pos().copy()
|
|
||||||
target_dof_pos_mj[self.lab2mj] = target_dof_pos_lab
|
|
||||||
commanded_pos = target_dof_pos_mj
|
|
||||||
if self._warm_start_steps > 0 and self._warmup_inference_counter < self._warm_start_steps:
|
|
||||||
self._warmup_inference_counter += 1
|
|
||||||
blend = self._warmup_inference_counter / float(self._warm_start_steps)
|
|
||||||
commanded_pos = (1.0 - blend) * self._warm_start_pose + blend * target_dof_pos_mj
|
|
||||||
|
|
||||||
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
||||||
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = commanded_pos
|
|
||||||
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
|
|
||||||
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
|
|
||||||
self.last_actions_[:] = self.actions_
|
|
||||||
|
|
||||||
self.timer_gait_ += self.dt_
|
|
||||||
self.robot_data_.joint_kp_p_[:len(self.joint_xml)] = self.stiffness_array
|
|
||||||
self.robot_data_.joint_kd_p_[:len(self.joint_xml)] = self.damping_array
|
|
||||||
|
|
||||||
def compute_observation(self, flag: ControlFlag, gait):
|
|
||||||
roll, pitch, yaw = (
|
|
||||||
float(self.robot_data_.imu_data_[2]),
|
|
||||||
float(self.robot_data_.imu_data_[1]),
|
|
||||||
float(self.robot_data_.imu_data_[0]),
|
|
||||||
)
|
|
||||||
quat_wxyz = self.euler_to_quaternion_scipy(roll, pitch, yaw)
|
|
||||||
q_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]], dtype=np.float32)
|
|
||||||
gravity_init = self.quat_rotate_inverse_numpy(q_xyzw, np.array([0.0, 0.0, -1.0], dtype=np.float32))
|
|
||||||
|
|
||||||
x_speed_command, y_speed_command, yaw_speed_command = self.robot_data_.get_walk_cmd()
|
|
||||||
new_filtered_x_speed = x_speed_command
|
|
||||||
change = new_filtered_x_speed - self.filtered_x_speed
|
|
||||||
change = np.clip(change, -0.005, 0.005)
|
|
||||||
self.filtered_x_speed = self.filtered_x_speed + change
|
|
||||||
command = np.array(
|
|
||||||
[x_speed_command, y_speed_command, yaw_speed_command],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
print(f"Input command: {command}")
|
|
||||||
|
|
||||||
ang_vel = self.robot_data_.get_angular_velocity()
|
|
||||||
q_mj = self.robot_data_.get_joint_pos()
|
|
||||||
dq_mj = self.robot_data_.get_joint_vel()
|
|
||||||
qj = q_mj[self.lab2mj] - self.default_angles_lab
|
|
||||||
dqj = dq_mj[self.lab2mj]
|
|
||||||
|
|
||||||
proprio = np.concatenate([
|
|
||||||
ang_vel,
|
|
||||||
gravity_init,
|
|
||||||
command,
|
|
||||||
qj,
|
|
||||||
dqj,
|
|
||||||
self.last_actions_,
|
|
||||||
gait,
|
|
||||||
])
|
|
||||||
|
|
||||||
if self.is_first_obs_:
|
|
||||||
for i in range(self.num_hist_):
|
|
||||||
start_idx = i * self.obs_size_
|
|
||||||
end_idx = start_idx + self.obs_size_
|
|
||||||
self.proprio_hist_buf_[start_idx:end_idx] = proprio
|
|
||||||
self.is_first_obs_ = False
|
|
||||||
else:
|
|
||||||
shift_size = (self.num_hist_ - 1) * self.obs_size_
|
|
||||||
self.proprio_hist_buf_[:shift_size] = self.proprio_hist_buf_[self.obs_size_:]
|
|
||||||
self.proprio_hist_buf_[shift_size:] = proprio
|
|
||||||
|
|
||||||
self.observations_ = np.clip(self.proprio_hist_buf_, -self.clip_obs_, self.clip_obs_)
|
|
||||||
|
|
||||||
def compute_actions(self):
|
|
||||||
if self.ort_session_ is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
input_data = self.observations_.reshape(1, -1).astype(np.float32)
|
|
||||||
input_name = self.ort_session_.get_inputs()[0].name
|
|
||||||
outputs = self.ort_session_.run(None, {input_name: input_data})
|
|
||||||
output_data = outputs[0][0]
|
|
||||||
for i in range(self.action_num_):
|
|
||||||
self.actions_[i] = np.clip(output_data[i], -self.clip_act_, self.clip_act_)
|
|
||||||
|
|
||||||
if self.is_first_action_:
|
|
||||||
print(f"[{self.log_prefix}-ONNX] First Observation:")
|
|
||||||
for i in range(self.obs_size_):
|
|
||||||
print(f"{self.observations_[i]:.6f} ", end="")
|
|
||||||
print()
|
|
||||||
self.is_first_action_ = False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{self.log_prefix}] ONNX Runtime inference error: {e}")
|
|
||||||
|
|
||||||
def on_exit(self):
|
|
||||||
print(f"[{self.log_prefix}] exit")
|
|
||||||
if getattr(self, "obs_log_file", None) is not None:
|
|
||||||
try:
|
|
||||||
self.obs_log_file.flush()
|
|
||||||
self.obs_log_file.close()
|
|
||||||
print(f"[{self.log_prefix}] obs log saved to {self.obs_log_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{self.log_prefix}] failed to close obs log: {e}")
|
|
||||||
self.obs_log_file = None
|
|
||||||
|
|
||||||
def check_transition(self, flag: ControlFlag) -> FSMStateName:
|
|
||||||
if flag.fsm_state_command == "gotoSTOP":
|
|
||||||
return FSMStateName.STOP
|
|
||||||
if flag.fsm_state_command == "gotoWALKAMP":
|
|
||||||
return FSMStateName.WALKAMP
|
|
||||||
if flag.fsm_state_command == "gotoMYPOLICY":
|
|
||||||
return FSMStateName.MYPOLICY
|
|
||||||
if flag.fsm_state_command == "gotoXSIMRUN":
|
|
||||||
return FSMStateName.XSIMRUN
|
|
||||||
if flag.fsm_state_command == "gotoZERO":
|
|
||||||
return FSMStateName.ZERO
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def euler_to_quaternion_scipy(roll, pitch, yaw, degrees=False):
|
|
||||||
r = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=degrees)
|
|
||||||
q_xyzw = r.as_quat()
|
|
||||||
return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]], dtype=np.float32)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def quat_rotate_inverse_numpy(q_xyzw, v):
|
|
||||||
q_w = q_xyzw[3]
|
|
||||||
q_v = q_xyzw[:3]
|
|
||||||
a = v * (2.0 * q_w * q_w - 1.0)
|
|
||||||
b = np.cross(q_v, v) * (2.0 * q_w)
|
|
||||||
c = q_v * (2.0 * np.dot(q_v, v))
|
|
||||||
return a - b + c
|
|
||||||
Binary file not shown.
@@ -82,10 +82,6 @@ class FSMStateStop(FSMState):
|
|||||||
return FSMStateName.ZERO
|
return FSMStateName.ZERO
|
||||||
elif flag.fsm_state_command == "gotoWALKAMP":
|
elif flag.fsm_state_command == "gotoWALKAMP":
|
||||||
return FSMStateName.WALKAMP
|
return FSMStateName.WALKAMP
|
||||||
elif flag.fsm_state_command == "gotoMYPOLICY":
|
|
||||||
return FSMStateName.MYPOLICY
|
|
||||||
elif flag.fsm_state_command == "gotoXSIMRUN":
|
|
||||||
return FSMStateName.XSIMRUN
|
|
||||||
elif flag.fsm_state_command == "gotoBEYONDZERO":
|
elif flag.fsm_state_command == "gotoBEYONDZERO":
|
||||||
return FSMStateName.BEYONDZERO
|
return FSMStateName.BEYONDZERO
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -446,8 +446,6 @@ class FSMStateWALKAMP(FSMState):
|
|||||||
return FSMStateName.STOP
|
return FSMStateName.STOP
|
||||||
elif flag.fsm_state_command == "gotoWALKAMP":
|
elif flag.fsm_state_command == "gotoWALKAMP":
|
||||||
return FSMStateName.WALKAMP
|
return FSMStateName.WALKAMP
|
||||||
elif flag.fsm_state_command == "gotoXSIMRUN":
|
|
||||||
return FSMStateName.XSIMRUN
|
|
||||||
elif flag.fsm_state_command == "gotoZERO":
|
elif flag.fsm_state_command == "gotoZERO":
|
||||||
return FSMStateName.ZERO
|
return FSMStateName.ZERO
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
model_path: "../mypolicy/model/policy.onnx"
|
|
||||||
motor_num: 29
|
|
||||||
actions_size: 23
|
|
||||||
dt: 0.01
|
|
||||||
warm_start_time: 0.0
|
|
||||||
command_clip: 1.0
|
|
||||||
|
|
||||||
sim:
|
|
||||||
mujoco_timestep: 0.005
|
|
||||||
|
|
||||||
joint_names: [
|
|
||||||
hip_pitch_l_joint, hip_pitch_r_joint, waist_yaw_joint,
|
|
||||||
hip_roll_l_joint, hip_roll_r_joint, waist_roll_joint,
|
|
||||||
hip_yaw_l_joint, hip_yaw_r_joint, waist_pitch_joint,
|
|
||||||
knee_pitch_l_joint, knee_pitch_r_joint,
|
|
||||||
shoulder_pitch_l_joint, shoulder_pitch_r_joint,
|
|
||||||
ankle_pitch_l_joint, ankle_pitch_r_joint,
|
|
||||||
shoulder_roll_l_joint, shoulder_roll_r_joint,
|
|
||||||
ankle_roll_l_joint, ankle_roll_r_joint,
|
|
||||||
shoulder_yaw_l_joint, shoulder_yaw_r_joint,
|
|
||||||
elbow_pitch_l_joint, elbow_pitch_r_joint
|
|
||||||
]
|
|
||||||
|
|
||||||
control:
|
|
||||||
action_scale: 0.25
|
|
||||||
decimation: 2
|
|
||||||
|
|
||||||
gait:
|
|
||||||
gait_air_ratio_l: 0.6
|
|
||||||
gait_air_ratio_r: 0.6
|
|
||||||
gait_phase_offset_l: 0.6
|
|
||||||
gait_phase_offset_r: 0.1
|
|
||||||
gait_cycle: 0.64
|
|
||||||
|
|
||||||
normalization:
|
|
||||||
clip_scales:
|
|
||||||
clip_observations: 100.0
|
|
||||||
clip_actions: 100.0
|
|
||||||
|
|
||||||
size:
|
|
||||||
num_hist: 10
|
|
||||||
observations_size: 84
|
|
||||||
|
|
||||||
gains:
|
|
||||||
kp: [
|
|
||||||
300.0, 300.0, 400.0,
|
|
||||||
300.0, 300.0, 400.0,
|
|
||||||
150.0, 150.0, 400.0,
|
|
||||||
350.0, 350.0,
|
|
||||||
150.0, 150.0,
|
|
||||||
30.0, 30.0,
|
|
||||||
50.0, 50.0,
|
|
||||||
16.8, 16.8,
|
|
||||||
50.0, 50.0,
|
|
||||||
150.0, 150.0
|
|
||||||
]
|
|
||||||
kd: [
|
|
||||||
10.0, 10.0, 5.0,
|
|
||||||
10.0, 10.0, 10.0,
|
|
||||||
5.0, 5.0, 10.0,
|
|
||||||
10.0, 10.0,
|
|
||||||
7.5, 7.5,
|
|
||||||
2.5, 2.5,
|
|
||||||
2.5, 2.5,
|
|
||||||
1.4, 1.4,
|
|
||||||
2.5, 2.5,
|
|
||||||
5.0, 5.0
|
|
||||||
]
|
|
||||||
|
|
||||||
init_state:
|
|
||||||
default_joint_angles: [
|
|
||||||
0.0, 0.0, 0.0,
|
|
||||||
-0.5, -0.5, 0.0,
|
|
||||||
0.0, 0.0, 0.0,
|
|
||||||
1.0, 1.0,
|
|
||||||
0.0, 0.0,
|
|
||||||
-0.5, -0.5,
|
|
||||||
0.2, -0.2,
|
|
||||||
0.0, 0.0,
|
|
||||||
0.0, 0.0,
|
|
||||||
-0.3, -0.3
|
|
||||||
]
|
|
||||||
@@ -1,303 +0,0 @@
|
|||||||
"""
|
|
||||||
FSM state implementation for an xSIM MuJoCo run policy that follows the
|
|
||||||
TienKung-Lab sim2sim observation/action flow more closely.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime as ort
|
|
||||||
import yaml
|
|
||||||
from scipy.spatial.transform import Rotation
|
|
||||||
|
|
||||||
from FSM.fsm_base import FSMState, FSMStateName
|
|
||||||
from common.joystick import ControlFlag
|
|
||||||
from common.robot_data import RobotData
|
|
||||||
|
|
||||||
|
|
||||||
class FSMStateXSIMRUN(FSMState):
|
|
||||||
"""Direct-position run policy for xSIM MuJoCo."""
|
|
||||||
|
|
||||||
def __init__(self, robot_data: RobotData):
|
|
||||||
super().__init__(robot_data)
|
|
||||||
self.current_state_name = FSMStateName.XSIMRUN
|
|
||||||
self.log_prefix = "FSMStateXSIMRUN"
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
config_path = os.path.join(current_dir, "config", "xsim_run.yaml")
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
policy_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
self.action_num_ = int(policy_config["actions_size"])
|
|
||||||
self.motor_num_ = int(policy_config["motor_num"])
|
|
||||||
self.dt_ = float(policy_config["dt"])
|
|
||||||
self.command_clip_ = float(policy_config.get("command_clip", 1.0))
|
|
||||||
|
|
||||||
size_config = policy_config.get("size", {})
|
|
||||||
self.num_hist_ = int(size_config["num_hist"])
|
|
||||||
self.obs_size_ = int(size_config["observations_size"])
|
|
||||||
|
|
||||||
control_config = policy_config.get("control", {})
|
|
||||||
self.action_scale_ = float(control_config["action_scale"])
|
|
||||||
self.decimation_ = int(control_config["decimation"])
|
|
||||||
self.warm_start_time_ = float(
|
|
||||||
control_config.get(
|
|
||||||
"warm_start_time",
|
|
||||||
policy_config.get("warm_start_time", 0.0),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sim_config = policy_config.get("sim", {})
|
|
||||||
self.mujoco_timestep_ = float(sim_config.get("mujoco_timestep", 0.005))
|
|
||||||
self.policy_period_ = self.dt_ * self.decimation_
|
|
||||||
|
|
||||||
gait_config = policy_config.get("gait", {})
|
|
||||||
self.gait_cycle_ = float(gait_config["gait_cycle"])
|
|
||||||
self.phase_ratio_ = np.array(
|
|
||||||
[
|
|
||||||
gait_config["gait_air_ratio_l"],
|
|
||||||
gait_config["gait_air_ratio_r"],
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
self.phase_offset_ = np.array(
|
|
||||||
[
|
|
||||||
gait_config["gait_phase_offset_l"],
|
|
||||||
gait_config["gait_phase_offset_r"],
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
norm_config = policy_config.get("normalization", {})
|
|
||||||
clip_config = norm_config.get("clip_scales", {})
|
|
||||||
self.clip_obs_ = float(clip_config.get("clip_observations", 100.0))
|
|
||||||
self.clip_act_ = float(clip_config.get("clip_actions", 100.0))
|
|
||||||
|
|
||||||
self.default_angles_lab_ = np.array(
|
|
||||||
policy_config["init_state"]["default_joint_angles"],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
self.stiffness_lab_ = np.array(policy_config["gains"]["kp"], dtype=np.float32)
|
|
||||||
self.damping_lab_ = np.array(policy_config["gains"]["kd"], dtype=np.float32)
|
|
||||||
|
|
||||||
model_rel_path = policy_config["model_path"]
|
|
||||||
self.model_path_ = os.path.normpath(os.path.join(current_dir, model_rel_path))
|
|
||||||
self._init_onnx_session()
|
|
||||||
|
|
||||||
# sim2sim.py uses policy output in Isaac order and then maps to MuJoCo order.
|
|
||||||
self.mujoco_to_policy_idx_ = np.array(
|
|
||||||
[0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 19, 4, 10, 16, 20, 5, 11, 17, 21, 18, 22],
|
|
||||||
dtype=int,
|
|
||||||
)
|
|
||||||
self.policy_to_mujoco_idx_ = np.array(
|
|
||||||
[0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8, 11, 15, 19, 21, 12, 16, 20, 22],
|
|
||||||
dtype=int,
|
|
||||||
)
|
|
||||||
|
|
||||||
# RobotData stores 29 joints in leg -> waist -> arm order.
|
|
||||||
self.mujoco_control_indices_ = np.array(
|
|
||||||
[
|
|
||||||
0, 1, 2, 3, 4, 5,
|
|
||||||
6, 7, 8, 9, 10, 11,
|
|
||||||
12, 13, 14,
|
|
||||||
15, 16, 17, 18,
|
|
||||||
22, 23, 24, 25,
|
|
||||||
],
|
|
||||||
dtype=int,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.default_angles_mujoco23_ = self.default_angles_lab_[self.policy_to_mujoco_idx_]
|
|
||||||
self.observations_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
|
|
||||||
self.obs_history_ = np.zeros_like(self.observations_)
|
|
||||||
self.actions_ = np.zeros(self.action_num_, dtype=np.float32)
|
|
||||||
self.last_actions_ = np.zeros(self.action_num_, dtype=np.float32)
|
|
||||||
self.current_gait_ = np.zeros(6, dtype=np.float32)
|
|
||||||
self.hold_pose_29_ = np.zeros(self.motor_num_, dtype=np.float32)
|
|
||||||
self._warm_start_pose_29_ = np.zeros(self.motor_num_, dtype=np.float32)
|
|
||||||
self._first_obs = True
|
|
||||||
self._policy_step_counter = 0
|
|
||||||
self.waiting_for_motion_ = True
|
|
||||||
self.motion_threshold_ = 1e-3
|
|
||||||
|
|
||||||
if self.warm_start_time_ > 0 and self.policy_period_ > 0:
|
|
||||||
self._warm_start_steps = max(1, int(self.warm_start_time_ / self.policy_period_))
|
|
||||||
else:
|
|
||||||
self._warm_start_steps = 0
|
|
||||||
self._warmup_inference_counter = 0
|
|
||||||
|
|
||||||
self.kp_29_ = np.zeros(self.motor_num_, dtype=np.float32)
|
|
||||||
self.kd_29_ = np.zeros(self.motor_num_, dtype=np.float32)
|
|
||||||
for lab_idx, mj_idx in enumerate(self.mujoco_control_indices_[self.policy_to_mujoco_idx_]):
|
|
||||||
self.kp_29_[mj_idx] = self.stiffness_lab_[lab_idx]
|
|
||||||
self.kd_29_[mj_idx] = self.damping_lab_[lab_idx]
|
|
||||||
|
|
||||||
def _init_onnx_session(self) -> None:
|
|
||||||
options = ort.SessionOptions()
|
|
||||||
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
options.intra_op_num_threads = 1
|
|
||||||
options.inter_op_num_threads = 1
|
|
||||||
self.ort_session_ = ort.InferenceSession(
|
|
||||||
self.model_path_,
|
|
||||||
options,
|
|
||||||
providers=["CPUExecutionProvider"],
|
|
||||||
)
|
|
||||||
print(f"[{self.log_prefix}] ONNX model loaded: {self.model_path_}")
|
|
||||||
|
|
||||||
def on_enter(self):
|
|
||||||
self.observations_.fill(0.0)
|
|
||||||
self.obs_history_.fill(0.0)
|
|
||||||
self.actions_.fill(0.0)
|
|
||||||
self.last_actions_.fill(0.0)
|
|
||||||
self.current_gait_.fill(0.0)
|
|
||||||
self._first_obs = True
|
|
||||||
self._policy_step_counter = 0
|
|
||||||
self._warmup_inference_counter = 0
|
|
||||||
self.waiting_for_motion_ = True
|
|
||||||
|
|
||||||
current_q = self.robot_data_.get_joint_pos().copy()
|
|
||||||
self.hold_pose_29_ = current_q
|
|
||||||
self._warm_start_pose_29_ = current_q
|
|
||||||
|
|
||||||
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
||||||
self.robot_data_.q_d_[base:base + self.motor_num_] = current_q
|
|
||||||
self.robot_data_.q_dot_d_[base:base + self.motor_num_] = 0.0
|
|
||||||
self.robot_data_.tau_d_[base:base + self.motor_num_] = 0.0
|
|
||||||
self.robot_data_.joint_kp_p_[:self.motor_num_] = self.kp_29_
|
|
||||||
self.robot_data_.joint_kd_p_[:self.motor_num_] = self.kd_29_
|
|
||||||
print(f"[{self.log_prefix}] enter")
|
|
||||||
|
|
||||||
def run(self, flag: ControlFlag):
|
|
||||||
walk_cmd = np.clip(
|
|
||||||
np.array(self.robot_data_.get_walk_cmd(), dtype=np.float32),
|
|
||||||
-self.command_clip_,
|
|
||||||
self.command_clip_,
|
|
||||||
)
|
|
||||||
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
||||||
|
|
||||||
if self.waiting_for_motion_:
|
|
||||||
if np.max(np.abs(walk_cmd)) <= self.motion_threshold_:
|
|
||||||
self.robot_data_.q_d_[base:base + self.motor_num_] = self.hold_pose_29_
|
|
||||||
self.robot_data_.q_dot_d_[base:base + self.motor_num_] = 0.0
|
|
||||||
self.robot_data_.tau_d_[base:base + self.motor_num_] = 0.0
|
|
||||||
self.robot_data_.joint_kp_p_[:self.motor_num_] = self.kp_29_
|
|
||||||
self.robot_data_.joint_kd_p_[:self.motor_num_] = self.kd_29_
|
|
||||||
return
|
|
||||||
self.waiting_for_motion_ = False
|
|
||||||
self._warm_start_pose_29_ = self.robot_data_.get_joint_pos().copy()
|
|
||||||
print(f"[{self.log_prefix}] motion command detected: {walk_cmd}")
|
|
||||||
|
|
||||||
if int(self.robot_data_.time_now_ / self.dt_) % self.decimation_ == 0:
|
|
||||||
self.current_gait_ = self._compute_gait_features()
|
|
||||||
self.compute_observation(walk_cmd)
|
|
||||||
self.compute_actions()
|
|
||||||
|
|
||||||
target_mujoco23 = (
|
|
||||||
self.actions_[self.policy_to_mujoco_idx_] * self.action_scale_
|
|
||||||
+ self.default_angles_mujoco23_
|
|
||||||
)
|
|
||||||
target_q_29 = self.hold_pose_29_.copy()
|
|
||||||
target_q_29[self.mujoco_control_indices_] = target_mujoco23
|
|
||||||
|
|
||||||
commanded_q_29 = target_q_29
|
|
||||||
if self._warm_start_steps > 0 and self._warmup_inference_counter < self._warm_start_steps:
|
|
||||||
self._warmup_inference_counter += 1
|
|
||||||
blend = self._warmup_inference_counter / float(self._warm_start_steps)
|
|
||||||
commanded_q_29 = (1.0 - blend) * self._warm_start_pose_29_ + blend * target_q_29
|
|
||||||
|
|
||||||
self.robot_data_.q_d_[base:base + self.motor_num_] = commanded_q_29
|
|
||||||
self.robot_data_.q_dot_d_[base:base + self.motor_num_] = 0.0
|
|
||||||
self.robot_data_.tau_d_[base:base + self.motor_num_] = 0.0
|
|
||||||
self.robot_data_.joint_kp_p_[:self.motor_num_] = self.kp_29_
|
|
||||||
self.robot_data_.joint_kd_p_[:self.motor_num_] = self.kd_29_
|
|
||||||
self.last_actions_[:] = self.actions_
|
|
||||||
|
|
||||||
def _compute_gait_features(self) -> np.ndarray:
|
|
||||||
t = self._policy_step_counter * self.policy_period_ / self.gait_cycle_
|
|
||||||
gait_phase = (t + self.phase_offset_) % 1.0
|
|
||||||
self._policy_step_counter += 1
|
|
||||||
return np.concatenate(
|
|
||||||
[
|
|
||||||
np.sin(2.0 * np.pi * gait_phase),
|
|
||||||
np.cos(2.0 * np.pi * gait_phase),
|
|
||||||
self.phase_ratio_,
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
).astype(np.float32)
|
|
||||||
|
|
||||||
def compute_observation(self, walk_cmd: np.ndarray):
|
|
||||||
if np.linalg.norm(self.robot_data_.imu_quat_) > 0.0:
|
|
||||||
q_wxyz = self.robot_data_.imu_quat_.astype(np.float32)
|
|
||||||
q_xyzw = np.array([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]], dtype=np.float32)
|
|
||||||
else:
|
|
||||||
roll = float(self.robot_data_.imu_data_[2])
|
|
||||||
pitch = float(self.robot_data_.imu_data_[1])
|
|
||||||
yaw = float(self.robot_data_.imu_data_[0])
|
|
||||||
q_wxyz = self.euler_to_quaternion_scipy(roll, pitch, yaw)
|
|
||||||
q_xyzw = np.array([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]], dtype=np.float32)
|
|
||||||
|
|
||||||
gravity = self.quat_rotate_inverse_numpy(q_xyzw, np.array([0.0, 0.0, -1.0], dtype=np.float32))
|
|
||||||
q_29 = self.robot_data_.get_joint_pos()
|
|
||||||
dq_29 = self.robot_data_.get_joint_vel()
|
|
||||||
q_23 = q_29[self.mujoco_control_indices_]
|
|
||||||
dq_23 = dq_29[self.mujoco_control_indices_]
|
|
||||||
|
|
||||||
proprio = np.concatenate(
|
|
||||||
[
|
|
||||||
self.robot_data_.get_angular_velocity(),
|
|
||||||
gravity,
|
|
||||||
walk_cmd,
|
|
||||||
(q_23 - self.default_angles_mujoco23_)[self.mujoco_to_policy_idx_],
|
|
||||||
dq_23[self.mujoco_to_policy_idx_],
|
|
||||||
np.clip(self.last_actions_, -self.clip_act_, self.clip_act_),
|
|
||||||
self.current_gait_,
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
).astype(np.float32)
|
|
||||||
|
|
||||||
if self._first_obs:
|
|
||||||
for i in range(self.num_hist_):
|
|
||||||
start = i * self.obs_size_
|
|
||||||
self.obs_history_[start:start + self.obs_size_] = proprio
|
|
||||||
self._first_obs = False
|
|
||||||
else:
|
|
||||||
self.obs_history_ = np.roll(self.obs_history_, -self.obs_size_)
|
|
||||||
self.obs_history_[-self.obs_size_:] = proprio
|
|
||||||
|
|
||||||
self.observations_ = np.clip(self.obs_history_, -self.clip_obs_, self.clip_obs_)
|
|
||||||
|
|
||||||
def compute_actions(self):
|
|
||||||
input_name = self.ort_session_.get_inputs()[0].name
|
|
||||||
input_data = self.observations_.reshape(1, -1).astype(np.float32)
|
|
||||||
outputs = self.ort_session_.run(None, {input_name: input_data})
|
|
||||||
self.actions_[:] = np.clip(outputs[0][0][: self.action_num_], -self.clip_act_, self.clip_act_)
|
|
||||||
|
|
||||||
def on_exit(self):
|
|
||||||
print(f"[{self.log_prefix}] exit")
|
|
||||||
|
|
||||||
def check_transition(self, flag: ControlFlag) -> FSMStateName:
|
|
||||||
if flag.fsm_state_command == "gotoSTOP":
|
|
||||||
return FSMStateName.STOP
|
|
||||||
if flag.fsm_state_command == "gotoZERO":
|
|
||||||
return FSMStateName.ZERO
|
|
||||||
if flag.fsm_state_command == "gotoWALKAMP":
|
|
||||||
return FSMStateName.WALKAMP
|
|
||||||
if flag.fsm_state_command == "gotoMYPOLICY":
|
|
||||||
return FSMStateName.MYPOLICY
|
|
||||||
if flag.fsm_state_command == "gotoXSIMRUN":
|
|
||||||
return FSMStateName.XSIMRUN
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def euler_to_quaternion_scipy(roll, pitch, yaw, degrees=False):
|
|
||||||
r = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=degrees)
|
|
||||||
q_xyzw = r.as_quat()
|
|
||||||
return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]], dtype=np.float32)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def quat_rotate_inverse_numpy(q_xyzw, v):
|
|
||||||
q_w = q_xyzw[3]
|
|
||||||
q_v = q_xyzw[:3]
|
|
||||||
a = v * (2.0 * q_w * q_w - 1.0)
|
|
||||||
b = np.cross(q_v, v) * (2.0 * q_w)
|
|
||||||
c = q_v * (2.0 * np.dot(q_v, v))
|
|
||||||
return a - b + c
|
|
||||||
@@ -75,10 +75,6 @@ class FSMStateZero(FSMState):
|
|||||||
return FSMStateName.STOP
|
return FSMStateName.STOP
|
||||||
elif flag.fsm_state_command == "gotoWALKAMP":
|
elif flag.fsm_state_command == "gotoWALKAMP":
|
||||||
return FSMStateName.WALKAMP
|
return FSMStateName.WALKAMP
|
||||||
elif flag.fsm_state_command == "gotoMYPOLICY":
|
|
||||||
return FSMStateName.MYPOLICY
|
|
||||||
elif flag.fsm_state_command == "gotoXSIMRUN":
|
|
||||||
return FSMStateName.XSIMRUN
|
|
||||||
elif flag.fsm_state_command == "gotoZERO":
|
elif flag.fsm_state_command == "gotoZERO":
|
||||||
return FSMStateName.ZERO
|
return FSMStateName.ZERO
|
||||||
elif flag.fsm_state_command == "gotoBEYONDZERO":
|
elif flag.fsm_state_command == "gotoBEYONDZERO":
|
||||||
|
|||||||
@@ -374,8 +374,6 @@ class XMIGCSControlNode(Node):
|
|||||||
FSMStateName.STOP: "gotoSTOP",
|
FSMStateName.STOP: "gotoSTOP",
|
||||||
FSMStateName.ZERO: "gotoZERO",
|
FSMStateName.ZERO: "gotoZERO",
|
||||||
FSMStateName.WALKAMP: "gotoWALKAMP",
|
FSMStateName.WALKAMP: "gotoWALKAMP",
|
||||||
FSMStateName.MYPOLICY: "gotoMYPOLICY",
|
|
||||||
FSMStateName.XSIMRUN: "gotoXSIMRUN",
|
|
||||||
}
|
}
|
||||||
current_state = self.robot_fsm.get_current_state()
|
current_state = self.robot_fsm.get_current_state()
|
||||||
return state_to_command.get(current_state, self.control_flag.fsm_state_command)
|
return state_to_command.get(current_state, self.control_flag.fsm_state_command)
|
||||||
|
|||||||
@@ -84,7 +84,6 @@ python3 udp_loopback/udp_xbox_sender.py
|
|||||||
- `A -> mode_stride -> gotoWALKAMP`
|
- `A -> mode_stride -> gotoWALKAMP`
|
||||||
- `X -> pose_home -> gotoZERO`
|
- `X -> pose_home -> gotoZERO`
|
||||||
- `Y -> pose_hold -> gotoSTOP`
|
- `Y -> pose_hold -> gotoSTOP`
|
||||||
- `B -> mode_dash -> gotoMYPOLICY`
|
|
||||||
- `START -> trim_reset`
|
- `START -> trim_reset`
|
||||||
- 左摇杆 Y -> 连续前后速度
|
- 左摇杆 Y -> 连续前后速度
|
||||||
- 左摇杆 X -> 连续横移速度
|
- 左摇杆 X -> 连续横移速度
|
||||||
@@ -96,7 +95,6 @@ python3 udp_loopback/udp_xbox_sender.py
|
|||||||
- `pose_home -> gotoZERO`
|
- `pose_home -> gotoZERO`
|
||||||
- `pose_hold -> gotoSTOP`
|
- `pose_hold -> gotoSTOP`
|
||||||
- `mode_stride -> gotoWALKAMP`
|
- `mode_stride -> gotoWALKAMP`
|
||||||
- `mode_dash -> gotoMYPOLICY`
|
|
||||||
- `surge/sway/spin -> x/y/yaw speed command`
|
- `surge/sway/spin -> x/y/yaw speed command`
|
||||||
|
|
||||||
当前事件码不是原工程里的 `gotoZERO` / `gotoSTOP` / `x_speed_command` 这一套,而是:
|
当前事件码不是原工程里的 `gotoZERO` / `gotoSTOP` / `x_speed_command` 这一套,而是:
|
||||||
@@ -104,7 +102,6 @@ python3 udp_loopback/udp_xbox_sender.py
|
|||||||
- `pose_home`
|
- `pose_home`
|
||||||
- `pose_hold`
|
- `pose_hold`
|
||||||
- `mode_stride`
|
- `mode_stride`
|
||||||
- `mode_dash`
|
|
||||||
- `surge_up`
|
- `surge_up`
|
||||||
- `surge_down`
|
- `surge_down`
|
||||||
- `sway_left`
|
- `sway_left`
|
||||||
|
|||||||
@@ -161,12 +161,6 @@ class UDPFSMController:
|
|||||||
elif event_code == "mode_stride":
|
elif event_code == "mode_stride":
|
||||||
self.motion_frame.mode_tag = "mode_stride"
|
self.motion_frame.mode_tag = "mode_stride"
|
||||||
self.last_fsm_command_time = packet.sent_at
|
self.last_fsm_command_time = packet.sent_at
|
||||||
elif event_code == "mode_dash":
|
|
||||||
self.motion_frame.mode_tag = "mode_dash"
|
|
||||||
self.last_fsm_command_time = packet.sent_at
|
|
||||||
elif event_code == "mode_xrun":
|
|
||||||
self.motion_frame.mode_tag = "mode_xrun"
|
|
||||||
self.last_fsm_command_time = packet.sent_at
|
|
||||||
elif event_code == "surge_up":
|
elif event_code == "surge_up":
|
||||||
self.motion_frame.surge_goal = min(
|
self.motion_frame.surge_goal = min(
|
||||||
self.max_surge, self.motion_frame.surge_goal + self.surge_step
|
self.max_surge, self.motion_frame.surge_goal + self.surge_step
|
||||||
@@ -236,8 +230,6 @@ class UDPFSMController:
|
|||||||
"pose_home": "gotoZERO",
|
"pose_home": "gotoZERO",
|
||||||
"pose_hold": "gotoSTOP",
|
"pose_hold": "gotoSTOP",
|
||||||
"mode_stride": "gotoWALKAMP",
|
"mode_stride": "gotoWALKAMP",
|
||||||
"mode_dash": "gotoMYPOLICY",
|
|
||||||
"mode_xrun": "gotoXSIMRUN",
|
|
||||||
}
|
}
|
||||||
self.udp_flag.enable = self.motion_frame.relay_on
|
self.udp_flag.enable = self.motion_frame.relay_on
|
||||||
self.udp_flag.fsm_state_command = mode_to_fsm_command.get(
|
self.udp_flag.fsm_state_command = mode_to_fsm_command.get(
|
||||||
|
|||||||
@@ -56,8 +56,6 @@ class UDPKeyboardSender:
|
|||||||
print(" z -> pose_home")
|
print(" z -> pose_home")
|
||||||
print(" c -> pose_hold")
|
print(" c -> pose_hold")
|
||||||
print(" m -> mode_stride")
|
print(" m -> mode_stride")
|
||||||
print(" p -> mode_dash")
|
|
||||||
print(" n -> mode_xrun")
|
|
||||||
print(" w/s -> surge +/-")
|
print(" w/s -> surge +/-")
|
||||||
print(" a/d -> sway +/-")
|
print(" a/d -> sway +/-")
|
||||||
print(" q/e -> spin +/-")
|
print(" q/e -> spin +/-")
|
||||||
@@ -118,8 +116,6 @@ class UDPKeyboardSender:
|
|||||||
"z": ("pose_home", "z", 1.0),
|
"z": ("pose_home", "z", 1.0),
|
||||||
"c": ("pose_hold", "c", 1.0),
|
"c": ("pose_hold", "c", 1.0),
|
||||||
"m": ("mode_stride", "m", 1.0),
|
"m": ("mode_stride", "m", 1.0),
|
||||||
"p": ("mode_dash", "p", 1.0),
|
|
||||||
"n": ("mode_xrun", "n", 1.0),
|
|
||||||
"r": ("trim_reset", "r", 1.0),
|
"r": ("trim_reset", "r", 1.0),
|
||||||
"4": ("set_surge", "4", 0.0),
|
"4": ("set_surge", "4", 0.0),
|
||||||
"5": ("set_sway", "5", 0.0),
|
"5": ("set_sway", "5", 0.0),
|
||||||
|
|||||||
@@ -138,10 +138,6 @@ class UDPLoopbackNode:
|
|||||||
self.motion_frame.mode_tag = "pose_hold"
|
self.motion_frame.mode_tag = "pose_hold"
|
||||||
elif event_code == "mode_stride":
|
elif event_code == "mode_stride":
|
||||||
self.motion_frame.mode_tag = "mode_stride"
|
self.motion_frame.mode_tag = "mode_stride"
|
||||||
elif event_code == "mode_dash":
|
|
||||||
self.motion_frame.mode_tag = "mode_dash"
|
|
||||||
elif event_code == "mode_xrun":
|
|
||||||
self.motion_frame.mode_tag = "mode_xrun"
|
|
||||||
elif event_code == "surge_up":
|
elif event_code == "surge_up":
|
||||||
self.motion_frame.surge_goal = min(
|
self.motion_frame.surge_goal = min(
|
||||||
self.max_surge, self.motion_frame.surge_goal + self.surge_step
|
self.max_surge, self.motion_frame.surge_goal + self.surge_step
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class UDPXboxSender(Node):
|
|||||||
f"Forwarding {self.joy_topic} -> udp://{self.target_host}:{self.target_port}"
|
f"Forwarding {self.joy_topic} -> udp://{self.target_host}:{self.target_port}"
|
||||||
)
|
)
|
||||||
self.get_logger().info(
|
self.get_logger().info(
|
||||||
"Buttons: A=WALKAMP X=ZERO Y=STOP B=MYPOLICY START=reset"
|
"Buttons: A=WALKAMP X=ZERO Y=STOP START=reset"
|
||||||
)
|
)
|
||||||
|
|
||||||
def destroy_node(self) -> bool:
|
def destroy_node(self) -> bool:
|
||||||
@@ -167,13 +167,6 @@ class UDPXboxSender(Node):
|
|||||||
self._send_event("pose_home", "x")
|
self._send_event("pose_home", "x")
|
||||||
elif self._rising_edge(state, "a"):
|
elif self._rising_edge(state, "a"):
|
||||||
self._send_event("mode_stride", "a")
|
self._send_event("mode_stride", "a")
|
||||||
elif self._rising_edge(state, "b"):
|
|
||||||
self._send_event("mode_dash", "b")
|
|
||||||
elif (
|
|
||||||
self._rising_edge(state, "home")
|
|
||||||
and state["l_trigger"] < self.trigger_pressed_threshold
|
|
||||||
):
|
|
||||||
self._send_event("mode_xrun", "home")
|
|
||||||
|
|
||||||
def _send_trim_event(self, state: Dict[str, float]) -> None:
|
def _send_trim_event(self, state: Dict[str, float]) -> None:
|
||||||
if self._rising_edge(state, "start"):
|
if self._rising_edge(state, "start"):
|
||||||
|
|||||||
Reference in New Issue
Block a user