Files
tienkung-szu/Deploy_Tienkung/policy/mypolicy/fsm_mypolicy.py
2026-03-27 16:10:51 +08:00

374 lines
17 KiB
Python

"""
FSM state implementation for the standalone MYPOLICY controller.
"""
import os
import numpy as np
import onnxruntime as ort
import yaml
from scipy.spatial.transform import Rotation
from FSM.fsm_base import FSMState, FSMStateName
from common.BasicFunction import gait_phase
from common.joystick import ControlFlag
from common.robot_data import RobotData
class FSMStateMYPOLICY(FSMState):
"""Standalone FSM implementation for the custom ONNX policy."""
def __init__(self, robot_data: RobotData):
super().__init__(robot_data)
self.current_state_name = FSMStateName.MYPOLICY
self.log_prefix = "FSMStateMYPOLICY"
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "config", "mypolicy.yaml")
with open(config_path, "r") as f:
policy_config = yaml.safe_load(f)
self.action_num_ = policy_config.get("actions_size")
self.motor_num_ = policy_config.get("motor_num")
self.dt_ = policy_config.get("dt")
size_config = policy_config.get("size", {})
self.num_hist_ = size_config.get("num_hist")
self.obs_size_ = size_config.get("observations_size")
control_config = policy_config.get("control", {})
self.action_scale_ = control_config.get("action_scale")
self.decimation_ = control_config.get("decimation")
self.warm_start_time_ = control_config.get(
"warm_start_time",
policy_config.get("warm_start_time", 0.3),
)
norm_config = policy_config.get("normalization", {})
clip_config = norm_config.get("clip_scales", {})
obs_config = norm_config.get("obs_scales", {})
self.clip_obs_ = clip_config.get("clip_observations", 100.0)
self.clip_act_ = clip_config.get("clip_actions", 100.0)
self.lin_vel_scale_ = obs_config.get("lin_vel")
self.ang_vel_scale_ = obs_config.get("ang_vel")
self.dof_pos_scale_ = obs_config.get("dof_pos")
self.dof_vel_scale_ = obs_config.get("dof_vel")
self.observations_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
self.proprio_hist_buf_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
self.last_actions_ = np.zeros(self.action_num_, dtype=np.float32)
self.actions_ = np.zeros(self.action_num_, dtype=np.float32)
self._warm_start_pose = np.zeros(self.motor_num_, dtype=np.float32)
self.is_first_obs_ = True
self.is_first_action_ = True
self.is_first_step_ = True
self.timer_gait_ = 0.0
gait_config = policy_config.get("gait", {})
self.gait_cycle = gait_config.get("gait_cycle", 0.64)
self.left_phase_ratio = gait_config.get("gait_air_ratio_l", 0.6)
self.right_phase_ratio = gait_config.get("gait_air_ratio_r", 0.6)
self.left_theta_offset = gait_config.get("gait_phase_offset_l", 0.6)
self.right_theta_offset = gait_config.get("gait_phase_offset_r", 0.1)
step = (self.decimation_ if self.decimation_ else 1) * self.dt_
if self.warm_start_time_ > 0 and step > 0:
self._warm_start_steps = max(1, int(self.warm_start_time_ / step))
else:
self._warm_start_steps = 0
self._warmup_inference_counter = 0
self.waiting_for_motion = True
self.motion_threshold = 1e-3
self.hold_pose = np.zeros(self.motor_num_, dtype=np.float32)
self.filtered_x_speed = 0.0
self.model_path = os.path.join(current_dir, "model", policy_config["model_path"])
self._init_onnx_session()
joint_names = policy_config.get("joint_names")
if joint_names is None:
raise ValueError("[FSMStateMYPOLICY] Missing 'joint_names' in mypolicy.yaml")
self.joint_seq = list(joint_names)
if self.action_scale_ is None:
raise ValueError("[FSMStateMYPOLICY] Missing 'control.action_scale' in mypolicy.yaml")
if np.isscalar(self.action_scale_):
self.action_scale = np.full(len(self.joint_seq), float(self.action_scale_), dtype=np.float32)
else:
self.action_scale = np.array(self.action_scale_, dtype=np.float32)
if len(self.action_scale) != len(self.joint_seq):
raise ValueError(
f"[FSMStateMYPOLICY] control.action_scale length {len(self.action_scale)} does not match joint count {len(self.joint_seq)}"
)
init_state_config = policy_config.get("init_state", {})
default_joint_angles = init_state_config.get("default_joint_angles")
if default_joint_angles is None:
raise ValueError("[FSMStateMYPOLICY] Missing 'init_state.default_joint_angles' in mypolicy.yaml")
self.joint_pos_array_seq = np.array(default_joint_angles, dtype=np.float32)
if len(self.joint_pos_array_seq) != len(self.joint_seq):
raise ValueError(
f"[FSMStateMYPOLICY] init_state.default_joint_angles length {len(self.joint_pos_array_seq)} does not match joint count {len(self.joint_seq)}"
)
gains_config = policy_config.get("gains", {})
kp_values = gains_config.get("kp")
kd_values = gains_config.get("kd")
if kp_values is None or kd_values is None:
raise ValueError("[FSMStateMYPOLICY] Missing 'gains.kp' or 'gains.kd' in mypolicy.yaml")
self.stiffness_array_seq = np.array(kp_values, dtype=np.float32)
self.damping_array_seq = np.array(kd_values, dtype=np.float32)
if len(self.stiffness_array_seq) != len(self.joint_seq):
raise ValueError(
f"[FSMStateMYPOLICY] gains.kp length {len(self.stiffness_array_seq)} does not match joint count {len(self.joint_seq)}"
)
if len(self.damping_array_seq) != len(self.joint_seq):
raise ValueError(
f"[FSMStateMYPOLICY] gains.kd length {len(self.damping_array_seq)} does not match joint count {len(self.joint_seq)}"
)
self.joint_xml = [
"hip_pitch_l_joint", "hip_roll_l_joint", "hip_yaw_l_joint",
"knee_pitch_l_joint", "ankle_pitch_l_joint", "ankle_roll_l_joint",
"hip_pitch_r_joint", "hip_roll_r_joint", "hip_yaw_r_joint",
"knee_pitch_r_joint", "ankle_pitch_r_joint", "ankle_roll_r_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"shoulder_pitch_l_joint", "shoulder_roll_l_joint", "shoulder_yaw_l_joint",
"elbow_pitch_l_joint", "elbow_yaw_l_joint", "wrist_pitch_l_joint", "wrist_roll_l_joint",
"shoulder_pitch_r_joint", "shoulder_roll_r_joint", "shoulder_yaw_r_joint",
"elbow_pitch_r_joint", "elbow_yaw_r_joint", "wrist_pitch_r_joint", "wrist_roll_r_joint",
]
self.lab2mj = []
for name in self.joint_seq:
if name not in self.joint_xml:
raise ValueError(f"[FSMStateMYPOLICY] joint '{name}' from mypolicy.yaml not found in joint_xml")
self.lab2mj.append(self.joint_xml.index(name))
self.lab2mj = np.array(self.lab2mj, dtype=int)
n_mj = len(self.joint_xml)
self.joint_pos_array = np.zeros(n_mj, dtype=np.float32)
self.stiffness_array = np.zeros(n_mj, dtype=np.float32)
self.damping_array = np.zeros(n_mj, dtype=np.float32)
for lab_idx, mj_idx in enumerate(self.lab2mj):
self.joint_pos_array[mj_idx] = self.joint_pos_array_seq[lab_idx]
self.stiffness_array[mj_idx] = self.stiffness_array_seq[lab_idx]
self.damping_array[mj_idx] = self.damping_array_seq[lab_idx]
self.kps_lab = self.stiffness_array_seq
self.kds_lab = self.damping_array_seq
self.default_angles_lab = self.joint_pos_array_seq
self.action_scale_lab = self.action_scale
def _init_onnx_session(self):
try:
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
options.enable_mem_pattern = False
options.enable_mem_reuse = True
self.ort_session_ = ort.InferenceSession(
self.model_path,
options,
providers=["CPUExecutionProvider"],
)
print(f"[{self.log_prefix}-ONNX] ONNX model loaded successfully: {self.model_path}")
except Exception as e:
print(f"[{self.log_prefix}] Failed to load ONNX model: {e}")
self.ort_session_ = None
def _reset_internal_state(self):
self.observations_.fill(0.0)
self.proprio_hist_buf_.fill(0.0)
self.last_actions_.fill(0.0)
self.actions_.fill(0.0)
self.is_first_obs_ = True
self.is_first_action_ = True
self.is_first_step_ = True
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = self.joint_pos_array
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
def on_enter(self):
self._reset_internal_state()
print(f"[{self.log_prefix}] enter")
self.timer_gait_ = 0.0
self.waiting_for_motion = True
self._warmup_inference_counter = 0
if self.robot_data_ is not None:
try:
current_pose = self.robot_data_.get_joint_pos().copy()
self._warm_start_pose = current_pose
self.hold_pose = current_pose
except Exception:
self._warm_start_pose.fill(0.0)
self.hold_pose.fill(0.0)
else:
self._warm_start_pose.fill(0.0)
self.hold_pose.fill(0.0)
print(f"[{self.log_prefix}] waiting for motion command before starting policy")
def run(self, flag: ControlFlag):
walk_cmd = np.array(self.robot_data_.get_walk_cmd(), dtype=np.float32)
if self.waiting_for_motion:
if np.max(np.abs(walk_cmd)) <= self.motion_threshold:
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = self.hold_pose
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
self.robot_data_.joint_kp_p_[:len(self.joint_xml)] = self.stiffness_array
self.robot_data_.joint_kd_p_[:len(self.joint_xml)] = self.damping_array
return
self.waiting_for_motion = False
self._warm_start_pose = self.robot_data_.get_joint_pos().copy()
self._warmup_inference_counter = 0
print(f"[{self.log_prefix}] motion command detected: {walk_cmd}, policy activated")
print(f"[{self.log_prefix}] run")
gait = gait_phase(
self.timer_gait_,
self.gait_cycle,
self.left_theta_offset,
self.right_theta_offset,
self.left_phase_ratio,
self.right_phase_ratio,
).astype(np.float32)
if int(self.robot_data_.time_now_ / self.dt_) % self.decimation_ == 0:
self.compute_observation(flag, gait)
self.compute_actions()
target_dof_pos_lab = self.actions_ * self.action_scale_lab + self.default_angles_lab
target_dof_pos_mj = self.robot_data_.get_joint_pos().copy()
target_dof_pos_mj[self.lab2mj] = target_dof_pos_lab
commanded_pos = target_dof_pos_mj
if self._warm_start_steps > 0 and self._warmup_inference_counter < self._warm_start_steps:
self._warmup_inference_counter += 1
blend = self._warmup_inference_counter / float(self._warm_start_steps)
commanded_pos = (1.0 - blend) * self._warm_start_pose + blend * target_dof_pos_mj
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = commanded_pos
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
self.last_actions_[:] = self.actions_
self.timer_gait_ += self.dt_
self.robot_data_.joint_kp_p_[:len(self.joint_xml)] = self.stiffness_array
self.robot_data_.joint_kd_p_[:len(self.joint_xml)] = self.damping_array
def compute_observation(self, flag: ControlFlag, gait):
roll, pitch, yaw = (
float(self.robot_data_.imu_data_[2]),
float(self.robot_data_.imu_data_[1]),
float(self.robot_data_.imu_data_[0]),
)
quat_wxyz = self.euler_to_quaternion_scipy(roll, pitch, yaw)
q_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]], dtype=np.float32)
gravity_init = self.quat_rotate_inverse_numpy(q_xyzw, np.array([0.0, 0.0, -1.0], dtype=np.float32))
x_speed_command, y_speed_command, yaw_speed_command = self.robot_data_.get_walk_cmd()
new_filtered_x_speed = x_speed_command
change = new_filtered_x_speed - self.filtered_x_speed
change = np.clip(change, -0.005, 0.005)
self.filtered_x_speed = self.filtered_x_speed + change
command = np.array(
[x_speed_command, y_speed_command, yaw_speed_command],
dtype=np.float32,
)
print(f"Input command: {command}")
ang_vel = self.robot_data_.get_angular_velocity()
q_mj = self.robot_data_.get_joint_pos()
dq_mj = self.robot_data_.get_joint_vel()
qj = q_mj[self.lab2mj] - self.default_angles_lab
dqj = dq_mj[self.lab2mj]
proprio = np.concatenate([
ang_vel,
gravity_init,
command,
qj,
dqj,
self.last_actions_,
gait,
])
if self.is_first_obs_:
for i in range(self.num_hist_):
start_idx = i * self.obs_size_
end_idx = start_idx + self.obs_size_
self.proprio_hist_buf_[start_idx:end_idx] = proprio
self.is_first_obs_ = False
else:
shift_size = (self.num_hist_ - 1) * self.obs_size_
self.proprio_hist_buf_[:shift_size] = self.proprio_hist_buf_[self.obs_size_:]
self.proprio_hist_buf_[shift_size:] = proprio
self.observations_ = np.clip(self.proprio_hist_buf_, -self.clip_obs_, self.clip_obs_)
def compute_actions(self):
if self.ort_session_ is None:
return
try:
input_data = self.observations_.reshape(1, -1).astype(np.float32)
input_name = self.ort_session_.get_inputs()[0].name
outputs = self.ort_session_.run(None, {input_name: input_data})
output_data = outputs[0][0]
for i in range(self.action_num_):
self.actions_[i] = np.clip(output_data[i], -self.clip_act_, self.clip_act_)
if self.is_first_action_:
print(f"[{self.log_prefix}-ONNX] First Observation:")
for i in range(self.obs_size_):
print(f"{self.observations_[i]:.6f} ", end="")
print()
self.is_first_action_ = False
except Exception as e:
print(f"[{self.log_prefix}] ONNX Runtime inference error: {e}")
def on_exit(self):
print(f"[{self.log_prefix}] exit")
if getattr(self, "obs_log_file", None) is not None:
try:
self.obs_log_file.flush()
self.obs_log_file.close()
print(f"[{self.log_prefix}] obs log saved to {self.obs_log_path}")
except Exception as e:
print(f"[{self.log_prefix}] failed to close obs log: {e}")
self.obs_log_file = None
def check_transition(self, flag: ControlFlag) -> FSMStateName:
if flag.fsm_state_command == "gotoSTOP":
return FSMStateName.STOP
if flag.fsm_state_command == "gotoWALKAMP":
return FSMStateName.WALKAMP
if flag.fsm_state_command == "gotoMYPOLICY":
return FSMStateName.MYPOLICY
if flag.fsm_state_command == "gotoXSIMRUN":
return FSMStateName.XSIMRUN
if flag.fsm_state_command == "gotoZERO":
return FSMStateName.ZERO
return None
@staticmethod
def euler_to_quaternion_scipy(roll, pitch, yaw, degrees=False):
r = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=degrees)
q_xyzw = r.as_quat()
return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]], dtype=np.float32)
@staticmethod
def quat_rotate_inverse_numpy(q_xyzw, v):
q_w = q_xyzw[3]
q_v = q_xyzw[:3]
a = v * (2.0 * q_w * q_w - 1.0)
b = np.cross(q_v, v) * (2.0 * q_w)
c = q_v * (2.0 * np.dot(q_v, v))
return a - b + c