Skip to content

Commit

Permalink
feat: add AttitudeProtect
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed Dec 21, 2024
1 parent 44f8f53 commit 2d4c0d2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,55 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
}
}

void RL::AttitudeProtect(const std::vector<double> &quaternion, float pitch_threshold, float roll_threshold)
{
float rad2deg = 57.2958;
float w, x, y, z;

if (this->params.framework == "isaacgym")
{
w = quaternion[3];
x = quaternion[0];
y = quaternion[1];
z = quaternion[2];
}
else if (this->params.framework == "isaacsim")
{
w = quaternion[0];
x = quaternion[1];
y = quaternion[2];
z = quaternion[3];
}

// Calculate roll (rotation around the X-axis)
float sinr_cosp = 2 * (w * x + y * z);
float cosr_cosp = 1 - 2 * (x * x + y * y);
float roll = std::atan2(sinr_cosp, cosr_cosp) * rad2deg;

// Calculate pitch (rotation around the Y-axis)
float sinp = 2 * (w * y - z * x);
float pitch;
if (std::fabs(sinp) >= 1)
{
pitch = std::copysign(90.0, sinp); // Clamp to avoid out-of-range values
}
else
{
pitch = std::asin(sinp) * rad2deg;
}

if (std::fabs(roll) > roll_threshold)
{
// this->control.control_state = STATE_POS_GETDOWN;
std::cout << LOGGER::WARNING << "Roll exceeds " << roll_threshold << " degrees. Current: " << roll << " degrees." << std::endl;
}
if (std::fabs(pitch) > pitch_threshold)
{
// this->control.control_state = STATE_POS_GETDOWN;
std::cout << LOGGER::WARNING << "Pitch exceeds " << pitch_threshold << " degrees. Current: " << pitch << " degrees." << std::endl;
}
}

#include <termios.h>
#include <sys/ioctl.h>
static bool kbhit()
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/library/rl_sdk/rl_sdk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class RL

// protect func
void TorqueProtect(torch::Tensor origin_output_torques);
void AttitudeProtect(const std::vector<double> &quaternion, float pitch_threshold, float roll_threshold);

protected:
// rl module
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/src/rl_real_a1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ void RL_Real::RunModel()
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);

this->TorqueProtect(origin_output_torques);
this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f);

this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
this->output_dof_pos = this->ComputePosition(this->obs.actions);
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/src/rl_real_go2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ void RL_Real::RunModel()
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);

this->TorqueProtect(origin_output_torques);
this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f);

this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
this->output_dof_pos = this->ComputePosition(this->obs.actions);
Expand Down

0 comments on commit 2d4c0d2

Please sign in to comment.