Files
tienkung-szu/Deploy_Tienkung/policy/niukua/fsm_beyond_mimic.py
2026-03-27 16:10:51 +08:00

548 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from FSM.fsm_base import FSMState, FSMStateName
import numpy as np
import yaml
import os
from types import SimpleNamespace
from typing import Optional
try:
import onnx
except ImportError: # pragma: no cover - runtime fallback when onnx isn't installed
onnx = None
import onnxruntime
try:
import torch
except ImportError: # pragma: no cover - torch is optional for warm start prep
torch = None
from common.robot_data import RobotData
from common.joystick import ControlFlag
import time
DEFAULT_BODY_NAMES = [
"pelvis",
"hip_pitch_l_link",
"hip_roll_l_link",
"hip_yaw_l_link",
"knee_pitch_l_link",
"ankle_pitch_l_link",
"ankle_roll_l_link",
"hip_pitch_r_link",
"hip_roll_r_link",
"hip_yaw_r_link",
"knee_pitch_r_link",
"ankle_pitch_r_link",
"ankle_roll_r_link",
"waist_yaw_link",
"waist_roll_link",
"waist_pitch_link",
"shoulder_pitch_l_link",
"shoulder_roll_l_link",
"shoulder_yaw_l_link",
"elbow_pitch_l_link",
"shoulder_pitch_r_link",
"shoulder_roll_r_link",
"shoulder_yaw_r_link",
"elbow_pitch_r_link",
]
class FSMStateBeyondMimic(FSMState):
def __init__(self, robot_data: RobotData, config_path: Optional[str] = None, variant_name: str = "default"):
super().__init__(robot_data)
self.motion_phase = 0
self.counter_step = 0
self.ref_motion_phase = 0
self.variant_name = variant_name
current_dir = os.path.dirname(os.path.abspath(__file__))
if config_path is None:
config_path = os.path.join(current_dir, "config", "BeyondMimic.yaml")
self.config_path = os.path.abspath(config_path)
with open(self.config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# 兼容多策略:模型路径仍然默认指向 policy/beyond_mimic/model 下
self.onnx_path = config["onnx_path"]
if not os.path.isabs(self.onnx_path):
self.onnx_path = os.path.join(current_dir, "model", self.onnx_path)
# self.motion_length = config["motion_length"]
# self.history_length = config["history_length"]
self.physical_dt = config["physical_dt"]
self.decimation_ = config["decimation"]
self.num_actions = config["num_actions"]
self.motor_nums = config["motor_nums"]
self.warm_start_time = config["warm_start_time"]
self.kps = config["kps"]
self.kds = config["kds"]
self.hold_final_reference = config.get("hold_final_reference", False)
self.motion_length: Optional[int] = config.get("motion_length")
self.body_names = config.get("body_names", DEFAULT_BODY_NAMES)
self.locked_joint_map = config["locked_joint_map"]
self.anchor_body_name = config.get("anchor_body", "pelvis")
if self.anchor_body_name not in self.body_names:
raise ValueError(f"Anchor body {self.anchor_body_name} missing from body list.")
self.anchor_body_index = self.body_names.index(self.anchor_body_name)
self.num_bodies = len(self.body_names)
self._warm_start_from_lab = np.zeros(self.num_actions, dtype=np.float32)
self._warm_start_to_lab = np.zeros(self.num_actions, dtype=np.float32)
self._warm_start_prev_target = np.zeros(self.num_actions, dtype=np.float32)
self.last_run_time = time.perf_counter()
self.qj_obs = np.zeros(self.num_actions, dtype=np.float32) # 初始化为最大可能大小
self.dqj_obs = np.zeros(self.num_actions, dtype=np.float32)
self.num_obs = None # set after loading onnx
self.obs = None
self.action = np.zeros(self.num_actions, dtype=np.float32)
self.ref_joint_pos = np.zeros(self.num_actions, dtype=np.float32)
self.ref_joint_vel = np.zeros(self.num_actions, dtype=np.float32)
self.ref_body_pos_w = np.zeros((1, self.num_bodies, 3), dtype=np.float32)
self.ref_body_quat_w = np.zeros((1, self.num_bodies, 4), dtype=np.float32)
self.ref_body_lin_vel_w = np.zeros((1, self.num_bodies, 3), dtype=np.float32)
self.ref_body_ang_vel_w = np.zeros((1, self.num_bodies, 3), dtype=np.float32)
# load policy
self.onnx_model = None
metadata_props = []
if onnx is not None and hasattr(onnx, "load"):
try:
self.onnx_model = onnx.load(self.onnx_path)
metadata_props = getattr(self.onnx_model, "metadata_props", [])
except Exception as exc:
print(f"[FSMStateBeyondMimic] Failed to load ONNX model via onnx.load: {exc}")
else:
print("[FSMStateBeyondMimic] Python onnx package unavailable, falling back to onnxruntime metadata.")
self.ort_session = onnxruntime.InferenceSession(self.onnx_path)
if not metadata_props:
model_meta = self.ort_session.get_modelmeta()
custom_map = getattr(model_meta, "custom_metadata_map", {})
metadata_props = [SimpleNamespace(key=k, value=v) for k, v in custom_map.items()]
input = self.ort_session.get_inputs()
self.input_name = []
for i, inpt in enumerate(input):
self.input_name.append(inpt.name)
obs_input = self.ort_session.get_inputs()[0]
last_dim = obs_input.shape[-1]
if isinstance(last_dim, int):
self.num_obs = last_dim
else:
self.num_obs = config.get("num_obs", 154)
self.obs = np.zeros(self.num_obs, dtype=np.float32)
# 从ONNX模型中读取参数
self.joint_seq = None
self.joint_pos_array_seq = None
self.action_scale = None
# self.stiffness_array_seq = None
# self.damping_array_seq = None
for prop in metadata_props:
if prop.key == "joint_names":
self.joint_seq = prop.value.split(",")
if prop.key == "default_joint_pos":
self.joint_pos_array_seq = np.array([float(x) for x in prop.value.split(",")])
# if prop.key == "joint_stiffness":
# self.stiffness_array_seq = np.array([float(x) for x in prop.value.split(",")])
# if prop.key == "joint_damping":
# self.damping_array_seq = np.array([float(x) for x in prop.value.split(",")])
if prop.key == "action_scale":
self.action_scale = np.array([float(x) for x in prop.value.split(",")])
if prop.key in ("motion_length", "time_step_total"):
try:
self.motion_length = int(float(prop.value))
except (TypeError, ValueError):
print(f"[FSMStateBeyondMimic] Invalid motion_length metadata value: {prop.value}")
if self.motion_length is not None:
try:
self.motion_length = int(self.motion_length)
except (TypeError, ValueError):
print(f"[FSMStateBeyondMimic] Invalid motion_length config value: {self.motion_length}")
self.motion_length = None
# # 根据YAML配置设置关节映射
# self.mj2lab = np.array(config["mj2lab"], dtype=np.int32)
# 设置从序列到实验室顺序的映射
self.joint_xml = [
"hip_pitch_l_joint", "hip_roll_l_joint", "hip_yaw_l_joint",
"knee_pitch_l_joint", "ankle_pitch_l_joint", "ankle_roll_l_joint",
"hip_pitch_r_joint", "hip_roll_r_joint", "hip_yaw_r_joint",
"knee_pitch_r_joint", "ankle_pitch_r_joint", "ankle_roll_r_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"shoulder_pitch_l_joint", "shoulder_roll_l_joint", "shoulder_yaw_l_joint",
"elbow_pitch_l_joint",
"shoulder_pitch_r_joint", "shoulder_roll_r_joint", "shoulder_yaw_r_joint",
"elbow_pitch_r_joint",
]
# 从MjXUML顺序映射到实验室顺序
self.mj2lab = np.array([self.joint_xml.index(joint) for joint in self.joint_seq])
# 从实验室顺序映射到MjXUML顺序
self.joint_pos_array = np.array([self.joint_pos_array_seq[self.joint_seq.index(joint)] for joint in self.joint_xml])
self.default_angles_lab = self.joint_pos_array_seq
self.action_scale_lab = self.action_scale
print("BeyondMimic-like policy initializing ...")
self._warmup_inference_counter = 0
self.warm_start_steps = 0
# Cache for the last motion frame so we can keep sending it after motion ends.
self._final_ref_cached = False
self._final_ref_joint_pos = np.zeros_like(self.ref_joint_pos)
self._final_ref_joint_vel = np.zeros_like(self.ref_joint_vel)
self._final_ref_body_pos_w = np.zeros_like(self.ref_body_pos_w)
self._final_ref_body_quat_w = np.zeros_like(self.ref_body_quat_w)
self._final_ref_body_lin_vel_w = np.zeros_like(self.ref_body_lin_vel_w)
self._final_ref_body_ang_vel_w = np.zeros_like(self.ref_body_ang_vel_w)
def on_enter(self):
self.ref_motion_phase = 0.
self.motion_time = 0
self.counter_step = 0
self._warmup_inference_counter = 0
print(f"[FSMStateBeyondMimic] enter variant={self.variant_name}, config={self.config_path}")
if self.warm_start_time > 0:
step = self.decimation_ * self.physical_dt
self.warm_start_steps = max(1, int(self.warm_start_time / step))
else:
self.warm_start_steps = 0
observation = {}
observation[self.input_name[0]] = np.zeros((1, self.num_obs), dtype=np.float32)
observation[self.input_name[1]] = np.zeros((1, 1), dtype=np.float32)
outputs_result = self.ort_session.run(None, observation)
(
self.action,
self.ref_joint_pos,
self.ref_joint_vel,
self.ref_body_pos_w,
self.ref_body_quat_w,
self.ref_body_lin_vel_w,
self.ref_body_ang_vel_w,
) = outputs_result
self.qj_obs = np.zeros(self.num_actions, dtype=np.float32)
self.dqj_obs = np.zeros(self.num_actions, dtype=np.float32)
self.obs = np.zeros(self.num_obs)
self._final_ref_cached = False
self._warm_start_from_lab = self._get_current_joint_pos_lab()
self._warm_start_to_lab = self._get_onnx_first_frame_lab()
self._warm_start_prev_target = np.array(self._warm_start_from_lab, copy=True)
# self.action = np.zeros(self.num_actions)
pass
def quat_mul(self, q1, q2):
w1, x1, y1, z1 = q1[0], q1[1], q1[2], q1[3]
w2, x2, y2, z2 = q2[0], q2[1], q2[2], q2[3]
# perform multiplication
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return np.array([w, x, y, z])
def matrix_from_quat(self, q):
w, x, y, z = q
return np.array([
[1 - 2 * (y**2 + z**2), 2 * (x * y - z * w), 2 * (x * z + y * w)],
[2 * (x * y + z * w), 1 - 2 * (x**2 + z**2), 2 * (y * z - x * w)],
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x**2 + y**2)]
])
def yaw_quat(self, q):
w, x, y, z = q
yaw = np.arctan2(2 * (w * z + x * y), 1 - 2 * (y**2 + z**2))
return np.array([np.cos(yaw / 2), 0, 0, np.sin(yaw / 2)])
def euler_single_axis_to_quat(self, angle, axis, degrees=False):
"""
将单个欧拉角转换为四元数
参数:
angle: 旋转角度
axis: 旋转轴,可以是 'x', 'y', 'z' 或者单位向量 [x, y, z]
degrees: 如果为True输入角度为度数如果为False输入角度为弧度
返回:
四元数 (w, x, y, z)
"""
# 转换角度为弧度
if degrees:
angle = np.radians(angle)
# 计算半角
half_angle = angle * 0.5
cos_half = np.cos(half_angle)
sin_half = np.sin(half_angle)
# 根据旋转轴确定四元数分量
if isinstance(axis, str):
if axis.lower() == 'x':
return np.array([cos_half, sin_half, 0.0, 0.0])
elif axis.lower() == 'y':
return np.array([cos_half, 0.0, sin_half, 0.0])
elif axis.lower() == 'z':
return np.array([cos_half, 0.0, 0.0, sin_half])
else:
raise ValueError("axis must be 'x', 'y', 'z' or a 3D unit vector")
else:
# 假设axis是一个3D向量 [x, y, z]
axis = np.array(axis, dtype=np.float32)
# 归一化轴向量
axis_norm = np.linalg.norm(axis)
if axis_norm == 0:
raise ValueError("axis vector cannot be zero")
axis = axis / axis_norm
# 计算四元数分量
w = cos_half
x = sin_half * axis[0]
y = sin_half * axis[1]
z = sin_half * axis[2]
return np.array([w, x, y, z])
def inner_run(self):
robot_quat = self.robot_data_.get_robot_quat()
qj = self.robot_data_.get_joint_pos()
# 将29dof自由度的数据映射回锁住腕部6关节之后的逻辑和和之前没区别
qj = qj[self.locked_joint_map]
qj = qj[self.mj2lab]
qj = (qj - self.default_angles_lab)
# IMU mounted on pelvis, so directly use measured orientation.
ref_anchor_pos_w, ref_anchor_ori_w = self._get_ref_anchor_pose()
# 在第一帧提取当前机器人yaw方向与参考动作yaw方向做差与beyond mimic一致
if(self.counter_step < 2):
init_to_anchor = self.matrix_from_quat(self.yaw_quat(ref_anchor_ori_w))
world_to_anchor = self.matrix_from_quat(self.yaw_quat(robot_quat))
self.init_to_world = world_to_anchor @ init_to_anchor.T
print("self.init_to_world: ", self.init_to_world)
self.counter_step += 1
return
robot_rot_mat = self.matrix_from_quat(robot_quat)
motion_anchor_ori_b = robot_rot_mat.T @ self.init_to_world @ self.matrix_from_quat(ref_anchor_ori_w)
ang_vel = self.robot_data_.get_angular_velocity()
dqj = self.robot_data_.get_joint_vel()
#映射到23dof
dqj = dqj[self.locked_joint_map]
use_warmstart = (
self.warm_start_steps > 0
and self._warmup_inference_counter < self.warm_start_steps
)
blended_target = None
if use_warmstart:
blend = (self._warmup_inference_counter + 1) / self.warm_start_steps
blended_target = (1.0 - blend) * self._warm_start_from_lab + blend * self._warm_start_to_lab
blended_vel = (blended_target - self._warm_start_prev_target) / (
self.decimation_ * self.physical_dt
)
self._warm_start_prev_target = blended_target
command_joint_pos = blended_target.reshape(1, -1)
command_joint_vel = blended_vel.reshape(1, -1)
safe_scale = np.where(self.action_scale_lab == 0, 1.0, self.action_scale_lab)
action_for_history = (blended_target - self.default_angles_lab) / safe_scale
else:
command_joint_pos = self.ref_joint_pos
command_joint_vel = self.ref_joint_vel
action_for_history = self.action
command_root = self.matrix_from_quat(ref_anchor_ori_w)
command_vec = np.concatenate(
(
command_joint_pos.squeeze(0),
command_joint_vel.squeeze(0),
# ref_anchor_pos_w,
# command_root[:, :2].reshape(-1),
),
dtype=np.float32,
)
mimic_obs_buf = np.concatenate(
(
command_vec,
motion_anchor_ori_b[:, :2].reshape(-1),
ang_vel,
qj,
dqj[self.mj2lab],
np.asarray(action_for_history, dtype=np.float32).reshape(-1),
),
axis=-1,
dtype=np.float32,
)
if mimic_obs_buf.shape[0] != self.num_obs:
raise RuntimeError(f"Observation length mismatch. Expected {self.num_obs}, got {mimic_obs_buf.shape[0]}.")
if torch is not None:
mimic_obs_tensor = torch.from_numpy(mimic_obs_buf).unsqueeze(0).cpu().numpy()
else:
mimic_obs_tensor = np.expand_dims(mimic_obs_buf, axis=0)
observation = {}
# obs0 是网络观测obs1 是当前时间步,用于输出参考动作信息
observation[self.input_name[0]] = mimic_obs_tensor
time_index = max(self.counter_step - self.warm_start_steps, 0)
if (
self.hold_final_reference
and self.motion_length is not None
and self.motion_length > 0
):
if self.motion_length is not None and self.motion_length > 0:
time_index = min(time_index, self.motion_length - 1)
observation[self.input_name[1]] = np.array([[time_index]], dtype=np.float32)
outputs_result = self.ort_session.run(None, observation)
(
self.action,
self.ref_joint_pos,
self.ref_joint_vel,
self.ref_body_pos_w,
self.ref_body_quat_w,
self.ref_body_lin_vel_w,
self.ref_body_ang_vel_w,
) = outputs_result
if (
self.hold_final_reference
and self.motion_length is not None
and self.motion_length > 0
):
if time_index == self.motion_length - 1 and not self._final_ref_cached:
self._cache_final_ref()
elif self.counter_step >= self.motion_length and self._final_ref_cached:
self._apply_final_ref()
target_dof_pos_mj = np.zeros(29)
target_dof_pos_mj_23dof = np.zeros(23)
if use_warmstart and blended_target is not None:
target_dof_pos_lab = blended_target
# Keep action history aligned with the inserted warm trajectory.
self.action = np.asarray(action_for_history, dtype=np.float32).reshape(1, -1)
else:
target_dof_pos_lab = self.action * self.action_scale_lab + self.default_angles_lab
if target_dof_pos_lab.ndim > 1:
target_dof_pos_lab = np.squeeze(target_dof_pos_lab, axis=0)
if self.warm_start_steps > 0:
self._warmup_inference_counter += 1
if self._warmup_inference_counter <= self.warm_start_steps:
blend = self._warmup_inference_counter / self.warm_start_steps
if not use_warmstart:
target_dof_pos_lab = (1.0 - blend) * self._warm_start_from_lab + blend * self._warm_start_to_lab
target_dof_pos_mj_23dof[self.mj2lab] = target_dof_pos_lab
target_dof_pos_mj[self.locked_joint_map] = target_dof_pos_mj_23dof
# Set joint commands exactly like C++
for i in range(self.motor_nums):
# C++: robot_data_->q_d_(35 - motor_num_ + i)
joint_idx = 35 - self.motor_nums + i
self.robot_data_.q_d_[joint_idx] = target_dof_pos_mj[i]
self.robot_data_.q_dot_d_[joint_idx] = 0.0
self.robot_data_.tau_d_[joint_idx] = 0.0
# update motion phase
self.counter_step += 1
def _cache_final_ref(self):
if not self.hold_final_reference:
return
self._final_ref_cached = True
self._final_ref_joint_pos = np.array(self.ref_joint_pos, copy=True)
self._final_ref_joint_vel = np.array(self.ref_joint_vel, copy=True)
self._final_ref_body_pos_w = np.array(self.ref_body_pos_w, copy=True)
self._final_ref_body_quat_w = np.array(self.ref_body_quat_w, copy=True)
self._final_ref_body_lin_vel_w = np.array(self.ref_body_lin_vel_w, copy=True)
self._final_ref_body_ang_vel_w = np.array(self.ref_body_ang_vel_w, copy=True)
def _apply_final_ref(self):
if not self.hold_final_reference or not self._final_ref_cached:
return
self.ref_joint_pos = np.array(self._final_ref_joint_pos, copy=True)
self.ref_joint_vel = np.array(self._final_ref_joint_vel, copy=True)
self.ref_body_pos_w = np.array(self._final_ref_body_pos_w, copy=True)
self.ref_body_quat_w = np.array(self._final_ref_body_quat_w, copy=True)
self.ref_body_lin_vel_w = np.array(self._final_ref_body_lin_vel_w, copy=True)
self.ref_body_ang_vel_w = np.array(self._final_ref_body_ang_vel_w, copy=True)
def run(self, flag: ControlFlag):
if int(self.robot_data_.time_now_ / self.physical_dt) % self.decimation_ == 0:
current_time = time.perf_counter()
print(f"Inference hz: {1/(current_time - self.last_run_time)}")
self.last_run_time = current_time
self.inner_run()
self.set_kp_kd()
def set_kp_kd(self):
# Set kp/kd gains
self.robot_data_.joint_kp_p_[:self.motor_nums] = self.kps
self.robot_data_.joint_kd_p_[:self.motor_nums] = self.kds
def on_exit(self):
self.action = np.zeros(self.num_actions, dtype=np.float32)
# self.action_buf = np.zeros(23 * self.history_length, dtype=np.float32)
self.ref_motion_phase = 0.
# self.ref_motion_phase_buf = np.zeros(1 * self.history_length, dtype=np.float32)
self.motion_time = 0
self.counter_step = 0
self._final_ref_cached = False
print("exited")
def check_transition(self, flag: ControlFlag) -> FSMStateName:
"""检查状态转换"""
if flag.fsm_state_command == "gotoSTOP":
return FSMStateName.STOP
elif flag.fsm_state_command == "gotoZERO":
return FSMStateName.ZERO
elif flag.fsm_state_command == "gotoBEYONDMIMIC":
return FSMStateName.BEYONDMIMIC
elif flag.fsm_state_command == "gotoBEYONDZERO":
return FSMStateName.BEYONDZERO
elif flag.fsm_state_command == "gotoWALKAMP":
return FSMStateName.WALKAMP
else:
return None # 无状态转换
def _get_ref_anchor_pose(self):
ref_pos = self.ref_body_pos_w[:, self.anchor_body_index].squeeze(0)
ref_quat = self.ref_body_quat_w[:, self.anchor_body_index].squeeze(0)
return ref_pos.astype(np.float32), ref_quat.astype(np.float32)
def _get_current_joint_pos_lab(self) -> np.ndarray:
try:
current_q = self.robot_data_.get_joint_pos()
current_q = current_q[self.locked_joint_map]
current_q_lab = current_q[self.mj2lab]
return current_q_lab.astype(np.float32)
except Exception as exc:
print(f"[FSMStateBeyondMimic] Failed to read current joint pose: {exc}")
return np.array(self.default_angles_lab, copy=True)
def _get_onnx_first_frame_lab(self) -> np.ndarray:
try:
action = self.action
if action is None:
raise RuntimeError("ONNX action output is None")
if action.ndim > 1:
action = np.squeeze(action, axis=0)
first_frame = action * self.action_scale_lab + self.default_angles_lab
return first_frame.astype(np.float32)
except Exception as exc:
print(f"[FSMStateBeyondMimic] Failed to read ONNX first frame: {exc}")
return np.array(self.default_angles_lab, copy=True)