374 lines
17 KiB
Python
374 lines
17 KiB
Python
"""
|
|
FSM state implementation for the standalone MYPOLICY controller.
|
|
"""
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import yaml
|
|
from scipy.spatial.transform import Rotation
|
|
|
|
from FSM.fsm_base import FSMState, FSMStateName
|
|
from common.BasicFunction import gait_phase
|
|
from common.joystick import ControlFlag
|
|
from common.robot_data import RobotData
|
|
|
|
|
|
class FSMStateMYPOLICY(FSMState):
|
|
"""Standalone FSM implementation for the custom ONNX policy."""
|
|
|
|
def __init__(self, robot_data: RobotData):
|
|
super().__init__(robot_data)
|
|
self.current_state_name = FSMStateName.MYPOLICY
|
|
self.log_prefix = "FSMStateMYPOLICY"
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
config_path = os.path.join(current_dir, "config", "mypolicy.yaml")
|
|
with open(config_path, "r") as f:
|
|
policy_config = yaml.safe_load(f)
|
|
|
|
self.action_num_ = policy_config.get("actions_size")
|
|
self.motor_num_ = policy_config.get("motor_num")
|
|
self.dt_ = policy_config.get("dt")
|
|
|
|
size_config = policy_config.get("size", {})
|
|
self.num_hist_ = size_config.get("num_hist")
|
|
self.obs_size_ = size_config.get("observations_size")
|
|
|
|
control_config = policy_config.get("control", {})
|
|
self.action_scale_ = control_config.get("action_scale")
|
|
self.decimation_ = control_config.get("decimation")
|
|
self.warm_start_time_ = control_config.get(
|
|
"warm_start_time",
|
|
policy_config.get("warm_start_time", 0.3),
|
|
)
|
|
|
|
norm_config = policy_config.get("normalization", {})
|
|
clip_config = norm_config.get("clip_scales", {})
|
|
obs_config = norm_config.get("obs_scales", {})
|
|
self.clip_obs_ = clip_config.get("clip_observations", 100.0)
|
|
self.clip_act_ = clip_config.get("clip_actions", 100.0)
|
|
self.lin_vel_scale_ = obs_config.get("lin_vel")
|
|
self.ang_vel_scale_ = obs_config.get("ang_vel")
|
|
self.dof_pos_scale_ = obs_config.get("dof_pos")
|
|
self.dof_vel_scale_ = obs_config.get("dof_vel")
|
|
|
|
self.observations_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
|
|
self.proprio_hist_buf_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
|
|
self.last_actions_ = np.zeros(self.action_num_, dtype=np.float32)
|
|
self.actions_ = np.zeros(self.action_num_, dtype=np.float32)
|
|
self._warm_start_pose = np.zeros(self.motor_num_, dtype=np.float32)
|
|
|
|
self.is_first_obs_ = True
|
|
self.is_first_action_ = True
|
|
self.is_first_step_ = True
|
|
self.timer_gait_ = 0.0
|
|
|
|
gait_config = policy_config.get("gait", {})
|
|
self.gait_cycle = gait_config.get("gait_cycle", 0.64)
|
|
self.left_phase_ratio = gait_config.get("gait_air_ratio_l", 0.6)
|
|
self.right_phase_ratio = gait_config.get("gait_air_ratio_r", 0.6)
|
|
self.left_theta_offset = gait_config.get("gait_phase_offset_l", 0.6)
|
|
self.right_theta_offset = gait_config.get("gait_phase_offset_r", 0.1)
|
|
|
|
step = (self.decimation_ if self.decimation_ else 1) * self.dt_
|
|
if self.warm_start_time_ > 0 and step > 0:
|
|
self._warm_start_steps = max(1, int(self.warm_start_time_ / step))
|
|
else:
|
|
self._warm_start_steps = 0
|
|
self._warmup_inference_counter = 0
|
|
|
|
self.waiting_for_motion = True
|
|
self.motion_threshold = 1e-3
|
|
self.hold_pose = np.zeros(self.motor_num_, dtype=np.float32)
|
|
self.filtered_x_speed = 0.0
|
|
|
|
self.model_path = os.path.join(current_dir, "model", policy_config["model_path"])
|
|
self._init_onnx_session()
|
|
|
|
joint_names = policy_config.get("joint_names")
|
|
if joint_names is None:
|
|
raise ValueError("[FSMStateMYPOLICY] Missing 'joint_names' in mypolicy.yaml")
|
|
self.joint_seq = list(joint_names)
|
|
|
|
if self.action_scale_ is None:
|
|
raise ValueError("[FSMStateMYPOLICY] Missing 'control.action_scale' in mypolicy.yaml")
|
|
if np.isscalar(self.action_scale_):
|
|
self.action_scale = np.full(len(self.joint_seq), float(self.action_scale_), dtype=np.float32)
|
|
else:
|
|
self.action_scale = np.array(self.action_scale_, dtype=np.float32)
|
|
if len(self.action_scale) != len(self.joint_seq):
|
|
raise ValueError(
|
|
f"[FSMStateMYPOLICY] control.action_scale length {len(self.action_scale)} does not match joint count {len(self.joint_seq)}"
|
|
)
|
|
|
|
init_state_config = policy_config.get("init_state", {})
|
|
default_joint_angles = init_state_config.get("default_joint_angles")
|
|
if default_joint_angles is None:
|
|
raise ValueError("[FSMStateMYPOLICY] Missing 'init_state.default_joint_angles' in mypolicy.yaml")
|
|
self.joint_pos_array_seq = np.array(default_joint_angles, dtype=np.float32)
|
|
if len(self.joint_pos_array_seq) != len(self.joint_seq):
|
|
raise ValueError(
|
|
f"[FSMStateMYPOLICY] init_state.default_joint_angles length {len(self.joint_pos_array_seq)} does not match joint count {len(self.joint_seq)}"
|
|
)
|
|
|
|
gains_config = policy_config.get("gains", {})
|
|
kp_values = gains_config.get("kp")
|
|
kd_values = gains_config.get("kd")
|
|
if kp_values is None or kd_values is None:
|
|
raise ValueError("[FSMStateMYPOLICY] Missing 'gains.kp' or 'gains.kd' in mypolicy.yaml")
|
|
self.stiffness_array_seq = np.array(kp_values, dtype=np.float32)
|
|
self.damping_array_seq = np.array(kd_values, dtype=np.float32)
|
|
if len(self.stiffness_array_seq) != len(self.joint_seq):
|
|
raise ValueError(
|
|
f"[FSMStateMYPOLICY] gains.kp length {len(self.stiffness_array_seq)} does not match joint count {len(self.joint_seq)}"
|
|
)
|
|
if len(self.damping_array_seq) != len(self.joint_seq):
|
|
raise ValueError(
|
|
f"[FSMStateMYPOLICY] gains.kd length {len(self.damping_array_seq)} does not match joint count {len(self.joint_seq)}"
|
|
)
|
|
|
|
self.joint_xml = [
|
|
"hip_pitch_l_joint", "hip_roll_l_joint", "hip_yaw_l_joint",
|
|
"knee_pitch_l_joint", "ankle_pitch_l_joint", "ankle_roll_l_joint",
|
|
"hip_pitch_r_joint", "hip_roll_r_joint", "hip_yaw_r_joint",
|
|
"knee_pitch_r_joint", "ankle_pitch_r_joint", "ankle_roll_r_joint",
|
|
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
|
"shoulder_pitch_l_joint", "shoulder_roll_l_joint", "shoulder_yaw_l_joint",
|
|
"elbow_pitch_l_joint", "elbow_yaw_l_joint", "wrist_pitch_l_joint", "wrist_roll_l_joint",
|
|
"shoulder_pitch_r_joint", "shoulder_roll_r_joint", "shoulder_yaw_r_joint",
|
|
"elbow_pitch_r_joint", "elbow_yaw_r_joint", "wrist_pitch_r_joint", "wrist_roll_r_joint",
|
|
]
|
|
|
|
self.lab2mj = []
|
|
for name in self.joint_seq:
|
|
if name not in self.joint_xml:
|
|
raise ValueError(f"[FSMStateMYPOLICY] joint '{name}' from mypolicy.yaml not found in joint_xml")
|
|
self.lab2mj.append(self.joint_xml.index(name))
|
|
self.lab2mj = np.array(self.lab2mj, dtype=int)
|
|
|
|
n_mj = len(self.joint_xml)
|
|
self.joint_pos_array = np.zeros(n_mj, dtype=np.float32)
|
|
self.stiffness_array = np.zeros(n_mj, dtype=np.float32)
|
|
self.damping_array = np.zeros(n_mj, dtype=np.float32)
|
|
for lab_idx, mj_idx in enumerate(self.lab2mj):
|
|
self.joint_pos_array[mj_idx] = self.joint_pos_array_seq[lab_idx]
|
|
self.stiffness_array[mj_idx] = self.stiffness_array_seq[lab_idx]
|
|
self.damping_array[mj_idx] = self.damping_array_seq[lab_idx]
|
|
|
|
self.kps_lab = self.stiffness_array_seq
|
|
self.kds_lab = self.damping_array_seq
|
|
self.default_angles_lab = self.joint_pos_array_seq
|
|
self.action_scale_lab = self.action_scale
|
|
|
|
def _init_onnx_session(self):
|
|
try:
|
|
options = ort.SessionOptions()
|
|
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
options.intra_op_num_threads = 1
|
|
options.inter_op_num_threads = 1
|
|
options.enable_mem_pattern = False
|
|
options.enable_mem_reuse = True
|
|
self.ort_session_ = ort.InferenceSession(
|
|
self.model_path,
|
|
options,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
print(f"[{self.log_prefix}-ONNX] ONNX model loaded successfully: {self.model_path}")
|
|
except Exception as e:
|
|
print(f"[{self.log_prefix}] Failed to load ONNX model: {e}")
|
|
self.ort_session_ = None
|
|
|
|
def _reset_internal_state(self):
|
|
self.observations_.fill(0.0)
|
|
self.proprio_hist_buf_.fill(0.0)
|
|
self.last_actions_.fill(0.0)
|
|
self.actions_.fill(0.0)
|
|
self.is_first_obs_ = True
|
|
self.is_first_action_ = True
|
|
self.is_first_step_ = True
|
|
|
|
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = self.joint_pos_array
|
|
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
|
|
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
|
|
|
|
def on_enter(self):
|
|
self._reset_internal_state()
|
|
print(f"[{self.log_prefix}] enter")
|
|
self.timer_gait_ = 0.0
|
|
self.waiting_for_motion = True
|
|
self._warmup_inference_counter = 0
|
|
if self.robot_data_ is not None:
|
|
try:
|
|
current_pose = self.robot_data_.get_joint_pos().copy()
|
|
self._warm_start_pose = current_pose
|
|
self.hold_pose = current_pose
|
|
except Exception:
|
|
self._warm_start_pose.fill(0.0)
|
|
self.hold_pose.fill(0.0)
|
|
else:
|
|
self._warm_start_pose.fill(0.0)
|
|
self.hold_pose.fill(0.0)
|
|
print(f"[{self.log_prefix}] waiting for motion command before starting policy")
|
|
|
|
def run(self, flag: ControlFlag):
|
|
walk_cmd = np.array(self.robot_data_.get_walk_cmd(), dtype=np.float32)
|
|
if self.waiting_for_motion:
|
|
if np.max(np.abs(walk_cmd)) <= self.motion_threshold:
|
|
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = self.hold_pose
|
|
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
|
|
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
|
|
self.robot_data_.joint_kp_p_[:len(self.joint_xml)] = self.stiffness_array
|
|
self.robot_data_.joint_kd_p_[:len(self.joint_xml)] = self.damping_array
|
|
return
|
|
self.waiting_for_motion = False
|
|
self._warm_start_pose = self.robot_data_.get_joint_pos().copy()
|
|
self._warmup_inference_counter = 0
|
|
print(f"[{self.log_prefix}] motion command detected: {walk_cmd}, policy activated")
|
|
|
|
print(f"[{self.log_prefix}] run")
|
|
gait = gait_phase(
|
|
self.timer_gait_,
|
|
self.gait_cycle,
|
|
self.left_theta_offset,
|
|
self.right_theta_offset,
|
|
self.left_phase_ratio,
|
|
self.right_phase_ratio,
|
|
).astype(np.float32)
|
|
|
|
if int(self.robot_data_.time_now_ / self.dt_) % self.decimation_ == 0:
|
|
self.compute_observation(flag, gait)
|
|
self.compute_actions()
|
|
|
|
target_dof_pos_lab = self.actions_ * self.action_scale_lab + self.default_angles_lab
|
|
target_dof_pos_mj = self.robot_data_.get_joint_pos().copy()
|
|
target_dof_pos_mj[self.lab2mj] = target_dof_pos_lab
|
|
commanded_pos = target_dof_pos_mj
|
|
if self._warm_start_steps > 0 and self._warmup_inference_counter < self._warm_start_steps:
|
|
self._warmup_inference_counter += 1
|
|
blend = self._warmup_inference_counter / float(self._warm_start_steps)
|
|
commanded_pos = (1.0 - blend) * self._warm_start_pose + blend * target_dof_pos_mj
|
|
|
|
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
|
|
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = commanded_pos
|
|
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
|
|
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
|
|
self.last_actions_[:] = self.actions_
|
|
|
|
self.timer_gait_ += self.dt_
|
|
self.robot_data_.joint_kp_p_[:len(self.joint_xml)] = self.stiffness_array
|
|
self.robot_data_.joint_kd_p_[:len(self.joint_xml)] = self.damping_array
|
|
|
|
def compute_observation(self, flag: ControlFlag, gait):
|
|
roll, pitch, yaw = (
|
|
float(self.robot_data_.imu_data_[2]),
|
|
float(self.robot_data_.imu_data_[1]),
|
|
float(self.robot_data_.imu_data_[0]),
|
|
)
|
|
quat_wxyz = self.euler_to_quaternion_scipy(roll, pitch, yaw)
|
|
q_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]], dtype=np.float32)
|
|
gravity_init = self.quat_rotate_inverse_numpy(q_xyzw, np.array([0.0, 0.0, -1.0], dtype=np.float32))
|
|
|
|
x_speed_command, y_speed_command, yaw_speed_command = self.robot_data_.get_walk_cmd()
|
|
new_filtered_x_speed = x_speed_command
|
|
change = new_filtered_x_speed - self.filtered_x_speed
|
|
change = np.clip(change, -0.005, 0.005)
|
|
self.filtered_x_speed = self.filtered_x_speed + change
|
|
command = np.array(
|
|
[x_speed_command, y_speed_command, yaw_speed_command],
|
|
dtype=np.float32,
|
|
)
|
|
print(f"Input command: {command}")
|
|
|
|
ang_vel = self.robot_data_.get_angular_velocity()
|
|
q_mj = self.robot_data_.get_joint_pos()
|
|
dq_mj = self.robot_data_.get_joint_vel()
|
|
qj = q_mj[self.lab2mj] - self.default_angles_lab
|
|
dqj = dq_mj[self.lab2mj]
|
|
|
|
proprio = np.concatenate([
|
|
ang_vel,
|
|
gravity_init,
|
|
command,
|
|
qj,
|
|
dqj,
|
|
self.last_actions_,
|
|
gait,
|
|
])
|
|
|
|
if self.is_first_obs_:
|
|
for i in range(self.num_hist_):
|
|
start_idx = i * self.obs_size_
|
|
end_idx = start_idx + self.obs_size_
|
|
self.proprio_hist_buf_[start_idx:end_idx] = proprio
|
|
self.is_first_obs_ = False
|
|
else:
|
|
shift_size = (self.num_hist_ - 1) * self.obs_size_
|
|
self.proprio_hist_buf_[:shift_size] = self.proprio_hist_buf_[self.obs_size_:]
|
|
self.proprio_hist_buf_[shift_size:] = proprio
|
|
|
|
self.observations_ = np.clip(self.proprio_hist_buf_, -self.clip_obs_, self.clip_obs_)
|
|
|
|
def compute_actions(self):
|
|
if self.ort_session_ is None:
|
|
return
|
|
|
|
try:
|
|
input_data = self.observations_.reshape(1, -1).astype(np.float32)
|
|
input_name = self.ort_session_.get_inputs()[0].name
|
|
outputs = self.ort_session_.run(None, {input_name: input_data})
|
|
output_data = outputs[0][0]
|
|
for i in range(self.action_num_):
|
|
self.actions_[i] = np.clip(output_data[i], -self.clip_act_, self.clip_act_)
|
|
|
|
if self.is_first_action_:
|
|
print(f"[{self.log_prefix}-ONNX] First Observation:")
|
|
for i in range(self.obs_size_):
|
|
print(f"{self.observations_[i]:.6f} ", end="")
|
|
print()
|
|
self.is_first_action_ = False
|
|
except Exception as e:
|
|
print(f"[{self.log_prefix}] ONNX Runtime inference error: {e}")
|
|
|
|
def on_exit(self):
|
|
print(f"[{self.log_prefix}] exit")
|
|
if getattr(self, "obs_log_file", None) is not None:
|
|
try:
|
|
self.obs_log_file.flush()
|
|
self.obs_log_file.close()
|
|
print(f"[{self.log_prefix}] obs log saved to {self.obs_log_path}")
|
|
except Exception as e:
|
|
print(f"[{self.log_prefix}] failed to close obs log: {e}")
|
|
self.obs_log_file = None
|
|
|
|
def check_transition(self, flag: ControlFlag) -> FSMStateName:
|
|
if flag.fsm_state_command == "gotoSTOP":
|
|
return FSMStateName.STOP
|
|
if flag.fsm_state_command == "gotoWALKAMP":
|
|
return FSMStateName.WALKAMP
|
|
if flag.fsm_state_command == "gotoMYPOLICY":
|
|
return FSMStateName.MYPOLICY
|
|
if flag.fsm_state_command == "gotoXSIMRUN":
|
|
return FSMStateName.XSIMRUN
|
|
if flag.fsm_state_command == "gotoZERO":
|
|
return FSMStateName.ZERO
|
|
return None
|
|
|
|
@staticmethod
|
|
def euler_to_quaternion_scipy(roll, pitch, yaw, degrees=False):
|
|
r = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=degrees)
|
|
q_xyzw = r.as_quat()
|
|
return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]], dtype=np.float32)
|
|
|
|
@staticmethod
|
|
def quat_rotate_inverse_numpy(q_xyzw, v):
|
|
q_w = q_xyzw[3]
|
|
q_v = q_xyzw[:3]
|
|
a = v * (2.0 * q_w * q_w - 1.0)
|
|
b = np.cross(q_v, v) * (2.0 * q_w)
|
|
c = q_v * (2.0 * np.dot(q_v, v))
|
|
return a - b + c
|