Skip to content
Snippets Groups Projects
Commit cf431c48 authored by Cyrille Berger's avatar Cyrille Berger
Browse files

add trajectory critic/generator

parent ca849461
No related branches found
No related tags found
No related merge requests found
......@@ -39,7 +39,9 @@ set(dependencies
add_library(air_navigation SHARED
src/python_controller.cpp
src/python_interpreter.cpp)
src/python_interpreter.cpp
src/python_trajectory_critic.cpp
src/python_trajectory_generator.cpp)
target_compile_definitions(air_navigation PUBLIC "PLUGINLIB__DISABLE_BOOST_FUNCTIONS")
target_link_libraries(air_navigation ${PYTHON_LIBRARY} )
ament_target_dependencies(air_navigation
......@@ -70,5 +72,6 @@ ament_export_libraries(air_navigation)
ament_export_dependencies(${dependencies})
pluginlib_export_plugin_description_file(nav2_core air_navigation_plugins.xml)
pluginlib_export_plugin_description_file(dwb_core air_navigation_plugins.xml)
ament_package()
......@@ -5,5 +5,11 @@
air_navigation_python_controller
</description>
</class>
</library>
<class type="air_navigation::PythonTrajectoryGenerator" base_class_type="dwb_core::TrajectoryGenerator">
<description></description>
</class>
<class type="air_navigation::PythonTrajectoryCritic" base_class_type="dwb_core::TrajectoryCritic">
<description></description>
</class>
</library>
</class_libraries>
......@@ -29,6 +29,10 @@ namespace air_navigation::py_interface
struct Twist { Vector3 linear; Vector3 angular; };
struct TwistStamped { Header header; Twist twist; };
struct Path { Header header; std::vector<PoseStamped> poses; };
struct Pose2D { double x, y, theta; };
struct Twist2D { double x, y, theta; };
struct Path2D { Header header; std::vector<Pose2D> poses; };
struct Trajectory2D { Twist2D velocity; std::vector<Time> time_offsets; std::vector<Pose2D> poses; };
template<typename _TS_, typename _TD_>
inline void copy_header(const _TS_& s, _TD_* d)
......@@ -80,6 +84,46 @@ namespace air_navigation::py_interface
}
}
template<typename _TS_, typename _TD_>
inline void copy_pose_2d(const _TS_& s, _TD_* d)
{
d->x = s.x;
d->y = s.y;
d->theta = s.theta;
}
template<typename _TS_, typename _TD_>
inline void copy_twist_2d(const _TS_& s, _TD_* d)
{
copy_pose_2d(s, d);
}
template<typename _TPS_, typename _TS_, typename _TD_>
inline void copy_path_2d(const _TS_& s, _TD_* d)
{
copy_header(s.header, &d->header);
for(auto p : s.poses)
{
_TPS_ dp;
copy_pose_2d(p, &dp);
d->poses.push_back(dp);
}
} template<typename _TPS_, typename _TPT_, typename _TS_, typename _TD_>
inline void copy_trajectory_2d(const _TS_& s, _TD_* d)
{
copy_twist_2d(s.velocity, &d->velocity);
for(auto t : s.time_offsets)
{
_TPT_ dt;
dt.sec = t.sec;
dt.nanosec = t.nanosec;
d->time_offsets.push_back(dt);
}
for(auto p : s.poses)
{
_TPS_ dp;
copy_pose_2d(p, &dp);
d->poses.push_back(dp);
}
}
class GoalChecker
{
public:
......@@ -148,4 +192,32 @@ namespace air_navigation::py_interface
std::cout << _text << std::endl;
}
};
class TrajectoryCritic
{
public:
TrajectoryCritic() = default;
virtual ~TrajectoryCritic() = default;
virtual void onInit() = 0;
virtual void reset() = 0;
virtual bool prepare(const Pose2D&, const Twist2D&, const Pose2D&, const Path2D&) = 0;
virtual double scoreTrajectory(const Trajectory2D & traj) = 0;
virtual void debrief(const Twist2D &) = 0;
};
class TrajectoryGenerator
{
public:
TrajectoryGenerator() = default;
virtual ~TrajectoryGenerator() = default;
virtual void initialize(const std::string& name) = 0;
virtual void reset() = 0;
virtual void startNewIteration(const Twist2D & current_velocity) = 0;
virtual bool hasMoreTwists() = 0;
virtual Twist2D nextTwist() = 0;
virtual Trajectory2D generateTrajectory(const Pose2D & start_pose, const Twist2D & start_vel, const Twist2D & cmd_vel) = 0;
virtual void setSpeedLimit(double speed_limit, bool percentage) = 0;
};
}
......@@ -5,8 +5,7 @@
#include <memory>
#include "nav2_core/controller.hpp"
#include "pluginlib/class_loader.hpp"
#include "pluginlib/class_list_macros.hpp"
namespace air_navigation
{
......@@ -41,9 +40,7 @@ public:
void setSpeedLimit(const double & speed_limit, const bool & percentage) override;
protected:
void* wtf = nullptr;
py_interface::Controller* m_python_controller_interface = nullptr;
};
} // namespace nav2_pure_pursuit_controller
......@@ -5,6 +5,8 @@ namespace air_navigation
namespace py_interface
{
class Controller;
class TrajectoryCritic;
class TrajectoryGenerator;
}
class python_interpreter
{
......@@ -15,6 +17,8 @@ namespace air_navigation
~python_interpreter();
public:
py_interface::Controller* create_controller(const std::string& _module, const std::string& _class);
py_interface::TrajectoryCritic* create_trajectory_critic(const std::string& _module, const std::string& _class);
py_interface::TrajectoryGenerator* create_trajectory_generator(const std::string& _module, const std::string& _class);
void release_controller(py_interface::Controller* _controller);
void acquire_lock();
void release_lock();
......
#include <dwb_core/trajectory_critic.hpp>
namespace air_navigation
{
namespace py_interface
{
class TrajectoryCritic;
}
class PythonTrajectoryCritic : public dwb_core::TrajectoryCritic
{
public:
PythonTrajectoryCritic() = default;
~PythonTrajectoryCritic() = default;
void onInit() override;
void reset() override;
bool prepare(const geometry_msgs::msg::Pose2D & pose, const nav_2d_msgs::msg::Twist2D & vel, const geometry_msgs::msg::Pose2D & goal, const nav_2d_msgs::msg::Path2D & global_plan) override;
double scoreTrajectory(const dwb_msgs::msg::Trajectory2D & traj) override;
void debrief(const nav_2d_msgs::msg::Twist2D &) override;
private:
py_interface::TrajectoryCritic* m_python_trajectory_generator_interface = nullptr;
};
}
#include <dwb_core/trajectory_generator.hpp>
namespace air_navigation
{
namespace py_interface
{
class TrajectoryGenerator;
}
class PythonTrajectoryGenerator : public dwb_core::TrajectoryGenerator
{
public:
PythonTrajectoryGenerator() = default;
~PythonTrajectoryGenerator() = default;
void initialize(const nav2_util::LifecycleNode::SharedPtr & nh, const std::string & plugin_name) override;
void reset() override;
void startNewIteration(const nav_2d_msgs::msg::Twist2D & current_velocity) override;
bool hasMoreTwists() override;
nav_2d_msgs::msg::Twist2D nextTwist() override;
dwb_msgs::msg::Trajectory2D generateTrajectory(const geometry_msgs::msg::Pose2D & start_pose, const nav_2d_msgs::msg::Twist2D & start_vel, const nav_2d_msgs::msg::Twist2D & cmd_vel) override;
void setSpeedLimit(const double & speed_limit, const bool & percentage) override;
private:
py_interface::TrajectoryGenerator* m_python_trajectory_generator_interface = nullptr;
};
}
import air_navigation_
Controller = air_navigation_.air_navigation.py_interface.Controller
TrajectoryCritic = air_navigation_.air_navigation.py_interface.TrajectoryCritic
TrajectoryGenerator = air_navigation_.air_navigation.py_interface.TrajectoryGenerator
TwistStamped = air_navigation_.air_navigation.py_interface.TwistStamped
Twist2D = air_navigation_.air_navigation.py_interface.Twist2D
Pose2D = air_navigation_.air_navigation.py_interface.Pose2D
Trajectory2D = air_navigation_.air_navigation.py_interface.Trajectory2D
......@@ -30,6 +30,10 @@ namespace air_navigation
struct Twist { Vector3 linear; Vector3 angular; };
struct TwistStamped { Header header; Twist twist; };
struct Path { Header header; std::vector<air_navigation::py_interface::PoseStamped> poses; };
struct Pose2D { double x; double y; double theta; };
struct Twist2D { double x; double y; double theta; };
struct Path2D { Header header; std::vector<air_navigation::py_interface::Pose2D> poses; };
struct Trajectory2D { Twist2D velocity; std::vector<air_navigation::py_interface::Time> time_offsets; std::vector<air_navigation::py_interface::Pose2D> poses; };
class GoalChecker /NoDefaultCtors/
{
......@@ -65,5 +69,31 @@ namespace air_navigation
void print(const std::string& _text);
};
class TrajectoryCritic
{
public:
TrajectoryCritic();
virtual ~TrajectoryCritic();
virtual void onInit() = 0;
virtual void reset() = 0;
virtual bool prepare(const Pose2D&, const Twist2D&, const Pose2D&, const Path2D&) = 0;
virtual double scoreTrajectory(const Trajectory2D & traj) = 0;
virtual void debrief(const Twist2D &) = 0;
};
class TrajectoryGenerator
{
public:
TrajectoryGenerator();
virtual ~TrajectoryGenerator();
virtual void initialize(const std::string& name) = 0;
virtual void reset() = 0;
virtual void startNewIteration(const Twist2D & current_velocity) = 0;
virtual bool hasMoreTwists() = 0;
virtual Twist2D nextTwist() = 0;
virtual Trajectory2D generateTrajectory(const Pose2D & start_pose, const Twist2D & start_vel, const Twist2D & cmd_vel) = 0;
virtual void setSpeedLimit(double speed_limit, bool percentage) = 0;
};
};
};
......@@ -4,8 +4,6 @@
#include <memory>
#include <iostream>
#include "Python.h"
#include "nav2_util/node_utils.hpp"
#include "air_navigation/python_controller.h"
#include "air_navigation/py_interface.h"
......@@ -105,5 +103,8 @@ void PythonController::setSpeedLimit(const double & speed_limit, const bool & pe
}
}
#include "pluginlib/class_loader.hpp"
#include "pluginlib/class_list_macros.hpp"
// Register this controller as a nav2_core plugin
PLUGINLIB_EXPORT_CLASS(air_navigation::PythonController, nav2_core::Controller)
......@@ -22,6 +22,9 @@ struct python_interpreter::Private
// PyThreadState* state = 0;
PyGILState_STATE state;
bool self_initialise = false;
template<typename _T_>
_T_* create_object(const std::string& _module, const std::string& _class, const char* _cpp_name);
};
python_interpreter* python_interpreter::instance()
......@@ -56,9 +59,10 @@ python_interpreter::~python_interpreter()
delete d;
}
py_interface::Controller* python_interpreter::create_controller(const std::string& _module, const std::string& _class)
template<typename _T_>
_T_* python_interpreter::Private::create_object(const std::string& _module, const std::string& _class, const char* _cpp_name)
{
acquire_lock();
python_lock pl;
// PyRun_SimpleString("import air_navigation");
PyObject *pModule = PyImport_Import(PyUnicode_DecodeFSDefault(_module.c_str()));
if(pModule)
......@@ -73,10 +77,9 @@ py_interface::Controller* python_interpreter::create_controller(const std::strin
Py_DECREF(pFunc);
int isErr = 0;
int state;
const sipTypeDef *td = get_sip_api()->api_find_type("air_navigation::py_interface::Controller");
py_interface::Controller* controller = reinterpret_cast<py_interface::Controller*>(get_sip_api()->api_convert_to_type(pValue, td, NULL, SIP_NOT_NONE, &state, &isErr));
const sipTypeDef *td = get_sip_api()->api_find_type(_cpp_name);
_T_* controller = reinterpret_cast<_T_*>(get_sip_api()->api_convert_to_type(pValue, td, NULL, SIP_NOT_NONE, &state, &isErr));
get_sip_api()->api_transfer_to(pValue, Py_None);
release_lock();
return controller;
} else {
PyErr_Print();
......@@ -87,8 +90,24 @@ py_interface::Controller* python_interpreter::create_controller(const std::strin
PyErr_Print();
RCLCPP_ERROR_STREAM(rclcpp::get_logger("python_interpreter"), "Failed to load module '" << _module << "'");
}
release_lock();
return nullptr;
}
py_interface::Controller* python_interpreter::create_controller(const std::string& _module, const std::string& _class)
{
return d->create_object<py_interface::Controller>(_module, _class, "air_navigation::py_interface::Controller");
}
py_interface::TrajectoryCritic* python_interpreter::create_trajectory_critic(const std::string& _module, const std::string& _class)
{
return d->create_object<py_interface::TrajectoryCritic>(_module, _class, "air_navigation::py_interface::TrajectoryCritic");
}
py_interface::TrajectoryGenerator* python_interpreter::create_trajectory_generator(const std::string& _module, const std::string& _class)
{
return d->create_object<py_interface::TrajectoryGenerator>(_module, _class, "air_navigation::py_interface::TrajectoryGenerator");
}
void python_interpreter::release_controller(py_interface::Controller* )
......
#include "air_navigation/python_trajectory_critic.h"
#include "air_navigation/py_interface.h"
#include "air_navigation/python_interpreter.h"
#include "nav_2d_utils/parameters.hpp"
using namespace air_navigation;
void PythonTrajectoryCritic::onInit()
{
auto node = node_.lock();
if (!node) {
throw std::runtime_error{"Failed to lock node"};
}
std::string module_name = nav_2d_utils::searchAndGetParam(
node,
dwb_plugin_name_ + "." + name_ + ".module_name", std::string());
std::string class_name = nav_2d_utils::searchAndGetParam(
node,
dwb_plugin_name_ + "." + name_ + ".class_name", std::string());
RCLCPP_INFO_STREAM(node->get_logger(), "Python critic will use: " << module_name << "." << class_name);
m_python_trajectory_generator_interface = air_navigation::python_interpreter::instance()->create_trajectory_critic(module_name, class_name);;
if(not m_python_trajectory_generator_interface)
{
RCLCPP_ERROR_STREAM(node->get_logger(), "Failed to construct critic!");
} else {
python_lock pl;
m_python_trajectory_generator_interface->onInit();
}
}
void PythonTrajectoryCritic::reset()
{
python_lock pl;
if(m_python_trajectory_generator_interface) m_python_trajectory_generator_interface->reset();
}
bool PythonTrajectoryCritic::prepare(const geometry_msgs::msg::Pose2D & pose, const nav_2d_msgs::msg::Twist2D & vel, const geometry_msgs::msg::Pose2D & goal, const nav_2d_msgs::msg::Path2D & global_plan)
{
python_lock pl;
if(m_python_trajectory_generator_interface)
{
py_interface::Pose2D py_pose;
py_interface::Twist2D py_vel;
py_interface::Pose2D py_goal;
py_interface::Path2D py_global_plan;
py_interface::copy_pose_2d(pose, &py_pose);
py_interface::copy_twist_2d(vel, &py_vel);
py_interface::copy_pose_2d(goal, &py_goal);
py_interface::copy_path_2d<py_interface::Pose2D>(global_plan, &py_global_plan);
return m_python_trajectory_generator_interface->prepare(py_pose, py_vel, py_goal, py_global_plan);
}
return false;
}
double PythonTrajectoryCritic::scoreTrajectory(const dwb_msgs::msg::Trajectory2D & traj)
{
python_lock pl;
if(m_python_trajectory_generator_interface)
{
py_interface::Trajectory2D py_traj;
py_interface::copy_trajectory_2d<py_interface::Pose2D, py_interface::Time>(traj, &py_traj);
return m_python_trajectory_generator_interface->scoreTrajectory(py_traj);
}
return 0.0;
}
void PythonTrajectoryCritic::debrief(const nav_2d_msgs::msg::Twist2D & vel)
{
python_lock pl;
if(m_python_trajectory_generator_interface)
{
py_interface::Twist2D py_vel;
py_interface::copy_twist_2d(vel, &py_vel);
return m_python_trajectory_generator_interface->debrief(py_vel);
}
}
#include "pluginlib/class_loader.hpp"
#include "pluginlib/class_list_macros.hpp"
PLUGINLIB_EXPORT_CLASS(air_navigation::PythonTrajectoryCritic, dwb_core::TrajectoryCritic)
#include "air_navigation/python_trajectory_generator.h"
#include "air_navigation/py_interface.h"
#include "air_navigation/python_interpreter.h"
#include "nav_2d_utils/parameters.hpp"
using namespace air_navigation;
void PythonTrajectoryGenerator::initialize(const nav2_util::LifecycleNode::SharedPtr & nh, const std::string & plugin_name)
{
std::string module_name = nav_2d_utils::searchAndGetParam(
nh,
plugin_name + ".air_navigation::PythonTrajectoryGenerator.module_name", std::string());
std::string class_name = nav_2d_utils::searchAndGetParam(
nh,
plugin_name + ".air_navigation::PythonTrajectoryGenerator.class_name", std::string());
RCLCPP_INFO_STREAM(nh->get_logger(), "Python trajectory generator will use: " << plugin_name);
RCLCPP_INFO_STREAM(nh->get_logger(), "Python trajectory generator will use: " << module_name << "." << class_name);
m_python_trajectory_generator_interface = air_navigation::python_interpreter::instance()->create_trajectory_generator(module_name, class_name);;
if(not m_python_trajectory_generator_interface)
{
RCLCPP_ERROR_STREAM(nh->get_logger(), "Failed to trajectory generator critic!");
} else {
python_lock pl;
m_python_trajectory_generator_interface->initialize(plugin_name);
}
}
void PythonTrajectoryGenerator::reset()
{
python_lock pl;
if(m_python_trajectory_generator_interface) m_python_trajectory_generator_interface->reset();
}
void PythonTrajectoryGenerator::startNewIteration(const nav_2d_msgs::msg::Twist2D & current_velocity)
{
python_lock pl;
if(m_python_trajectory_generator_interface)
{
py_interface::Twist2D py_current_velocity;
py_interface::copy_twist_2d(current_velocity, &py_current_velocity);
m_python_trajectory_generator_interface->startNewIteration(py_current_velocity);
}
}
bool PythonTrajectoryGenerator::hasMoreTwists()
{
python_lock pl;
if(m_python_trajectory_generator_interface) return m_python_trajectory_generator_interface->hasMoreTwists();
return false;
}
nav_2d_msgs::msg::Twist2D PythonTrajectoryGenerator::nextTwist()
{
python_lock pl;
nav_2d_msgs::msg::Twist2D twist;
if(m_python_trajectory_generator_interface)
{
py_interface::Twist2D py_twist = m_python_trajectory_generator_interface->nextTwist();
py_interface::copy_twist_2d(py_twist, &twist);
}
return twist;
}
dwb_msgs::msg::Trajectory2D PythonTrajectoryGenerator::generateTrajectory(const geometry_msgs::msg::Pose2D & start_pose, const nav_2d_msgs::msg::Twist2D & start_vel, const nav_2d_msgs::msg::Twist2D & cmd_vel)
{
python_lock pl;
dwb_msgs::msg::Trajectory2D trajectory;
if(m_python_trajectory_generator_interface)
{
py_interface::Pose2D py_start_pose;
py_interface::Twist2D py_start_vel, py_cmd_vel;
py_interface::copy_pose_2d(start_pose, &py_start_pose);
py_interface::copy_twist_2d(start_vel, &py_start_vel);
py_interface::copy_twist_2d(cmd_vel, &py_cmd_vel);
py_interface::Trajectory2D py_trajectory = m_python_trajectory_generator_interface->generateTrajectory(py_start_pose, py_start_vel, py_cmd_vel);
py_interface::copy_trajectory_2d<geometry_msgs::msg::Pose2D, builtin_interfaces::msg::Duration>(py_trajectory, &trajectory);
}
return trajectory;
}
void PythonTrajectoryGenerator::setSpeedLimit(const double & speed_limit, const bool & percentage)
{
python_lock pl;
if(m_python_trajectory_generator_interface) m_python_trajectory_generator_interface->setSpeedLimit(speed_limit, percentage);
}
#include "pluginlib/class_loader.hpp"
#include "pluginlib/class_list_macros.hpp"
PLUGINLIB_EXPORT_CLASS(air_navigation::PythonTrajectoryGenerator, dwb_core::TrajectoryGenerator)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment