Files
tienkung-szu/Deploy_Tienkung/policy/walk_amp/fsm_walkamp.py
2026-03-30 15:30:30 +08:00

453 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FSM State Implementations
Concrete implementations of different FSM states
"""
import numpy as np
import onnxruntime as ort
from FSM.fsm_base import FSMState, FSMStateName
from common.joystick import ControlFlag
from common.robot_data import RobotData
from common.BasicFunction import clip_vector, gait_phase
import os
import yaml
from scipy.spatial.transform import Rotation
class FSMStateWALKAMP(FSMState):
"""WALKAMP策略状态实现"""
def _reset_internal_state(self):
"""把所有随时间变化的内部状态重置成初始值"""
# 1) 清空 obs / hist / actions
self.observations_.fill(0.0)
self.proprio_hist_buf_.fill(0.0)
self.last_actions_.fill(0.0)
self.actions_.fill(0.0)
# 2) 标志位重置
self.is_first_obs_ = True
self.is_first_action_ = True
self.is_first_step_ = True
# 3) 期望关节 / 期望速度 / 力矩重置为“初始姿态”
# 你已经有 self.joint_pos_arraymj 顺序,长度 len(self.joint_xml)
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
# 期望角 = 初始角
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = self.joint_pos_array
# 期望速度 = 0
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
# 期望力矩 = 0位置控制
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
def __init__(self,
robot_data: RobotData,
config_path: str | None = None,
base_dir: str | None = None,
log_prefix: str = "FSMStateWALKAMP"):
super().__init__(robot_data)
self.log_prefix = log_prefix
# 获取包路径
current_dir = os.path.dirname(os.path.abspath(__file__))
if config_path is None:
config_path = os.path.join(current_dir, "config", "walk_amp.yaml")
config_path = os.path.abspath(config_path)
if base_dir is None:
base_dir = os.path.dirname(os.path.dirname(config_path))
with open(config_path, 'r') as f:
policy_config = yaml.safe_load(f)
# Load configuration exactly like C++
self.action_num_ = policy_config.get('actions_size')
self.motor_num_ = policy_config.get('motor_num')
self.dt_ = policy_config.get('dt')
# Size configuration
size_config = policy_config.get('size', {})
self.num_hist_ = size_config.get('num_hist')
self.obs_size_ = size_config.get('observations_size')
# Control configuration
control_config = policy_config.get('control', {})
self.action_scale_ = control_config.get('action_scale')
# self.gait_cycle_period_ = control_config.get('gait_cycle_period', 1.0)
self.decimation_ = control_config.get('decimation')
self.warm_start_time_ = control_config.get(
'warm_start_time',
policy_config.get('warm_start_time', 0.3),
)
# Normalization configuration
norm_config = policy_config.get('normalization', {})
clip_config = norm_config.get('clip_scales', {})
obs_config = norm_config.get('obs_scales', {})
self.clip_obs_ = clip_config.get('clip_observations', 100.0)
self.clip_act_ = clip_config.get('clip_actions', 100.0)
self.lin_vel_scale_ = obs_config.get('lin_vel')
self.ang_vel_scale_ = obs_config.get('ang_vel')
self.dof_pos_scale_ = obs_config.get('dof_pos')
self.dof_vel_scale_ = obs_config.get('dof_vel')
# Initialize buffers and actions
self.observations_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
self.proprio_hist_buf_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
self.last_actions_ = np.zeros(self.action_num_, dtype=np.float32)
self.actions_ = np.zeros(self.action_num_, dtype=np.float32)
self._warm_start_pose = np.zeros(self.motor_num_, dtype=np.float32)
# Flags matching C++
self.is_first_obs_ = True
self.is_first_action_ = True
# self.phase_locked = False
self.timer_gait_ = 0.0
gait_config = policy_config.get('gait', {})
self.gait_cycle = gait_config.get('gait_cycle', 0.85)
self.left_phase_ratio = gait_config.get('gait_air_ratio_l', 0.38)
self.right_phase_ratio = gait_config.get('gait_air_ratio_r', 0.38)
self.left_theta_offset = gait_config.get('gait_phase_offset_l', 0.38)
self.right_theta_offset = gait_config.get('gait_phase_offset_r', 0.88)
self.is_first_step_ = True
step = (self.decimation_ if self.decimation_ else 1) * self.dt_
if self.warm_start_time_ > 0 and step > 0:
self._warm_start_steps = max(1, int(self.warm_start_time_ / step))
else:
self._warm_start_steps = 0
self._warmup_inference_counter = 0
# Initialize ONNX session
self.model_path = os.path.join(base_dir, "model", policy_config["model_path"])
self._init_onnx_session()
self.joint_seq = None
self.joint_pos_array_seq = None
self.action_scale = None
self.stiffness_array_seq = None
self.damping_array_seq = None
joint_names = policy_config.get('joint_names')
if joint_names is None:
raise ValueError("[FSMStateWALKAMP] Missing 'joint_names' in walk_amp.yaml")
self.joint_seq = list(joint_names)
if self.action_scale_ is None:
raise ValueError("[FSMStateWALKAMP] Missing 'control.action_scale' in walk_amp.yaml")
if np.isscalar(self.action_scale_):
self.action_scale = np.full(len(self.joint_seq), float(self.action_scale_), dtype=np.float32)
else:
self.action_scale = np.array(self.action_scale_, dtype=np.float32)
if len(self.action_scale) != len(self.joint_seq):
raise ValueError(
f"[FSMStateWALKAMP] control.action_scale length {len(self.action_scale)} does not match joint count {len(self.joint_seq)}"
)
init_state_config = policy_config.get('init_state', {})
default_joint_angles = init_state_config.get('default_joint_angles')
if default_joint_angles is None:
raise ValueError("[FSMStateWALKAMP] Missing 'init_state.default_joint_angles' in walk_amp.yaml")
self.joint_pos_array_seq = np.array(default_joint_angles, dtype=np.float32)
if len(self.joint_pos_array_seq) != len(self.joint_seq):
raise ValueError(
f"[FSMStateWALKAMP] init_state.default_joint_angles length {len(self.joint_pos_array_seq)} does not match joint count {len(self.joint_seq)}"
)
gains_config = policy_config.get('gains', {})
kp_values = gains_config.get('kp')
kd_values = gains_config.get('kd')
if kp_values is None or kd_values is None:
raise ValueError("[FSMStateWALKAMP] Missing 'gains.kp' or 'gains.kd' in walk_amp.yaml")
self.stiffness_array_seq = np.array(kp_values, dtype=np.float32)
self.damping_array_seq = np.array(kd_values, dtype=np.float32)
if len(self.stiffness_array_seq) != len(self.joint_seq):
raise ValueError(
f"[FSMStateWALKAMP] gains.kp length {len(self.stiffness_array_seq)} does not match joint count {len(self.joint_seq)}"
)
if len(self.damping_array_seq) != len(self.joint_seq):
raise ValueError(
f"[FSMStateWALKAMP] gains.kd length {len(self.damping_array_seq)} does not match joint count {len(self.joint_seq)}"
)
# # 设置从序列到实验室顺序的映射
self.joint_xml = [
"hip_pitch_l_joint", "hip_roll_l_joint", "hip_yaw_l_joint",
"knee_pitch_l_joint", "ankle_pitch_l_joint", "ankle_roll_l_joint",
"hip_pitch_r_joint", "hip_roll_r_joint", "hip_yaw_r_joint",
"knee_pitch_r_joint", "ankle_pitch_r_joint", "ankle_roll_r_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"shoulder_pitch_l_joint", "shoulder_roll_l_joint", "shoulder_yaw_l_joint",
"elbow_pitch_l_joint", "elbow_yaw_l_joint", "wrist_pitch_l_joint", "wrist_roll_l_joint",
"shoulder_pitch_r_joint", "shoulder_roll_r_joint", "shoulder_yaw_r_joint",
"elbow_pitch_r_joint", "elbow_yaw_r_joint", "wrist_pitch_r_joint", "wrist_roll_r_joint",
]
# 从MjXUML顺序映射到实验室顺序
# self.mj2lab = np.array([self.joint_xml.index(joint) for joint in self.joint_seq])
self.lab2mj = []
for name in self.joint_seq:
if name not in self.joint_xml:
raise ValueError(f"[FSMStateWALKAMP] joint '{name}' from walk_amp.yaml not found in joint_xml!")
self.lab2mj.append(self.joint_xml.index(name))
self.lab2mj = np.array(self.lab2mj, dtype=int)
# 从实验室顺序映射到MjXUML顺序
# ====== 把 23 个 lab 关节 scatter 到 29 个 xml 里,多的 6 个保持默认 ======
n_mj = len(self.joint_xml)
# 29 长度mujoco XML 顺序,先全 0 或者你想要的默认值
self.joint_pos_array = np.zeros(n_mj, dtype=np.float32)
self.stiffness_array = np.zeros(n_mj, dtype=np.float32)
self.damping_array = np.zeros(n_mj, dtype=np.float32)
# joint_pos_array_seq / stiffness_array_seq / damping_array_seq 是 23 长度lab 顺序
for lab_idx, mj_idx in enumerate(self.lab2mj):
self.joint_pos_array[mj_idx] = self.joint_pos_array_seq[lab_idx]
self.stiffness_array[mj_idx] = self.stiffness_array_seq[lab_idx]
self.damping_array[mj_idx] = self.damping_array_seq[lab_idx]
# 设置其他参数
self.kps_lab = self.stiffness_array_seq
self.kds_lab = self.damping_array_seq
self.default_angles_lab = self.joint_pos_array_seq
self.action_scale_lab = self.action_scale
self.filtered_x_speed = 0
def _init_onnx_session(self):
"""初始化ONNX推理会话"""
try:
# 配置SessionOptions
options = ort.SessionOptions()
# 启用图优化,使用所有可用的优化(包括算子融合等)
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# 设置执行模式(可选,默认执行模式是顺序执行,但图优化会改变计算图)
# 设置线程数根据CPU核心数调整
# 建议设置为CPU物理核心数非超线程数因为超线程可能不会带来线性提升
options.intra_op_num_threads = 1 # 设置计算图中的运算符内部并行线程数
options.inter_op_num_threads = 1 # 设置多个运算符之间的并行线程数(如果模型有多个分支)
# 启用内存优化(避免重复分配内存)
options.enable_mem_pattern = False # 对于固定输入大小可以设为False以避免内存规划的开销
options.enable_mem_reuse = True # 启用内存重用机制
self.ort_session_ = ort.InferenceSession(self.model_path, options, providers=['CPUExecutionProvider'])
print(f"[{self.log_prefix}-ONNX] ONNX model loaded successfully: {self.model_path}")
except Exception as e:
print(f"[{self.log_prefix}] Failed to load ONNX model: {e}")
self.ort_session_ = None
def on_enter(self):
"""进入WALKAMP状态"""
self._reset_internal_state()
print(f"[{self.log_prefix}] enter")
self.is_first_obs_ = True
self.is_first_action_ = True
self._warmup_inference_counter = 0
self.timer_gait_ = 0.0
if self.robot_data_ is not None:
try:
self._warm_start_pose = self.robot_data_.get_joint_pos().copy()
except Exception:
self._warm_start_pose.fill(0.0)
else:
self._warm_start_pose.fill(0.0)
def run(self, flag: ControlFlag):
"""运行WALKAMP状态 - 与C++版本完全一致"""
print(f"[{self.log_prefix}] run")
# Only run policy inference every decimation_ steps
gait = gait_phase(
self.timer_gait_,
self.gait_cycle,
self.left_theta_offset,
self.right_theta_offset,
self.left_phase_ratio,
self.right_phase_ratio,
).astype(np.float32)
if int(self.robot_data_.time_now_ / self.dt_) % self.decimation_ == 0:
# print(f"[FSMStateWALKAMP] Gait phase: {gait}")
self.compute_observation(flag,gait)
self.compute_actions()
# lab 顺序目标角 23 维
target_dof_pos_lab = self.actions_ * self.action_scale_lab + self.default_angles_lab
# 拿一份当前 mj 顺序的关节角(或你原来用的 default 也行)
target_dof_pos_mj = self.robot_data_.get_joint_pos().copy()
# 只更新 23 个受控 DOF
target_dof_pos_mj[self.lab2mj] = target_dof_pos_lab
commanded_pos = target_dof_pos_mj
if self._warm_start_steps > 0 and self._warmup_inference_counter < self._warm_start_steps:
self._warmup_inference_counter += 1
blend = self._warmup_inference_counter / float(self._warm_start_steps)
commanded_pos = (1.0 - blend) * self._warm_start_pose + blend * target_dof_pos_mj
base = self.robot_data_.q_d_.shape[0] - self.motor_num_
self.robot_data_.q_d_[base:base + len(self.joint_xml)] = commanded_pos
self.robot_data_.q_dot_d_[base:base + len(self.joint_xml)] = 0.0
self.robot_data_.tau_d_[base:base + len(self.joint_xml)] = 0.0
self.last_actions_[:] = self.actions_
self.timer_gait_ += self.dt_
self.robot_data_.joint_kp_p_[:len(self.joint_xml)] = self.stiffness_array
self.robot_data_.joint_kd_p_[:len(self.joint_xml)] = self.damping_array
def compute_observation(self, flag: ControlFlag, gait):
roll, pitch, yaw = (
float(self.robot_data_.imu_data_[2]),
float(self.robot_data_.imu_data_[1]),
float(self.robot_data_.imu_data_[0]),
)
quat_wxyz = self.euler_to_quaternion_scipy(roll, pitch, yaw)
q_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]], dtype=np.float32)
gravity_init = self.quat_rotate_inverse_numpy(q_xyzw, np.array([0.,0.,-1.], dtype=np.float32))
x_speed_command, y_speed_command, yaw_speed_command = self.robot_data_.get_walk_cmd()
new_filtered_x_speed = 1 * x_speed_command + (1 - 1) * self.filtered_x_speed
change = new_filtered_x_speed - self.filtered_x_speed
change = np.clip(change, -0.005, 0.005)
self.filtered_x_speed = self.filtered_x_speed + change
command = np.concatenate([
np.array([
x_speed_command,
y_speed_command,
yaw_speed_command,
], dtype=np.float32),
])
print(f'Input command: {command}')
gyro = np.array([
self.robot_data_.imu_data_[3],
self.robot_data_.imu_data_[4],
self.robot_data_.imu_data_[5]
], dtype=np.float32) * self.ang_vel_scale_
q_mj = self.robot_data_.get_joint_pos()
qdot_mj = self.robot_data_.get_joint_vel()
ang_vel = self.robot_data_.get_angular_velocity()
q_mj = self.robot_data_.get_joint_pos() # mj 顺序,长度 29
dq_mj = self.robot_data_.get_joint_vel()
# 只取 23 个受控关节,变成 lab 顺序
qj = q_mj[self.lab2mj]
dqj = dq_mj[self.lab2mj]
qj = qj - self.default_angles_lab
# Observation = ang_vel(3) + gravity(3) + command(9) + q(23) + dq(23) + action(23) = 84
proprio = np.concatenate([
ang_vel , # 3 elements
gravity_init,
command,
qj,
dqj,
self.last_actions_,
gait
])
# History buffer management exactly like C++
if self.is_first_obs_:
for i in range(self.num_hist_):
start_idx = i * self.obs_size_
end_idx = start_idx + self.obs_size_
self.proprio_hist_buf_[start_idx:end_idx] = proprio
self.is_first_obs_ = False
else:
# Shift history: head((num_hist-1)*obs_size) = tail((num_hist-1)*obs_size)
shift_size = (self.num_hist_ - 1) * self.obs_size_
self.proprio_hist_buf_[:shift_size] = self.proprio_hist_buf_[self.obs_size_:]
self.proprio_hist_buf_[shift_size:] = proprio
# Clip observations exactly like C++
self.observations_ = np.clip(self.proprio_hist_buf_, -self.clip_obs_, self.clip_obs_)
@staticmethod
def euler_to_quaternion_scipy(roll, pitch, yaw, degrees=False):
r = Rotation.from_euler('xyz', [roll, pitch, yaw], degrees=degrees)
q_xyzw = r.as_quat()
return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]], dtype=np.float32)
@staticmethod
def quat_rotate_inverse_numpy(q_xyzw, v):
q_w = q_xyzw[3]
q_v = q_xyzw[:3]
a = v * (2.0 * q_w * q_w - 1.0)
b = np.cross(q_v, v) * (2.0 * q_w)
c = q_v * (2.0 * np.dot(q_v, v))
return a - b + c
def compute_actions(self):
if self.ort_session_ is None:
return
try:
# Prepare input tensor
input_data = self.observations_.reshape(1, -1).astype(np.float32)
# ONNX inference
input_name = self.ort_session_.get_inputs()[0].name
outputs = self.ort_session_.run(None, {input_name: input_data})
# Extract and clip actions exactly like C++
output_data = outputs[0][0]
for i in range(self.action_num_):
self.actions_[i] = np.clip(output_data[i], -self.clip_act_, self.clip_act_)
if self.is_first_action_:
print(f"[{self.log_prefix}-ONNX] First Observation:")
for i in range(self.obs_size_):
print(f"{self.observations_[i]:.6f} ", end="")
print()
self.is_first_action_ = False
except Exception as e:
print(f"[{self.log_prefix}] ONNX Runtime inference error: {e}")
def on_exit(self):
"""退出WALKAMP状态"""
print(f"[{self.log_prefix}] exit")
# 关掉 obs 日志文件
if getattr(self, "obs_log_file", None) is not None:
try:
self.obs_log_file.flush()
self.obs_log_file.close()
print(f"[{self.log_prefix}] obs log saved to {self.obs_log_path}")
except Exception as e:
print(f"[{self.log_prefix}] failed to close obs log: {e}")
self.obs_log_file = None
def check_transition(self, flag: ControlFlag) -> FSMStateName:
"""检查状态转换"""
if flag.fsm_state_command == "gotoSTOP":
return FSMStateName.STOP
elif flag.fsm_state_command == "gotoWALKAMP":
return FSMStateName.WALKAMP
elif flag.fsm_state_command == "gotoZERO":
return FSMStateName.ZERO
else:
return None # 无状态转换