""" FSM state implementation for an xSIM MuJoCo run policy that follows the TienKung-Lab sim2sim observation/action flow more closely. """ import os import numpy as np import onnxruntime as ort import yaml from scipy.spatial.transform import Rotation from FSM.fsm_base import FSMState, FSMStateName from common.joystick import ControlFlag from common.robot_data import RobotData class FSMStateXSIMRUN(FSMState): """Direct-position run policy for xSIM MuJoCo.""" def __init__(self, robot_data: RobotData): super().__init__(robot_data) self.current_state_name = FSMStateName.XSIMRUN self.log_prefix = "FSMStateXSIMRUN" current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, "config", "xsim_run.yaml") with open(config_path, "r", encoding="utf-8") as f: policy_config = yaml.safe_load(f) self.action_num_ = int(policy_config["actions_size"]) self.motor_num_ = int(policy_config["motor_num"]) self.dt_ = float(policy_config["dt"]) self.command_clip_ = float(policy_config.get("command_clip", 1.0)) size_config = policy_config.get("size", {}) self.num_hist_ = int(size_config["num_hist"]) self.obs_size_ = int(size_config["observations_size"]) control_config = policy_config.get("control", {}) self.action_scale_ = float(control_config["action_scale"]) self.decimation_ = int(control_config["decimation"]) self.warm_start_time_ = float( control_config.get( "warm_start_time", policy_config.get("warm_start_time", 0.0), ) ) sim_config = policy_config.get("sim", {}) self.mujoco_timestep_ = float(sim_config.get("mujoco_timestep", 0.005)) self.policy_period_ = self.dt_ * self.decimation_ gait_config = policy_config.get("gait", {}) self.gait_cycle_ = float(gait_config["gait_cycle"]) self.phase_ratio_ = np.array( [ gait_config["gait_air_ratio_l"], gait_config["gait_air_ratio_r"], ], dtype=np.float32, ) self.phase_offset_ = np.array( [ gait_config["gait_phase_offset_l"], gait_config["gait_phase_offset_r"], ], dtype=np.float32, ) norm_config = policy_config.get("normalization", {}) clip_config = norm_config.get("clip_scales", {}) self.clip_obs_ = float(clip_config.get("clip_observations", 100.0)) self.clip_act_ = float(clip_config.get("clip_actions", 100.0)) self.default_angles_lab_ = np.array( policy_config["init_state"]["default_joint_angles"], dtype=np.float32, ) self.stiffness_lab_ = np.array(policy_config["gains"]["kp"], dtype=np.float32) self.damping_lab_ = np.array(policy_config["gains"]["kd"], dtype=np.float32) model_rel_path = policy_config["model_path"] self.model_path_ = os.path.normpath(os.path.join(current_dir, model_rel_path)) self._init_onnx_session() # sim2sim.py uses policy output in Isaac order and then maps to MuJoCo order. self.mujoco_to_policy_idx_ = np.array( [0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 19, 4, 10, 16, 20, 5, 11, 17, 21, 18, 22], dtype=int, ) self.policy_to_mujoco_idx_ = np.array( [0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8, 11, 15, 19, 21, 12, 16, 20, 22], dtype=int, ) # RobotData stores 29 joints in leg -> waist -> arm order. self.mujoco_control_indices_ = np.array( [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 22, 23, 24, 25, ], dtype=int, ) self.default_angles_mujoco23_ = self.default_angles_lab_[self.policy_to_mujoco_idx_] self.observations_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32) self.obs_history_ = np.zeros_like(self.observations_) self.actions_ = np.zeros(self.action_num_, dtype=np.float32) self.last_actions_ = np.zeros(self.action_num_, dtype=np.float32) self.current_gait_ = np.zeros(6, dtype=np.float32) self.hold_pose_29_ = np.zeros(self.motor_num_, dtype=np.float32) self._warm_start_pose_29_ = np.zeros(self.motor_num_, dtype=np.float32) self._first_obs = True self._policy_step_counter = 0 self.waiting_for_motion_ = True self.motion_threshold_ = 1e-3 if self.warm_start_time_ > 0 and self.policy_period_ > 0: self._warm_start_steps = max(1, int(self.warm_start_time_ / self.policy_period_)) else: self._warm_start_steps = 0 self._warmup_inference_counter = 0 self.kp_29_ = np.zeros(self.motor_num_, dtype=np.float32) self.kd_29_ = np.zeros(self.motor_num_, dtype=np.float32) for lab_idx, mj_idx in enumerate(self.mujoco_control_indices_[self.policy_to_mujoco_idx_]): self.kp_29_[mj_idx] = self.stiffness_lab_[lab_idx] self.kd_29_[mj_idx] = self.damping_lab_[lab_idx] def _init_onnx_session(self) -> None: options = ort.SessionOptions() options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL options.intra_op_num_threads = 1 options.inter_op_num_threads = 1 self.ort_session_ = ort.InferenceSession( self.model_path_, options, providers=["CPUExecutionProvider"], ) print(f"[{self.log_prefix}] ONNX model loaded: {self.model_path_}") def on_enter(self): self.observations_.fill(0.0) self.obs_history_.fill(0.0) self.actions_.fill(0.0) self.last_actions_.fill(0.0) self.current_gait_.fill(0.0) self._first_obs = True self._policy_step_counter = 0 self._warmup_inference_counter = 0 self.waiting_for_motion_ = True current_q = self.robot_data_.get_joint_pos().copy() self.hold_pose_29_ = current_q self._warm_start_pose_29_ = current_q base = self.robot_data_.q_d_.shape[0] - self.motor_num_ self.robot_data_.q_d_[base:base + self.motor_num_] = current_q self.robot_data_.q_dot_d_[base:base + self.motor_num_] = 0.0 self.robot_data_.tau_d_[base:base + self.motor_num_] = 0.0 self.robot_data_.joint_kp_p_[:self.motor_num_] = self.kp_29_ self.robot_data_.joint_kd_p_[:self.motor_num_] = self.kd_29_ print(f"[{self.log_prefix}] enter") def run(self, flag: ControlFlag): walk_cmd = np.clip( np.array(self.robot_data_.get_walk_cmd(), dtype=np.float32), -self.command_clip_, self.command_clip_, ) base = self.robot_data_.q_d_.shape[0] - self.motor_num_ if self.waiting_for_motion_: if np.max(np.abs(walk_cmd)) <= self.motion_threshold_: self.robot_data_.q_d_[base:base + self.motor_num_] = self.hold_pose_29_ self.robot_data_.q_dot_d_[base:base + self.motor_num_] = 0.0 self.robot_data_.tau_d_[base:base + self.motor_num_] = 0.0 self.robot_data_.joint_kp_p_[:self.motor_num_] = self.kp_29_ self.robot_data_.joint_kd_p_[:self.motor_num_] = self.kd_29_ return self.waiting_for_motion_ = False self._warm_start_pose_29_ = self.robot_data_.get_joint_pos().copy() print(f"[{self.log_prefix}] motion command detected: {walk_cmd}") if int(self.robot_data_.time_now_ / self.dt_) % self.decimation_ == 0: self.current_gait_ = self._compute_gait_features() self.compute_observation(walk_cmd) self.compute_actions() target_mujoco23 = ( self.actions_[self.policy_to_mujoco_idx_] * self.action_scale_ + self.default_angles_mujoco23_ ) target_q_29 = self.hold_pose_29_.copy() target_q_29[self.mujoco_control_indices_] = target_mujoco23 commanded_q_29 = target_q_29 if self._warm_start_steps > 0 and self._warmup_inference_counter < self._warm_start_steps: self._warmup_inference_counter += 1 blend = self._warmup_inference_counter / float(self._warm_start_steps) commanded_q_29 = (1.0 - blend) * self._warm_start_pose_29_ + blend * target_q_29 self.robot_data_.q_d_[base:base + self.motor_num_] = commanded_q_29 self.robot_data_.q_dot_d_[base:base + self.motor_num_] = 0.0 self.robot_data_.tau_d_[base:base + self.motor_num_] = 0.0 self.robot_data_.joint_kp_p_[:self.motor_num_] = self.kp_29_ self.robot_data_.joint_kd_p_[:self.motor_num_] = self.kd_29_ self.last_actions_[:] = self.actions_ def _compute_gait_features(self) -> np.ndarray: t = self._policy_step_counter * self.policy_period_ / self.gait_cycle_ gait_phase = (t + self.phase_offset_) % 1.0 self._policy_step_counter += 1 return np.concatenate( [ np.sin(2.0 * np.pi * gait_phase), np.cos(2.0 * np.pi * gait_phase), self.phase_ratio_, ], axis=0, ).astype(np.float32) def compute_observation(self, walk_cmd: np.ndarray): if np.linalg.norm(self.robot_data_.imu_quat_) > 0.0: q_wxyz = self.robot_data_.imu_quat_.astype(np.float32) q_xyzw = np.array([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]], dtype=np.float32) else: roll = float(self.robot_data_.imu_data_[2]) pitch = float(self.robot_data_.imu_data_[1]) yaw = float(self.robot_data_.imu_data_[0]) q_wxyz = self.euler_to_quaternion_scipy(roll, pitch, yaw) q_xyzw = np.array([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]], dtype=np.float32) gravity = self.quat_rotate_inverse_numpy(q_xyzw, np.array([0.0, 0.0, -1.0], dtype=np.float32)) q_29 = self.robot_data_.get_joint_pos() dq_29 = self.robot_data_.get_joint_vel() q_23 = q_29[self.mujoco_control_indices_] dq_23 = dq_29[self.mujoco_control_indices_] proprio = np.concatenate( [ self.robot_data_.get_angular_velocity(), gravity, walk_cmd, (q_23 - self.default_angles_mujoco23_)[self.mujoco_to_policy_idx_], dq_23[self.mujoco_to_policy_idx_], np.clip(self.last_actions_, -self.clip_act_, self.clip_act_), self.current_gait_, ], axis=0, ).astype(np.float32) if self._first_obs: for i in range(self.num_hist_): start = i * self.obs_size_ self.obs_history_[start:start + self.obs_size_] = proprio self._first_obs = False else: self.obs_history_ = np.roll(self.obs_history_, -self.obs_size_) self.obs_history_[-self.obs_size_:] = proprio self.observations_ = np.clip(self.obs_history_, -self.clip_obs_, self.clip_obs_) def compute_actions(self): input_name = self.ort_session_.get_inputs()[0].name input_data = self.observations_.reshape(1, -1).astype(np.float32) outputs = self.ort_session_.run(None, {input_name: input_data}) self.actions_[:] = np.clip(outputs[0][0][: self.action_num_], -self.clip_act_, self.clip_act_) def on_exit(self): print(f"[{self.log_prefix}] exit") def check_transition(self, flag: ControlFlag) -> FSMStateName: if flag.fsm_state_command == "gotoSTOP": return FSMStateName.STOP if flag.fsm_state_command == "gotoZERO": return FSMStateName.ZERO if flag.fsm_state_command == "gotoWALKAMP": return FSMStateName.WALKAMP if flag.fsm_state_command == "gotoMYPOLICY": return FSMStateName.MYPOLICY if flag.fsm_state_command == "gotoXSIMRUN": return FSMStateName.XSIMRUN return None @staticmethod def euler_to_quaternion_scipy(roll, pitch, yaw, degrees=False): r = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=degrees) q_xyzw = r.as_quat() return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]], dtype=np.float32) @staticmethod def quat_rotate_inverse_numpy(q_xyzw, v): q_w = q_xyzw[3] q_v = q_xyzw[:3] a = v * (2.0 * q_w * q_w - 1.0) b = np.cross(q_v, v) * (2.0 * q_w) c = q_v * (2.0 * np.dot(q_v, v)) return a - b + c