diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json index e60653295..01a5afdbe 100644 --- a/.vscode/c_cpp_properties.json +++ b/.vscode/c_cpp_properties.json @@ -13,7 +13,7 @@ ], "compilerPath": "/usr/bin/gcc", "cStandard": "c17", - "cppStandard": "gnu++17", + "cppStandard": "c++20", "intelliSenseMode": "linux-gcc-x64" } ], diff --git a/.vscode/settings.json b/.vscode/settings.json index 5605ecfb3..6cee7217e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -6,6 +6,20 @@ "editor.rulers": [ 120 ], - "autopep8.args": ["--max-line-length", "120", "--experimental"], - "pylint.args": ["--generate-members", "--max-line-length", "120", "-d", "C0114", "-d", "C0115", "-d", "C0116"] + "autopep8.args": [ + "--max-line-length", + "120", + "--experimental" + ], + "pylint.args": [ + "--generate-members", + "--max-line-length", + "120", + "-d", + "C0114", + "-d", + "C0115", + "-d", + "C0116" + ], } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 427b673a4..89789a3b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ Release Versions: - chore: format repository (#142) - docs: update schema path in component descriptions (#154) - feat(utils): add binary reader and recorder for encoded states (#152) +- feat!: add support for custom inputs and outputs (#133) ## 4.2.2 diff --git a/source/modulo_components/CMakeLists.txt b/source/modulo_components/CMakeLists.txt index 8a1dfd1e6..2a0d73375 100644 --- a/source/modulo_components/CMakeLists.txt +++ b/source/modulo_components/CMakeLists.txt @@ -6,9 +6,9 @@ if(NOT CMAKE_C_STANDARD) set(CMAKE_C_STANDARD 99) endif() -# default to C++17 +# default to C++20 if(NOT CMAKE_CXX_STANDARD) - set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD 20) endif() if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") diff --git a/source/modulo_components/include/modulo_components/Component.hpp b/source/modulo_components/include/modulo_components/Component.hpp index 0b896dbb6..eade5f7bc 100644 --- a/source/modulo_components/include/modulo_components/Component.hpp +++ b/source/modulo_components/include/modulo_components/Component.hpp @@ -155,6 +155,15 @@ inline void Component::add_output( ->create_publisher_interface(message_pair); break; } + case MessageType::CUSTOM_MESSAGE: { + if constexpr (modulo_core::concepts::CustomT) { + auto publisher = this->create_publisher(topic_name, this->get_qos()); + this->outputs_.at(parsed_signal_name) = + std::make_shared, DataT>>(PublisherType::PUBLISHER, publisher) + ->create_publisher_interface(message_pair); + } + break; + } } } catch (const std::exception& ex) { RCLCPP_ERROR_STREAM(this->get_logger(), "Failed to add output '" << signal_name << "': " << ex.what()); diff --git a/source/modulo_components/include/modulo_components/ComponentInterface.hpp b/source/modulo_components/include/modulo_components/ComponentInterface.hpp index 385f76447..59751fd62 100644 --- a/source/modulo_components/include/modulo_components/ComponentInterface.hpp +++ b/source/modulo_components/include/modulo_components/ComponentInterface.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -646,7 +647,8 @@ inline void ComponentInterface::add_input( std::shared_ptr subscription_interface; switch (message_pair->get_type()) { case MessageType::BOOL: { - auto subscription_handler = std::make_shared>(message_pair); + auto subscription_handler = + std::make_shared>(message_pair, this->node_logging_->get_logger()); auto subscription = rclcpp::create_subscription( this->node_parameters_, this->node_topics_, topic_name, this->qos_, subscription_handler->get_callback(user_callback)); @@ -654,7 +656,8 @@ inline void ComponentInterface::add_input( break; } case MessageType::FLOAT64: { - auto subscription_handler = std::make_shared>(message_pair); + auto subscription_handler = std::make_shared>( + message_pair, this->node_logging_->get_logger()); auto subscription = rclcpp::create_subscription( this->node_parameters_, this->node_topics_, topic_name, this->qos_, subscription_handler->get_callback(user_callback)); @@ -662,8 +665,8 @@ inline void ComponentInterface::add_input( break; } case MessageType::FLOAT64_MULTI_ARRAY: { - auto subscription_handler = - std::make_shared>(message_pair); + auto subscription_handler = std::make_shared>( + message_pair, this->node_logging_->get_logger()); auto subscription = rclcpp::create_subscription( this->node_parameters_, this->node_topics_, topic_name, this->qos_, subscription_handler->get_callback(user_callback)); @@ -671,7 +674,8 @@ inline void ComponentInterface::add_input( break; } case MessageType::INT32: { - auto subscription_handler = std::make_shared>(message_pair); + auto subscription_handler = std::make_shared>( + message_pair, this->node_logging_->get_logger()); auto subscription = rclcpp::create_subscription( this->node_parameters_, this->node_topics_, topic_name, this->qos_, subscription_handler->get_callback(user_callback)); @@ -679,7 +683,8 @@ inline void ComponentInterface::add_input( break; } case MessageType::STRING: { - auto subscription_handler = std::make_shared>(message_pair); + auto subscription_handler = std::make_shared>( + message_pair, this->node_logging_->get_logger()); auto subscription = rclcpp::create_subscription( this->node_parameters_, this->node_topics_, topic_name, this->qos_, subscription_handler->get_callback(user_callback)); @@ -687,13 +692,25 @@ inline void ComponentInterface::add_input( break; } case MessageType::ENCODED_STATE: { - auto subscription_handler = std::make_shared>(message_pair); + auto subscription_handler = std::make_shared>( + message_pair, this->node_logging_->get_logger()); auto subscription = rclcpp::create_subscription( this->node_parameters_, this->node_topics_, topic_name, this->qos_, subscription_handler->get_callback(user_callback)); subscription_interface = subscription_handler->create_subscription_interface(subscription); break; } + case MessageType::CUSTOM_MESSAGE: { + if constexpr (modulo_core::concepts::CustomT) { + auto subscription_handler = + std::make_shared>(message_pair, this->node_logging_->get_logger()); + auto subscription = rclcpp::create_subscription( + this->node_parameters_, this->node_topics_, topic_name, this->qos_, + subscription_handler->get_callback(user_callback)); + subscription_interface = subscription_handler->create_subscription_interface(subscription); + } + break; + } } this->inputs_.insert_or_assign(parsed_signal_name, subscription_interface); } catch (const std::exception& ex) { diff --git a/source/modulo_components/include/modulo_components/LifecycleComponent.hpp b/source/modulo_components/include/modulo_components/LifecycleComponent.hpp index 739bde7c0..a97a4fade 100644 --- a/source/modulo_components/include/modulo_components/LifecycleComponent.hpp +++ b/source/modulo_components/include/modulo_components/LifecycleComponent.hpp @@ -273,6 +273,11 @@ class LifecycleComponent : public rclcpp_lifecycle::LifecycleNode, public Compon using ComponentInterface::publish_outputs; using ComponentInterface::publish_predicates; using rclcpp_lifecycle::LifecycleNode::get_parameter; + + std::map< + std::string, + std::function(const std::string& topic_name)>> + custom_output_configuration_callables_;///< Map of custom output configuration callables }; template @@ -287,9 +292,24 @@ inline void LifecycleComponent::add_output( return; } try { - this->create_output( - modulo_core::communication::PublisherType::LIFECYCLE_PUBLISHER, signal_name, data, default_topic, fixed_topic, - publish_on_step); + using modulo_core::communication::PublisherHandler; + using modulo_core::communication::PublisherType; + + auto parsed_signal_name = this->create_output( + PublisherType::LIFECYCLE_PUBLISHER, signal_name, data, default_topic, fixed_topic, publish_on_step); + + auto message_pair = this->outputs_.at(parsed_signal_name)->get_message_pair(); + if (message_pair->get_type() == modulo_core::communication::MessageType::CUSTOM_MESSAGE) { + if constexpr (modulo_core::concepts::CustomT) { + this->custom_output_configuration_callables_.insert_or_assign( + parsed_signal_name, [this, message_pair](const std::string& topic_name) { + auto publisher = this->create_publisher(topic_name, this->get_qos()); + return std::make_shared, DataT>>( + PublisherType::LIFECYCLE_PUBLISHER, publisher) + ->create_publisher_interface(message_pair); + }); + } + } } catch (const modulo_core::exceptions::AddSignalException& ex) { RCLCPP_ERROR_STREAM(this->get_logger(), "Failed to add output '" << signal_name << "': " << ex.what()); } diff --git a/source/modulo_components/modulo_components/component_interface.py b/source/modulo_components/modulo_components/component_interface.py index a50bf171c..52d2e115c 100644 --- a/source/modulo_components/modulo_components/component_interface.py +++ b/source/modulo_components/modulo_components/component_interface.py @@ -448,6 +448,11 @@ def _create_output(self, signal_name: str, data: str, message_type: MsgT, clprot elif message_type == EncodedState: translator = partial(modulo_writers.write_clproto_message, clproto_message_type=clproto_message_type) + elif hasattr(message_type, 'get_fields_and_field_types'): + def write_ros_msg(message, data): + for field in message.get_fields_and_field_types().keys(): + setattr(message, field, getattr(data, field)) + translator = write_ros_msg else: raise AddSignalError("The provided message type is not supported to create a component output.") self._outputs[parsed_signal_name] = {"attribute": data, "message_type": message_type, @@ -469,7 +474,17 @@ def remove_input(self, signal_name: str): return self.get_logger().debug(f"Removing signal '{signal_name}'.") - def __subscription_callback(self, message: MsgT, attribute_name: str, reader: Callable, user_callback: Callable): + def __read_translated_message(self, message: MsgT, attribute_name: str, reader: Callable): + obj_type = type(self.__getattribute__(attribute_name)) + decoded_message = reader(message) + self.__setattr__(attribute_name, obj_type(decoded_message)) + + def __read_custom_message(self, message: MsgT, attribute_name: str): + for field in message.get_fields_and_field_types().keys(): + setattr(self.__getattribute__(attribute_name), field, getattr(message, field)) + + def __subscription_callback( + self, message: MsgT, attribute_name: str, read_message: Callable, user_callback: Callable): """ Subscription callback for the ROS subscriptions. @@ -478,9 +493,7 @@ def __subscription_callback(self, message: MsgT, attribute_name: str, reader: Ca :param reader: A callable that can read the ROS message and translate to the desired type """ try: - obj_type = type(self.__getattribute__(attribute_name)) - decoded_message = reader(message) - self.__setattr__(attribute_name, obj_type(decoded_message)) + read_message(message, attribute_name) except (AttributeError, MessageTranslationError, TypeError) as e: self.get_logger().warn(f"Failed to read message for attribute {attribute_name}: {e}", throttle_duration_sec=1.0) @@ -491,7 +504,7 @@ def __subscription_callback(self, message: MsgT, attribute_name: str, reader: Ca self.get_logger().error(f"Failed to execute user callback in subscription for attribute" f" '{attribute_name}': {e}", throttle_duration_sec=1.0) - def declare_signal(self, signal_name: str, signal_type: str, default_topic="", fixed_topic=False): + def __declare_signal(self, signal_name: str, signal_type: str, default_topic="", fixed_topic=False): """ Declare an input to create the topic parameter without adding it to the map of inputs yet. @@ -505,7 +518,9 @@ def declare_signal(self, signal_name: str, signal_type: str, default_topic="", f if not parsed_signal_name: raise AddSignalError(topic_validation_warning(signal_name, signal_type)) if signal_name != parsed_signal_name: - self.get_logger().warn(topic_validation_warning(signal_name, signal_type)) + self.get_logger().warn( + f"The parsed name for {signal_type} '{signal_name}' is '{parsed_signal_name}'." + "Use the parsed name to refer to this {signal_type}.") if parsed_signal_name in self._inputs.keys(): raise AddSignalError(f"Signal with name '{parsed_signal_name}' already exists as input.") if parsed_signal_name in self._outputs.keys(): @@ -529,7 +544,7 @@ def declare_input(self, signal_name: str, default_topic="", fixed_topic=False): :param fixed_topic: If true, the topic name of the signal is fixed :raises AddSignalError: if the input could not be declared (empty name or already created) """ - self.declare_signal(signal_name, "input", default_topic, fixed_topic) + self.__declare_signal(signal_name, "input", default_topic, fixed_topic) def declare_output(self, signal_name: str, default_topic="", fixed_topic=False): """ @@ -540,7 +555,7 @@ def declare_output(self, signal_name: str, default_topic="", fixed_topic=False): :param fixed_topic: If true, the topic name of the signal is fixed :raises AddSignalError: if the output could not be declared (empty name or already created) """ - self.declare_signal(signal_name, "output", default_topic, fixed_topic) + self.__declare_signal(signal_name, "output", default_topic, fixed_topic) def add_input(self, signal_name: str, subscription: Union[str, Callable], message_type: MsgT, default_topic="", fixed_topic=False, user_callback: Callable = None): @@ -581,19 +596,31 @@ def default_callback(): user_callback = default_callback if message_type == Bool or message_type == Float64 or \ message_type == Float64MultiArray or message_type == Int32 or message_type == String: + read_message = partial(self.__read_translated_message, + reader=modulo_readers.read_std_message) self._inputs[parsed_signal_name] = \ self.create_subscription(message_type, topic_name, partial(self.__subscription_callback, attribute_name=subscription, - reader=modulo_readers.read_std_message, + read_message=read_message, user_callback=user_callback), self._qos) elif message_type == EncodedState: + read_message = partial(self.__read_translated_message, + reader=modulo_readers.read_clproto_message) + self._inputs[parsed_signal_name] = \ + self.create_subscription(message_type, topic_name, + partial(self.__subscription_callback, + attribute_name=subscription, + read_message=read_message, + user_callback=user_callback), + self._qos) + elif hasattr(message_type, 'get_fields_and_field_types'): self._inputs[parsed_signal_name] = \ self.create_subscription(message_type, topic_name, partial(self.__subscription_callback, attribute_name=subscription, - reader=modulo_readers.read_clproto_message, + read_message=self.__read_custom_message, user_callback=user_callback), self._qos) else: diff --git a/source/modulo_components/src/ComponentInterface.cpp b/source/modulo_components/src/ComponentInterface.cpp index e6258071a..191fc528d 100644 --- a/source/modulo_components/src/ComponentInterface.cpp +++ b/source/modulo_components/src/ComponentInterface.cpp @@ -303,7 +303,9 @@ void ComponentInterface::declare_signal( } if (signal_name != parsed_signal_name) { RCLCPP_WARN_STREAM( - this->node_logging_->get_logger(), modulo_utils::parsing::topic_validation_warning(signal_name, type)); + this->node_logging_->get_logger(), + "The parsed name for " + type + " '" + signal_name + "' is '" + parsed_signal_name + + "'. Use the parsed name to refer to this " + type); } if (this->inputs_.find(parsed_signal_name) != this->inputs_.cend()) { throw exceptions::AddSignalException("Signal with name '" + parsed_signal_name + "' already exists as input."); diff --git a/source/modulo_components/src/LifecycleComponent.cpp b/source/modulo_components/src/LifecycleComponent.cpp index 7d1cf59fb..2b5c503ae 100644 --- a/source/modulo_components/src/LifecycleComponent.cpp +++ b/source/modulo_components/src/LifecycleComponent.cpp @@ -292,6 +292,10 @@ bool LifecycleComponent::configure_outputs() { ->create_publisher_interface(message_pair); break; } + case MessageType::CUSTOM_MESSAGE: { + interface = this->custom_output_configuration_callables_.at(name)(topic_name); + break; + } } } catch (const modulo_core::exceptions::CoreException& ex) { success = false; diff --git a/source/modulo_components/test/cpp/include/test_modulo_components/communication_components.hpp b/source/modulo_components/test/cpp/include/test_modulo_components/communication_components.hpp index 67aee016e..11367d90a 100644 --- a/source/modulo_components/test/cpp/include/test_modulo_components/communication_components.hpp +++ b/source/modulo_components/test/cpp/include/test_modulo_components/communication_components.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -47,6 +48,38 @@ class MinimalCartesianInput : public ComponentT { std::shared_ptr input; std::shared_future received_future; +private: + std::promise received_; +}; + +template +class MinimalTwistOutput : public ComponentT { +public: + MinimalTwistOutput( + const rclcpp::NodeOptions& node_options, const std::string& topic, + std::shared_ptr twist, bool publish_on_step) + : ComponentT(node_options, "minimal_twist_output"), output_(twist) { + this->add_output("twist", this->output_, topic, true, publish_on_step); + } + + void publish() { this->publish_output("twist"); } + +private: + std::shared_ptr output_; +}; + +template +class MinimalTwistInput : public ComponentT { +public: + MinimalTwistInput(const rclcpp::NodeOptions& node_options, const std::string& topic) + : ComponentT(node_options, "minimal_twist_input"), input(std::make_shared()) { + this->received_future = this->received_.get_future(); + this->add_input("twist", this->input, [this]() { this->received_.set_value(); }, topic); + } + + std::shared_ptr input; + std::shared_future received_future; + private: std::promise received_; }; diff --git a/source/modulo_components/test/cpp/test_component.cpp b/source/modulo_components/test/cpp/test_component.cpp index c227fe266..d9047ce7e 100644 --- a/source/modulo_components/test/cpp/test_component.cpp +++ b/source/modulo_components/test/cpp/test_component.cpp @@ -1,6 +1,8 @@ #include +#include #include +#include #include @@ -70,5 +72,24 @@ TEST_F(ComponentTest, AddRemoveOutput) { EXPECT_NO_THROW(component_->publish_output("8_teEsTt_#1@3")); EXPECT_NO_THROW(component_->publish_output("test_13")); EXPECT_THROW(component_->publish_output(""), modulo_core::exceptions::CoreException); + + auto std_msg_data = std::make_shared(); + std_msg_data->data = "foo"; + component_->add_output("custom_msg_test", std_msg_data); + EXPECT_TRUE(component_->outputs_.find("custom_msg_test") != component_->outputs_.end()); + EXPECT_NO_THROW(component_->outputs_.at("custom_msg_test")->publish()); + EXPECT_THROW(component_->publish_output("custom_msg_test"), modulo_core::exceptions::CoreException); + + auto geometry_msg_data = std::make_shared(); + component_->add_output("geometry_msg_test", geometry_msg_data); + EXPECT_TRUE(component_->outputs_.find("geometry_msg_test") != component_->outputs_.end()); + EXPECT_NO_THROW(component_->outputs_.at("geometry_msg_test")->publish()); + EXPECT_THROW(component_->publish_output("geometry_msg_test"), modulo_core::exceptions::CoreException); + + auto sensor_msg_data = std::make_shared(); + component_->add_output("sensor_msg_test", sensor_msg_data); + EXPECT_TRUE(component_->outputs_.find("sensor_msg_test") != component_->outputs_.end()); + EXPECT_NO_THROW(component_->outputs_.at("sensor_msg_test")->publish()); + EXPECT_THROW(component_->publish_output("sensor_msg_test"), modulo_core::exceptions::CoreException); } }// namespace modulo_components diff --git a/source/modulo_components/test/cpp/test_component_communication.cpp b/source/modulo_components/test/cpp/test_component_communication.cpp index 099c0d91a..4a2afb6fd 100644 --- a/source/modulo_components/test/cpp/test_component_communication.cpp +++ b/source/modulo_components/test/cpp/test_component_communication.cpp @@ -58,6 +58,20 @@ TEST_F(ComponentCommunicationTest, InputOutputManual) { EXPECT_TRUE(cartesian_state.data().isApprox(input_node->input->data())); } +TEST_F(ComponentCommunicationTest, TwistInputOutput) { + auto twist = std::make_shared(); + twist->linear.x = 1.0; + auto input_node = std::make_shared>(rclcpp::NodeOptions(), "/topic"); + auto output_node = + std::make_shared>(rclcpp::NodeOptions(), "/topic", twist, true); + this->exec_->add_node(input_node); + this->exec_->add_node(output_node); + auto return_code = this->exec_->spin_until_future_complete(input_node->received_future, 500ms); + ASSERT_EQ(return_code, rclcpp::FutureReturnCode::SUCCESS); + EXPECT_EQ(twist->linear.x, input_node->input->linear.x); + EXPECT_THROW(output_node->publish(), modulo_core::exceptions::CoreException); +} + TEST_F(ComponentCommunicationTest, Trigger) { auto trigger = std::make_shared(rclcpp::NodeOptions()); auto listener = diff --git a/source/modulo_components/test/cpp/test_component_interface.cpp b/source/modulo_components/test/cpp/test_component_interface.cpp index 13e31fb63..305fce693 100644 --- a/source/modulo_components/test/cpp/test_component_interface.cpp +++ b/source/modulo_components/test/cpp/test_component_interface.cpp @@ -8,6 +8,8 @@ #include "test_modulo_components/component_public_interfaces.hpp" +#include + namespace modulo_components { using namespace std::chrono_literals; @@ -107,6 +109,9 @@ TYPED_TEST(ComponentInterfaceTest, AddRemoveInput) { // remove input this->component_->remove_input("test_13"); EXPECT_TRUE(this->component_->inputs_.find("test_13") == this->component_->inputs_.end()); + + EXPECT_NO_THROW(this->component_->add_input("sensor_msg_data", std::make_shared())); + EXPECT_FALSE(this->component_->inputs_.find("sensor_msg_data") == this->component_->inputs_.end()); } TYPED_TEST(ComponentInterfaceTest, AddInputWithUserCallback) { diff --git a/source/modulo_components/test/cpp/test_lifecycle_component_communication.cpp b/source/modulo_components/test/cpp/test_lifecycle_component_communication.cpp index 61be340f6..ffe72deb7 100644 --- a/source/modulo_components/test/cpp/test_lifecycle_component_communication.cpp +++ b/source/modulo_components/test/cpp/test_lifecycle_component_communication.cpp @@ -65,6 +65,20 @@ TEST_F(LifecycleComponentCommunicationTest, InputOutputManual) { EXPECT_TRUE(cartesian_state.data().isApprox(input_node->input->data())); } +TEST_F(LifecycleComponentCommunicationTest, TwistInputOutput) { + auto twist = std::make_shared(); + twist->linear.x = 1.0; + auto input_node = std::make_shared>(rclcpp::NodeOptions(), "/topic"); + auto output_node = + std::make_shared>(rclcpp::NodeOptions(), "/topic", twist, true); + add_configure_activate(this->exec_, input_node); + add_configure_activate(this->exec_, output_node); + auto return_code = this->exec_->spin_until_future_complete(input_node->received_future, 500ms); + ASSERT_EQ(return_code, rclcpp::FutureReturnCode::SUCCESS); + EXPECT_EQ(twist->linear.x, input_node->input->linear.x); + EXPECT_THROW(output_node->publish(), modulo_core::exceptions::CoreException); +} + TEST_F(LifecycleComponentCommunicationTest, Trigger) { auto trigger = std::make_shared(rclcpp::NodeOptions()); auto listener = diff --git a/source/modulo_components/test/python/conftest.py b/source/modulo_components/test/python/conftest.py index 033e9aa9c..c2c1dd1ed 100644 --- a/source/modulo_components/test/python/conftest.py +++ b/source/modulo_components/test/python/conftest.py @@ -4,6 +4,7 @@ from modulo_components.component import Component from modulo_core import EncodedState from rclpy.task import Future +from sensor_msgs.msg import JointState pytest_plugins = ["modulo_utils.testutils.ros", "modulo_utils.testutils.lifecycle_change_client", "modulo_utils.testutils.service_client", "modulo_utils.testutils.predicates_listener"] @@ -19,6 +20,17 @@ def random_joint(): return sr.JointState().Random("test", 3) +@pytest.fixture +def random_sensor(): + random = sr.JointState().Random("test", 3) + msg = JointState() + msg.name = random.get_names() + msg.position = random.get_positions().tolist() + msg.velocity = random.get_velocities().tolist() + msg.effort = random.get_torques().tolist() + return msg + + @pytest.fixture def minimal_cartesian_output(request, random_pose): def _make_minimal_cartesian_output(component_type, topic, publish_on_step): @@ -51,6 +63,21 @@ def publish(self): yield _make_minimal_cartesian_output(request.param[0], request.param[1], request.param[2]) +@pytest.fixture +def minimal_sensor_output(request, random_sensor): + def _make_minimal_sensor_output(component_type, topic, publish_on_step): + def publish(self): + self.publish_output("sensor_state") + + component = component_type("minimal_sensor_output") + component._output = random_sensor + component.add_output("sensor_state", "_output", JointState, default_topic=topic, publish_on_step=publish_on_step) + component.publish = publish.__get__(component) + return component + + yield _make_minimal_sensor_output(request.param[0], request.param[1], request.param[2]) + + class MinimalInvalidEncodedStatePublisher(Component): def __init__(self, topic, *args, **kwargs): super().__init__("minimal_invalid_encoded_state_publisher", *args, **kwargs) @@ -83,3 +110,16 @@ def _make_minimal_cartesian_input(component_type, topic): return component yield _make_minimal_cartesian_input(request.param[0], request.param[1]) + + +@pytest.fixture +def minimal_sensor_input(request): + def _make_minimal_sensor_input(component_type, topic): + component = component_type("minimal_sensor_input") + component.received_future = Future() + component.input = JointState() + component.add_input("sensor_state", "input", JointState, topic, + user_callback=lambda: component.received_future.set_result(True)) + return component + + yield _make_minimal_sensor_input(request.param[0], request.param[1]) diff --git a/source/modulo_components/test/python/test_component.py b/source/modulo_components/test/python/test_component.py index 7156ecb09..a7d8e2834 100644 --- a/source/modulo_components/test/python/test_component.py +++ b/source/modulo_components/test/python/test_component.py @@ -24,7 +24,7 @@ def test_add_remove_output(component): assert component._outputs["test_13"]["message_type"] == Bool component.remove_output("test_13") - assert "test_13" not in component._inputs.keys() + assert "test_13" not in component._outputs.keys() component.add_output("8_teEsTt_#1@3", "test", Bool, publish_on_step=False) assert not component._periodic_outputs["test_13"] diff --git a/source/modulo_components/test/python/test_component_communication.py b/source/modulo_components/test/python/test_component_communication.py index 713ce2ea4..57a77b83b 100644 --- a/source/modulo_components/test/python/test_component_communication.py +++ b/source/modulo_components/test/python/test_component_communication.py @@ -39,6 +39,20 @@ def test_input_output_manual(ros_exec, random_pose, minimal_cartesian_output, mi assert random_pose.dist(minimal_cartesian_input.input) < 1e-3 +@pytest.mark.parametrize("minimal_sensor_input", [[Component, "/topic"]], indirect=True) +@pytest.mark.parametrize("minimal_sensor_output", [[Component, "/topic", False]], indirect=True) +def test_input_output_manual(ros_exec, random_sensor, minimal_sensor_output, minimal_sensor_input): + ros_exec.add_node(minimal_sensor_input) + ros_exec.add_node(minimal_sensor_output) + ros_exec.spin_until_future_complete(minimal_sensor_input.received_future, timeout_sec=0.5) + assert not minimal_sensor_input.received_future.done() + minimal_sensor_output.publish() + ros_exec.spin_until_future_complete(minimal_sensor_input.received_future, timeout_sec=0.5) + assert minimal_sensor_input.received_future.done() and minimal_sensor_input.received_future.result() + for key in random_sensor.get_fields_and_field_types().keys(): + assert getattr(random_sensor, key) == getattr(minimal_sensor_input.input, key) + + @pytest.mark.parametrize("minimal_cartesian_input", [[Component, "/topic"]], indirect=True) @pytest.mark.parametrize("minimal_joint_output", [[Component, "/topic", True]], indirect=True) def test_input_output_invalid_type(ros_exec, minimal_joint_output, minimal_cartesian_input): diff --git a/source/modulo_components/test/python/test_component_interface.py b/source/modulo_components/test/python/test_component_interface.py index d3579226f..4c9ea809b 100644 --- a/source/modulo_components/test/python/test_component_interface.py +++ b/source/modulo_components/test/python/test_component_interface.py @@ -10,6 +10,7 @@ from modulo_core.exceptions import CoreError, LookupTransformError from rclpy.qos import QoSProfile from std_msgs.msg import Bool, String +from sensor_msgs.msg import JointState def raise_(ex): @@ -152,19 +153,25 @@ def test_create_output(component_interface): assert component_interface._periodic_outputs["test"] component_interface._create_output( - "8_teEsTt_#1@3", - "test", - Bool, - clproto.MessageType.UNKNOWN_MESSAGE, - "", - True, - False) + "8_teEsTt_#1@3", "test", Bool, clproto.MessageType.UNKNOWN_MESSAGE, "", True, False) assert not component_interface._periodic_outputs["test_13"] component_interface.publish_output("8_teEsTt_#1@3") component_interface.publish_output("test_13") with pytest.raises(CoreError): component_interface.publish_output("") + component_interface._create_output("test_custom", "test", JointState, + clproto.MessageType.UNKNOWN_MESSAGE, "/topic", True, True) + assert "test_custom" in component_interface._outputs.keys() + assert component_interface.get_parameter_value("test_custom_topic") == "/topic" + assert component_interface._outputs["test_custom"]["message_type"] == JointState + data = JointState() + data.name = ["joint_1", "joint_2"] + msg = JointState() + component_interface._outputs["test_custom"]["translator"](msg, data) + assert msg.name == data.name + assert component_interface._periodic_outputs["test_custom"] + def test_tf(component_interface): component_interface.add_tf_broadcaster() diff --git a/source/modulo_components/test/python/test_lifecycle_component.py b/source/modulo_components/test/python/test_lifecycle_component.py index 22c52e4ba..278743471 100644 --- a/source/modulo_components/test/python/test_lifecycle_component.py +++ b/source/modulo_components/test/python/test_lifecycle_component.py @@ -24,7 +24,7 @@ def test_add_remove_output(lifecycle_component): assert lifecycle_component._outputs["test_13"]["message_type"] == Bool lifecycle_component.remove_output("test_13") - assert "test_13" not in lifecycle_component._inputs.keys() + assert "test_13" not in lifecycle_component._outputs.keys() lifecycle_component.add_output("8_teEsTt_#1@3", "test", Bool, publish_on_step=False) assert not lifecycle_component._periodic_outputs["test_13"] diff --git a/source/modulo_controllers/CMakeLists.txt b/source/modulo_controllers/CMakeLists.txt index 60c285ec1..2fdbdfd88 100644 --- a/source/modulo_controllers/CMakeLists.txt +++ b/source/modulo_controllers/CMakeLists.txt @@ -6,10 +6,10 @@ if (NOT CMAKE_C_STANDARD) set(CMAKE_C_STANDARD 99) endif () -# default to C++17 -if (NOT CMAKE_CXX_STANDARD) - set(CMAKE_CXX_STANDARD 17) -endif () +# default to C++20 +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 20) +endif() if (CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") add_compile_options(-Wall -Wextra -Wpedantic) diff --git a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp index 0b49f6289..bde578168 100644 --- a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp +++ b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -22,6 +23,8 @@ #include +#include + namespace modulo_controllers { typedef std::variant< @@ -30,7 +33,7 @@ typedef std::variant< std::shared_ptr>, std::shared_ptr>, std::shared_ptr>, - std::shared_ptr>> + std::shared_ptr>, std::any> SubscriptionVariant; typedef std::variant< @@ -39,7 +42,7 @@ typedef std::variant< realtime_tools::RealtimeBuffer>, realtime_tools::RealtimeBuffer>, realtime_tools::RealtimeBuffer>, - realtime_tools::RealtimeBuffer>> + realtime_tools::RealtimeBuffer>, std::any> BufferVariant; typedef std::tuple< @@ -66,9 +69,11 @@ typedef std::pair< std::shared_ptr>, realtime_tools::RealtimePublisherSharedPtr> StringPublishers; +typedef std::pair CustomPublishers; typedef std::variant< - EncodedStatePublishers, BoolPublishers, DoublePublishers, DoubleVecPublishers, IntPublishers, StringPublishers> + EncodedStatePublishers, BoolPublishers, DoublePublishers, DoubleVecPublishers, IntPublishers, StringPublishers, + CustomPublishers> PublisherVariant; /** @@ -76,6 +81,7 @@ typedef std::variant< * @brief Input structure to save topic data in a realtime buffer and timestamps in one object. */ struct ControllerInput { + ControllerInput() = default; ControllerInput(BufferVariant buffer_variant) : buffer(std::move(buffer_variant)) {} BufferVariant buffer; std::chrono::time_point timestamp; @@ -471,6 +477,11 @@ class BaseControllerInterface : public controller_interface::ControllerInterface std::shared_ptr predicate_timer_; std::timed_mutex command_mutex_; + + std::map> + custom_output_configuration_callables_; + std::map> + custom_input_configuration_callables_; }; template @@ -515,11 +526,36 @@ inline void BaseControllerInterface::set_parameter_value(const std::string& name template inline void BaseControllerInterface::add_input(const std::string& name, const std::string& topic_name) { - auto buffer = realtime_tools::RealtimeBuffer>(); - auto input = ControllerInput(buffer); - create_input(input, name, topic_name); - input_message_pairs_.insert_or_assign( - name, modulo_core::communication::make_shared_message_pair(std::make_shared(), get_node()->get_clock())); + if constexpr (modulo_core::concepts::CustomT) { + auto buffer = std::make_shared>>(); + auto input = ControllerInput(buffer); + auto parsed_name = validate_and_declare_signal(name, "input", topic_name); + if (!parsed_name.empty()) { + inputs_.insert_or_assign(parsed_name, input); + custom_input_configuration_callables_.insert_or_assign( + name, [this](const std::string& name, const std::string& topic) { + auto subscription = + get_node()->create_subscription(topic, qos_, [this, name](const std::shared_ptr message) { + auto buffer_variant = std::get(inputs_.at(name).buffer); + auto buffer = std::any_cast>>>( + buffer_variant); + buffer->writeFromNonRT(message); + inputs_.at(name).timestamp = std::chrono::steady_clock::now(); + }); + subscriptions_.push_back(subscription); + }); + } + } else { + auto buffer = realtime_tools::RealtimeBuffer>(); + auto input = ControllerInput(buffer); + auto parsed_name = validate_and_declare_signal(name, "input", topic_name); + if (!parsed_name.empty()) { + inputs_.insert_or_assign(parsed_name, input); + input_message_pairs_.insert_or_assign( + parsed_name, + modulo_core::communication::make_shared_message_pair(std::make_shared(), get_node()->get_clock())); + } + } } template<> @@ -569,8 +605,22 @@ BaseControllerInterface::create_subscription(const std::string& name, const std: template inline void BaseControllerInterface::add_output(const std::string& name, const std::string& topic_name) { - std::shared_ptr state_ptr = std::make_shared(); - create_output(EncodedStatePublishers(state_ptr, {}, {}), name, topic_name); + if constexpr (modulo_core::concepts::CustomT) { + typedef std::pair>, realtime_tools::RealtimePublisherSharedPtr> PublisherT; + auto parsed_name = validate_and_declare_signal(name, "output", topic_name); + if (!parsed_name.empty()) { + outputs_.insert_or_assign(parsed_name, PublisherT()); + custom_output_configuration_callables_.insert_or_assign( + name, [this](CustomPublishers& pub, const std::string& topic) { + auto publisher = get_node()->create_publisher(topic, qos_); + pub.first = publisher; + pub.second = std::make_shared>(publisher); + }); + } + } else { + std::shared_ptr state_ptr = std::make_shared(); + create_output(EncodedStatePublishers(state_ptr, {}, {}), name, topic_name); + } } template<> @@ -604,33 +654,45 @@ inline std::optional BaseControllerInterface::read_input(const std::string& n if (!check_input_valid(name)) { return {}; } - auto message = - **std::get>>(inputs_.at(name).buffer) - .readFromNonRT(); - std::shared_ptr state; - try { - auto message_pair = input_message_pairs_.at(name); - message_pair->read(message); - state = message_pair->get_message_pair()->get_data(); - } catch (const std::exception& ex) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what()); - return {}; - } - if (state->is_empty()) { + + if constexpr (modulo_core::concepts::CustomT) { + try { + auto buffer_variant = std::get(inputs_.at(name).buffer); + auto buffer = std::any_cast>>>(buffer_variant); + return **(buffer->readFromNonRT()); + } catch (const std::bad_any_cast& ex) { + RCLCPP_ERROR(get_node()->get_logger(), "Failed to read custom input: %s", ex.what()); + } return {}; - } - auto cast_ptr = std::dynamic_pointer_cast(state); - if (cast_ptr != nullptr) { - return *cast_ptr; } else { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(), - get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str()); + auto message = + **std::get>>(inputs_.at(name).buffer) + .readFromNonRT(); + std::shared_ptr state; + try { + auto message_pair = input_message_pairs_.at(name); + message_pair->read(message); + state = message_pair->get_message_pair()->get_data(); + } catch (const std::exception& ex) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what()); + return {}; + } + if (state->is_empty()) { + return {}; + } + auto cast_ptr = std::dynamic_pointer_cast(state); + if (cast_ptr != nullptr) { + return *cast_ptr; + } else { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(), + get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str()); + } + return {}; } - return {}; } template<> @@ -689,44 +751,71 @@ inline std::optional BaseControllerInterface::read_input inline void BaseControllerInterface::write_output(const std::string& name, const T& data) { - if (data.is_empty()) { - RCLCPP_DEBUG_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Skipping publication of output '%s' due to emptiness of state", name.c_str()); - return; - } if (outputs_.find(name) == outputs_.end()) { RCLCPP_WARN_THROTTLE( get_node()->get_logger(), *get_node()->get_clock(), 1000, "Could not find output '%s'", name.c_str()); return; } - EncodedStatePublishers publishers; - try { - publishers = std::get(outputs_.at(name)); - } catch (const std::bad_variant_access&) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); - return; - } - if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')", - state_representation::get_state_type_name(output_type).c_str(), - state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str()); - return; - } - auto rt_pub = std::get<2>(publishers); - if (rt_pub && rt_pub->trylock()) { + + if constexpr (modulo_core::concepts::CustomT) { + CustomPublishers publishers; try { - modulo_core::translators::write_message(rt_pub->msg_, data, get_node()->get_clock()->now()); - } catch (const modulo_core::exceptions::MessageTranslationException& ex) { + publishers = std::get(outputs_.at(name)); + } catch (const std::bad_variant_access&) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); + return; + } + + std::shared_ptr> rt_pub; + try { + rt_pub = std::any_cast>>(publishers.second); + } catch (const std::bad_any_cast& ex) { RCLCPP_ERROR_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(), - ex.what()); + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to wrong data type: %s", name.c_str(), ex.what()); + return; + } + if (rt_pub && rt_pub->trylock()) { + rt_pub->msg_ = data; + rt_pub->unlockAndPublish(); + } + } else { + if (data.is_empty()) { + RCLCPP_DEBUG_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to emptiness of state", name.c_str()); + return; + } + EncodedStatePublishers publishers; + try { + publishers = std::get(outputs_.at(name)); + } catch (const std::bad_variant_access&) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); + return; + } + if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')", + state_representation::get_state_type_name(output_type).c_str(), + state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str()); + return; + } + auto rt_pub = std::get<2>(publishers); + if (rt_pub && rt_pub->trylock()) { + try { + modulo_core::translators::write_message(rt_pub->msg_, data, get_node()->get_clock()->now()); + } catch (const modulo_core::exceptions::MessageTranslationException& ex) { + RCLCPP_ERROR_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(), + ex.what()); + } + rt_pub->unlockAndPublish(); } - rt_pub->unlockAndPublish(); } } diff --git a/source/modulo_controllers/src/BaseControllerInterface.cpp b/source/modulo_controllers/src/BaseControllerInterface.cpp index 54b2600fd..e07ef5e26 100644 --- a/source/modulo_controllers/src/BaseControllerInterface.cpp +++ b/source/modulo_controllers/src/BaseControllerInterface.cpp @@ -330,7 +330,7 @@ void BaseControllerInterface::create_input( const ControllerInput& input, const std::string& name, const std::string& topic_name) { auto parsed_name = validate_and_declare_signal(name, "input", topic_name); if (!parsed_name.empty()) { - inputs_.insert_or_assign(name, input); + inputs_.insert_or_assign(parsed_name, input); } } @@ -357,6 +357,9 @@ void BaseControllerInterface::add_inputs() { }, [&](const realtime_tools::RealtimeBuffer>&) { subscriptions_.push_back(create_subscription(name, topic)); + }, + [&](const std::any&) { + custom_input_configuration_callables_.at(name)(name, topic); }}, input.buffer); } catch (const std::exception& ex) { @@ -369,7 +372,7 @@ void BaseControllerInterface::create_output( const PublisherVariant& publishers, const std::string& name, const std::string& topic_name) { auto parsed_name = validate_and_declare_signal(name, "output", topic_name); if (!parsed_name.empty()) { - outputs_.insert_or_assign(name, publishers); + outputs_.insert_or_assign(parsed_name, publishers); } } @@ -403,10 +406,15 @@ void BaseControllerInterface::add_outputs() { [&](StringPublishers& pub) { pub.first = get_node()->create_publisher(topic, qos_); pub.second = std::make_shared>(pub.first); + }, + [&](CustomPublishers& pub) { + custom_output_configuration_callables_.at(name)(pub, name); }}, publishers); + } catch (const std::bad_any_cast& ex) { + RCLCPP_ERROR(get_node()->get_logger(), "Failed to add custom output '%s': %s", name.c_str(), ex.what()); } catch (const std::exception& ex) { - RCLCPP_ERROR(get_node()->get_logger(), "Failed to add input '%s': %s", name.c_str(), ex.what()); + RCLCPP_ERROR(get_node()->get_logger(), "Failed to add output '%s': %s", name.c_str(), ex.what()); } } } diff --git a/source/modulo_controllers/test/test_controller_interface.cpp b/source/modulo_controllers/test/test_controller_interface.cpp index 8fceda3d4..67cf4f053 100644 --- a/source/modulo_controllers/test/test_controller_interface.cpp +++ b/source/modulo_controllers/test/test_controller_interface.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -26,6 +27,12 @@ class FriendControllerInterface : public ControllerInterface { } }; +sensor_msgs::msg::Image make_image_msg(double width) { + auto msg = sensor_msgs::msg::Image(); + msg.width = width; + return msg; +} + using BoolT = std::tuple; using DoubleT = std::tuple; using DoubleVecT = std::tuple, std_msgs::msg::Float64MultiArray>; @@ -33,6 +40,7 @@ using IntT = std::tuple; using StringT = std::tuple; using CartesianStateT = std::tuple; using JointStateT = std::tuple; +using ImageT = std::tuple; template T write_std_msg(const T& message_data) { @@ -49,6 +57,12 @@ T write_state_msg(const T& message_data) { return copy; } +ImageT write_image_msg(const ImageT& message_data) { + auto copy = message_data; + std::get<1>(copy) = std::get<0>(message_data); + return copy; +} + template T read_std_msg(const T& message_data) { auto copy = message_data; @@ -63,6 +77,12 @@ T read_state_msg(const T& message_data) { return copy; } +ImageT read_image_msg(const ImageT& message_data) { + auto copy = message_data; + std::get<0>(copy) = std::get<1>(message_data); + return copy; +} + template bool std_msg_equal(const T& sent, const T& received) { return std::get<0>(sent) == std::get<0>(received); @@ -74,12 +94,16 @@ bool encoded_state_equal(const T& sent, const T& received) { return equal && std::get<0>(sent).data().isApprox(std::get<0>(received).data()); } +bool sensor_msg_equal(const ImageT& sent, const ImageT& received) { + return std::get<0>(sent).width == std::get<0>(received).width; +} + template using SignalT = std::vector, std::function, std::function>>; static std::tuple< SignalT, SignalT, SignalT, SignalT, SignalT, SignalT, - SignalT> + SignalT, SignalT> signal_test_cases{ {std::make_tuple( std::make_tuple(true, std_msgs::msg::Bool()), write_std_msg, read_std_msg, @@ -100,7 +124,9 @@ static std::tuple< write_state_msg, read_state_msg, encoded_state_equal)}, {std::make_tuple( std::make_tuple(JointState::Random("test", 3), modulo_core::EncodedState()), write_state_msg, - read_state_msg, encoded_state_equal)}}; + read_state_msg, encoded_state_equal)}, + {std::make_tuple( + std::make_tuple(make_image_msg(1), make_image_msg(2)), write_image_msg, read_image_msg, sensor_msg_equal)}}; template class ControllerInterfaceTest : public ::testing::Test { @@ -177,7 +203,6 @@ TYPED_TEST_P(ControllerInterfaceTest, OutputTest) { for (auto [message_data, write_func, read_func, validation_func] : this->test_cases_) { auto data = std::get<0>(message_data); this->interface_->template write_output("output", data); - // rclcpp::spin_some(this->interface_->get_node()->get_node_base_interface()); auto return_code = rclcpp::spin_until_future_complete(test_node.get_node_base_interface(), test_node.get_sub_future(), 200ms); ASSERT_EQ(return_code, rclcpp::FutureReturnCode::SUCCESS); @@ -188,5 +213,5 @@ TYPED_TEST_P(ControllerInterfaceTest, OutputTest) { REGISTER_TYPED_TEST_CASE_P(ControllerInterfaceTest, ConfigureErrorTest, InputTest, OutputTest); -typedef ::testing::Types SignalTypes; +typedef ::testing::Types SignalTypes; INSTANTIATE_TYPED_TEST_CASE_P(TestPrefix, ControllerInterfaceTest, SignalTypes); diff --git a/source/modulo_core/CMakeLists.txt b/source/modulo_core/CMakeLists.txt index 0a3485821..37e447c91 100644 --- a/source/modulo_core/CMakeLists.txt +++ b/source/modulo_core/CMakeLists.txt @@ -6,9 +6,9 @@ if(NOT CMAKE_C_STANDARD) set(CMAKE_C_STANDARD 99) endif() -# default to C++17 +# default to C++20 if(NOT CMAKE_CXX_STANDARD) - set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD 20) endif() if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") diff --git a/source/modulo_core/include/modulo_core/communication/MessagePair.hpp b/source/modulo_core/include/modulo_core/communication/MessagePair.hpp index 20e0bb338..53077045d 100644 --- a/source/modulo_core/include/modulo_core/communication/MessagePair.hpp +++ b/source/modulo_core/include/modulo_core/communication/MessagePair.hpp @@ -3,6 +3,7 @@ #include #include "modulo_core/communication/MessagePairInterface.hpp" +#include "modulo_core/concepts.hpp" #include "modulo_core/translators/message_readers.hpp" #include "modulo_core/translators/message_writers.hpp" @@ -25,6 +26,15 @@ class MessagePair : public MessagePairInterface { */ MessagePair(std::shared_ptr data, std::shared_ptr clock); + /** + * @brief Constructor of the MessagePair that requires custom message types only + * @param data The pointer referring to the data stored in the MessagePair + * @param clock The ROS clock for translating messages + */ + MessagePair(std::shared_ptr data, std::shared_ptr clock) + requires concepts::CustomT && concepts::CustomT + : MessagePairInterface(MessageType::CUSTOM_MESSAGE), data_(std::move(data)), clock_(std::move(clock)) {} + /** * @brief Write the value of the data pointer to a ROS message. * @return The value of the data pointer as a ROS message @@ -53,6 +63,47 @@ class MessagePair : public MessagePairInterface { void set_data(const std::shared_ptr& data); private: + /** + * @brief Write the value of the data pointer to a ROS message using modulo translators. + * @return The value of the data pointer as a ROS message + * @throws modulo_core::exceptions::MessageTranslationException if the data could not be written to message + */ + [[nodiscard]] MsgT write_translated_message() const; + + /** + * @brief Write the value of the data pointer to a ROS message using modulo translators for encoded states. + * @return The value of the data pointer as a ROS message + * @throws modulo_core::exceptions::MessageTranslationException if the data could not be written to message + */ + [[nodiscard]] MsgT write_encoded_message() const; + + /** + * @brief Write the value of the data pointer to a ROS message by direct assignment. + * @return The value of the data pointer as a ROS message + */ + [[nodiscard]] MsgT write_raw_message() const; + + /** + * @brief Read a ROS message and store the value in the data pointer using modulo translators. + * @param message The ROS message to read + * @throws modulo_core::exceptions::MessageTranslationException if the message could not be read + */ + void read_translated_message(const MsgT& message); + + /** + * @brief Read a ROS message and store the value in the data pointer using modulo translators for encoded states. + * @param message The ROS message to read + * @throws modulo_core::exceptions::MessageTranslationException if the message could not be read + */ + void read_encoded_message(const MsgT& message); + + /** + * @brief Read a ROS message and store the value in the data pointer by direct assignment. + * @param message The ROS message to read + * @throws modulo_core::exceptions::MessageTranslationException if the message could not be read + */ + void read_raw_message(const MsgT& message); + std::shared_ptr data_; ///< Pointer referring to the data stored in the MessagePair std::shared_ptr clock_;///< ROS clock for translating messages }; @@ -62,37 +113,67 @@ inline MsgT MessagePair::write_message() const { if (this->data_ == nullptr) { throw exceptions::NullPointerException("The message pair data is not set, nothing to write"); } + + MsgT message; + if constexpr (concepts::CustomT && concepts::CustomT) { + message = write_raw_message(); + } else if constexpr (std::same_as) { + message = write_encoded_message(); + } else { + message = write_translated_message(); + } + return message; +} + +template +inline MsgT MessagePair::write_translated_message() const { auto message = MsgT(); translators::write_message(message, *this->data_, clock_->now()); return message; } template<> -inline EncodedState MessagePair::write_message() const { - if (this->data_ == nullptr) { - throw exceptions::NullPointerException("The message pair data is not set, nothing to write"); - } +inline EncodedState MessagePair::write_encoded_message() const { auto message = EncodedState(); translators::write_message(message, this->data_, clock_->now()); return message; } +template +inline MsgT MessagePair::write_raw_message() const { + return *this->data_; +} + template inline void MessagePair::read_message(const MsgT& message) { if (this->data_ == nullptr) { throw exceptions::NullPointerException("The message pair data is not set, nothing to read"); } + + if constexpr (concepts::CustomT && concepts::CustomT) { + read_raw_message(message); + } else if constexpr (std::same_as) { + read_encoded_message(message); + } else { + read_translated_message(message); + } +} + +template +inline void MessagePair::read_translated_message(const MsgT& message) { translators::read_message(*this->data_, message); } template<> -inline void MessagePair::read_message(const EncodedState& message) { - if (this->data_ == nullptr) { - throw exceptions::NullPointerException("The message pair data is not set, nothing to read"); - } +inline void MessagePair::read_encoded_message(const EncodedState& message) { translators::read_message(this->data_, message); } +template +inline void MessagePair::read_raw_message(const MsgT& message) { + *this->data_ = message; +} + template inline std::shared_ptr MessagePair::get_data() const { return this->data_; @@ -106,38 +187,44 @@ inline void MessagePair::set_data(const std::shared_ptr& dat this->data_ = data; } -template +template inline std::shared_ptr make_shared_message_pair(const std::shared_ptr& data, const std::shared_ptr& clock) { return std::make_shared>( std::dynamic_pointer_cast(data), clock); } -template<> +template +inline std::shared_ptr +make_shared_message_pair(const std::shared_ptr& data, const std::shared_ptr& clock) { + return std::make_shared>(data, clock); +} + +template inline std::shared_ptr make_shared_message_pair(const std::shared_ptr& data, const std::shared_ptr& clock) { return std::make_shared>(data, clock); } -template<> +template inline std::shared_ptr make_shared_message_pair(const std::shared_ptr& data, const std::shared_ptr& clock) { return std::make_shared>(data, clock); } -template<> +template> inline std::shared_ptr make_shared_message_pair( const std::shared_ptr>& data, const std::shared_ptr& clock) { return std::make_shared>>(data, clock); } -template<> +template inline std::shared_ptr make_shared_message_pair(const std::shared_ptr& data, const std::shared_ptr& clock) { return std::make_shared>(data, clock); } -template<> +template inline std::shared_ptr make_shared_message_pair(const std::shared_ptr& data, const std::shared_ptr& clock) { return std::make_shared>(data, clock); diff --git a/source/modulo_core/include/modulo_core/communication/MessageType.hpp b/source/modulo_core/include/modulo_core/communication/MessageType.hpp index 6b387e6a7..1ae03e45f 100644 --- a/source/modulo_core/include/modulo_core/communication/MessageType.hpp +++ b/source/modulo_core/include/modulo_core/communication/MessageType.hpp @@ -10,5 +10,5 @@ namespace modulo_core::communication { * @brief Enum of all supported ROS message types for the MessagePairInterface * @see MessagePairInterface */ -enum class MessageType { BOOL, FLOAT64, FLOAT64_MULTI_ARRAY, INT32, STRING, ENCODED_STATE }; +enum class MessageType { BOOL, FLOAT64, FLOAT64_MULTI_ARRAY, INT32, STRING, ENCODED_STATE, CUSTOM_MESSAGE }; }// namespace modulo_core::communication diff --git a/source/modulo_core/include/modulo_core/communication/PublisherHandler.hpp b/source/modulo_core/include/modulo_core/communication/PublisherHandler.hpp index 8775234f7..177ffcc1f 100644 --- a/source/modulo_core/include/modulo_core/communication/PublisherHandler.hpp +++ b/source/modulo_core/include/modulo_core/communication/PublisherHandler.hpp @@ -4,6 +4,8 @@ #include "modulo_core/communication/PublisherInterface.hpp" +#include + namespace modulo_core::communication { /** @@ -29,16 +31,19 @@ class PublisherHandler : public PublisherInterface { ~PublisherHandler() override; /** - * @brief Activate the ROS publisher if applicable. - * @throws modulo_core::exceptions::NullPointerException if the publisher pointer is null - */ - void on_activate(); + * @copydoc PublisherInterface::activate + */ + virtual void activate() override; /** - * @brief Deactivate the ROS publisher if applicable. - * @throws modulo_core::exceptions::NullPointerException if the publisher pointer is null - */ - void on_deactivate(); + * @copydoc PublisherInterface::deactivate + */ + virtual void deactivate() override; + + /** + * @copydoc PublisherInterface::publish + */ + void publish() override; /** * @brief Publish the ROS message through the ROS publisher. @@ -56,6 +61,8 @@ class PublisherHandler : public PublisherInterface { private: std::shared_ptr publisher_;///< The ROS publisher + + using PublisherInterface::message_pair_; }; template @@ -68,34 +75,46 @@ PublisherHandler::~PublisherHandler() { } template -void PublisherHandler::on_activate() { - if (this->publisher_ == nullptr) { - throw exceptions::NullPointerException("Publisher not set"); - } - if (this->get_type() == PublisherType::LIFECYCLE_PUBLISHER) { +inline void PublisherHandler::activate() { + if constexpr (std::derived_from>) { + if (this->publisher_ == nullptr) { + throw exceptions::NullPointerException("Publisher not set"); + } try { this->publisher_->on_activate(); } catch (const std::exception& ex) { throw exceptions::CoreException(ex.what()); } - } else { - RCLCPP_DEBUG(rclcpp::get_logger("PublisherHandler"), "Only LifecyclePublishers can be deactivated"); } } template -void PublisherHandler::on_deactivate() { - if (this->publisher_ == nullptr) { - throw exceptions::NullPointerException("Publisher not set"); - } - if (this->get_type() == PublisherType::LIFECYCLE_PUBLISHER) { +inline void PublisherHandler::deactivate() { + if constexpr (std::derived_from>) { + if (this->publisher_ == nullptr) { + throw exceptions::NullPointerException("Publisher not set"); + } try { this->publisher_->on_deactivate(); } catch (const std::exception& ex) { throw exceptions::CoreException(ex.what()); } - } else { - RCLCPP_DEBUG(rclcpp::get_logger("PublisherHandler"), "Only LifecyclePublishers can be deactivated"); + } +} + +template +inline void PublisherHandler::publish() { + try { + if constexpr (concepts::CustomT && !concepts::TranslatedMsgT) { + if (this->message_pair_ == nullptr) { + throw exceptions::NullPointerException("Message pair is not set, nothing to publish"); + } + publish(this->message_pair_->write()); + } else { + PublisherInterface::publish(); + } + } catch (const exceptions::CoreException& ex) { + throw; } } @@ -112,7 +131,7 @@ void PublisherHandler::publish(const MsgT& message) const { } template -std::shared_ptr +inline std::shared_ptr PublisherHandler::create_publisher_interface(const std::shared_ptr& message_pair) { std::shared_ptr publisher_interface; try { diff --git a/source/modulo_core/include/modulo_core/communication/PublisherInterface.hpp b/source/modulo_core/include/modulo_core/communication/PublisherInterface.hpp index d0280ea50..aa291134a 100644 --- a/source/modulo_core/include/modulo_core/communication/PublisherInterface.hpp +++ b/source/modulo_core/include/modulo_core/communication/PublisherInterface.hpp @@ -65,7 +65,7 @@ class PublisherInterface : public std::enable_shared_from_this message_pair_;///< The pointer to the stored MessagePair instance + private: /** * @brief Publish the data stored in the message pair through the ROS publisher of a derived PublisherHandler instance @@ -120,8 +123,7 @@ class PublisherInterface : public std::enable_shared_from_this void publish(const MsgT& message); - PublisherType type_; ///< The type of the publisher interface - std::shared_ptr message_pair_;///< The pointer to the stored MessagePair instance + PublisherType type_;///< The type of the publisher interface }; template diff --git a/source/modulo_core/include/modulo_core/communication/SubscriptionHandler.hpp b/source/modulo_core/include/modulo_core/communication/SubscriptionHandler.hpp index d590ff7d4..1baad707a 100644 --- a/source/modulo_core/include/modulo_core/communication/SubscriptionHandler.hpp +++ b/source/modulo_core/include/modulo_core/communication/SubscriptionHandler.hpp @@ -2,6 +2,8 @@ #include "modulo_core/communication/SubscriptionInterface.hpp" +#include "modulo_core/concepts.hpp" + namespace modulo_core::communication { /** @@ -15,8 +17,11 @@ class SubscriptionHandler : public SubscriptionInterface { /** * @brief Constructor with the message pair. * @param message_pair The pointer to the message pair with the data that should be updated through the subscription + * @param logger An optional ROS logger to do logging from the subscription callback */ - explicit SubscriptionHandler(std::shared_ptr message_pair = nullptr); + explicit SubscriptionHandler( + std::shared_ptr message_pair = nullptr, + const rclcpp::Logger& logger = rclcpp::get_logger("SubscriptionHandler")); /** * @brief Destructor to explicitly reset the subscription pointer. @@ -54,7 +59,7 @@ class SubscriptionHandler : public SubscriptionInterface { std::function)> get_callback(const std::function& user_callback); /** - * @brief Create a SubscriptionInterface pointer through an instance of a SubscriptionHandler by providing a ROS + * @brief Create a SubscriptionInterface pointer through an instance of a SubscriptionHandler by providing a ROS * subscription. * @details This throws a NullPointerException if the ROS subscription is null. * @see SubscriptionHandler::set_subscription @@ -70,12 +75,34 @@ class SubscriptionHandler : public SubscriptionInterface { */ void handle_callback_exceptions(); + /** + * @brief Get a callback function that will be associated with the ROS subscription to receive and translate + * internally supported messages. + * @details This variant also takes a user callback function to execute after the message is received and translated. + * @param user_callback Void callback function for additional logic after the message is received and translated. + */ + std::function)> get_translated_callback(); + + /** + * @brief Get a callback function that will be associated with the ROS subscription to receive and translate generic + * messages. + * @details This variant also takes a user callback function to execute after the message is received and translated. + * @param user_callback Void callback function for additional logic after the message is received and translated. + */ + std::function)> get_raw_callback(); + std::shared_ptr> subscription_;///< The pointer referring to the ROS subscription + rclcpp::Logger logger_; ///< ROS logger for logging warnings std::shared_ptr clock_; ///< ROS clock for throttling log std::function user_callback_ = [] { };///< User callback to be executed after the subscription callback }; +template +SubscriptionHandler::SubscriptionHandler( + std::shared_ptr message_pair, const rclcpp::Logger& logger) + : SubscriptionInterface(std::move(message_pair)), logger_(logger), clock_(std::make_shared()) {} + template SubscriptionHandler::~SubscriptionHandler() { this->subscription_.reset(); @@ -94,6 +121,27 @@ void SubscriptionHandler::set_subscription(const std::shared_ptrsubscription_ = subscription; } +template +inline std::function)> SubscriptionHandler::get_callback() { + if constexpr (concepts::TranslatedMsgT) { + return get_translated_callback(); + } else { + return get_raw_callback(); + } +} + +template +std::function)> SubscriptionHandler::get_raw_callback() { + return [this](const std::shared_ptr message) { + try { + this->get_message_pair()->template read(*message); + this->user_callback_(); + } catch (...) { + this->handle_callback_exceptions(); + } + }; +} + template void SubscriptionHandler::set_user_callback(const std::function& user_callback) { this->user_callback_ = user_callback; @@ -120,12 +168,10 @@ void SubscriptionHandler::handle_callback_exceptions() { throw; } catch (const exceptions::CoreException& ex) { RCLCPP_WARN_STREAM_THROTTLE( - rclcpp::get_logger("SubscriptionHandler"), *this->clock_, 1000, - "Exception in subscription callback: " << ex.what()); + this->logger_, *this->clock_, 1000, "Exception in subscription callback: " << ex.what()); } catch (const std::exception& ex) { RCLCPP_WARN_STREAM_THROTTLE( - rclcpp::get_logger("SubscriptionHandler"), *this->clock_, 1000, - "Unhandled exception in subscription user callback: " << ex.what()); + this->logger_, *this->clock_, 1000, "Unhandled exception in subscription user callback: " << ex.what()); } } diff --git a/source/modulo_core/include/modulo_core/communication/SubscriptionInterface.hpp b/source/modulo_core/include/modulo_core/communication/SubscriptionInterface.hpp index be25a5bd2..618b45d47 100644 --- a/source/modulo_core/include/modulo_core/communication/SubscriptionInterface.hpp +++ b/source/modulo_core/include/modulo_core/communication/SubscriptionInterface.hpp @@ -66,7 +66,7 @@ class SubscriptionInterface : public std::enable_shared_from_this& message_pair); -private: +protected: std::shared_ptr message_pair_;///< The pointer to the stored MessagePair instance }; diff --git a/source/modulo_core/include/modulo_core/concepts.hpp b/source/modulo_core/include/modulo_core/concepts.hpp new file mode 100644 index 000000000..a721241f4 --- /dev/null +++ b/source/modulo_core/include/modulo_core/concepts.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "modulo_core/EncodedState.hpp" + +#include +#include +#include +#include +#include +#include + +namespace modulo_core::concepts { + +// Data type concepts + +template +concept PrimitiveDataT = std::same_as || std::same_as || std::same_as> + || std::same_as || std::same_as; + +template +concept CoreDataT = std::derived_from || PrimitiveDataT; + +// Message type concepts + +template +concept TranslatedMsgT = std::same_as || std::same_as + || std::same_as || std::same_as + || std::same_as || std::same_as; + +template +concept CustomT = !CoreDataT && !std::same_as; + +}// namespace modulo_core::concepts diff --git a/source/modulo_core/src/communication/PublisherInterface.cpp b/source/modulo_core/src/communication/PublisherInterface.cpp index 0dac4a7e5..aa3fd9ef1 100644 --- a/source/modulo_core/src/communication/PublisherInterface.cpp +++ b/source/modulo_core/src/communication/PublisherInterface.cpp @@ -15,72 +15,16 @@ namespace modulo_core::communication { PublisherInterface::PublisherInterface(PublisherType type, std::shared_ptr message_pair) - : type_(type), message_pair_(std::move(message_pair)) {} + : message_pair_(std::move(message_pair)), type_(type) {} void PublisherInterface::activate() { - if (this->message_pair_ == nullptr) { - throw exceptions::NullPointerException("Message pair is not set, cannot deduce message type"); - } - switch (this->message_pair_->get_type()) { - case MessageType::BOOL: - this->template get_handler, std_msgs::msg::Bool>() - ->on_activate(); - break; - case MessageType::FLOAT64: - this->template get_handler, std_msgs::msg::Float64>() - ->on_activate(); - break; - case MessageType::FLOAT64_MULTI_ARRAY: - this->template get_handler< - rclcpp_lifecycle::LifecyclePublisher, - std_msgs::msg::Float64MultiArray>() - ->on_activate(); - break; - case MessageType::INT32: - this->template get_handler, std_msgs::msg::Int32>() - ->on_activate(); - break; - case MessageType::STRING: - this->template get_handler, std_msgs::msg::String>() - ->on_activate(); - break; - case MessageType::ENCODED_STATE: - this->template get_handler, EncodedState>()->on_activate(); - break; - } + throw exceptions::CoreException( + "The derived publisher handler is required to override this function to handle activation"); } void PublisherInterface::deactivate() { - if (this->message_pair_ == nullptr) { - throw exceptions::NullPointerException("Message pair is not set, cannot deduce message type"); - } - switch (this->message_pair_->get_type()) { - case MessageType::BOOL: - this->template get_handler, std_msgs::msg::Bool>() - ->on_deactivate(); - break; - case MessageType::FLOAT64: - this->template get_handler, std_msgs::msg::Float64>() - ->on_deactivate(); - break; - case MessageType::FLOAT64_MULTI_ARRAY: - this->template get_handler< - rclcpp_lifecycle::LifecyclePublisher, - std_msgs::msg::Float64MultiArray>() - ->on_deactivate(); - break; - case MessageType::INT32: - this->template get_handler, std_msgs::msg::Int32>() - ->on_deactivate(); - break; - case MessageType::STRING: - this->template get_handler, std_msgs::msg::String>() - ->on_deactivate(); - break; - case MessageType::ENCODED_STATE: - this->template get_handler, EncodedState>()->on_deactivate(); - break; - } + throw exceptions::CoreException( + "The derived publisher handler is required to override this function to handle deactivation"); } void PublisherInterface::publish() { @@ -111,6 +55,8 @@ void PublisherInterface::publish() { this->publish(this->message_pair_->write()); } break; + default: + break; } } catch (const exceptions::CoreException& ex) { throw; diff --git a/source/modulo_core/src/communication/SubscriptionHandler.cpp b/source/modulo_core/src/communication/SubscriptionHandler.cpp index baed8a1bd..a1cf6134c 100644 --- a/source/modulo_core/src/communication/SubscriptionHandler.cpp +++ b/source/modulo_core/src/communication/SubscriptionHandler.cpp @@ -4,33 +4,9 @@ namespace modulo_core::communication { -template<> -SubscriptionHandler::SubscriptionHandler(std::shared_ptr message_pair) - : SubscriptionInterface(std::move(message_pair)), clock_(std::make_shared()) {} - -template<> -SubscriptionHandler::SubscriptionHandler(std::shared_ptr message_pair) - : SubscriptionInterface(std::move(message_pair)), clock_(std::make_shared()) {} -template<> -SubscriptionHandler::SubscriptionHandler( - std::shared_ptr message_pair) - : SubscriptionInterface(std::move(message_pair)), clock_(std::make_shared()) {} - -template<> -SubscriptionHandler::SubscriptionHandler(std::shared_ptr message_pair) - : SubscriptionInterface(std::move(message_pair)), clock_(std::make_shared()) {} - -template<> -SubscriptionHandler::SubscriptionHandler(std::shared_ptr message_pair) - : SubscriptionInterface(std::move(message_pair)), clock_(std::make_shared()) {} - -template<> -SubscriptionHandler::SubscriptionHandler(std::shared_ptr message_pair) - : SubscriptionInterface(std::move(message_pair)), clock_(std::make_shared()) {} - template<> std::function)> -SubscriptionHandler::get_callback() { +SubscriptionHandler::get_translated_callback() { return [this](const std::shared_ptr message) { try { this->get_message_pair()->template read(*message); @@ -43,7 +19,7 @@ SubscriptionHandler::get_callback() { template<> std::function)> -SubscriptionHandler::get_callback() { +SubscriptionHandler::get_translated_callback() { return [this](const std::shared_ptr message) { try { this->get_message_pair()->template read(*message); @@ -56,7 +32,7 @@ SubscriptionHandler::get_callback() { template<> std::function)> -SubscriptionHandler::get_callback() { +SubscriptionHandler::get_translated_callback() { return [this](const std::shared_ptr message) { try { this->get_message_pair()->template read>(*message); @@ -69,7 +45,7 @@ SubscriptionHandler::get_callback() { template<> std::function)> -SubscriptionHandler::get_callback() { +SubscriptionHandler::get_translated_callback() { return [this](const std::shared_ptr message) { try { this->get_message_pair()->template read(*message); @@ -82,7 +58,7 @@ SubscriptionHandler::get_callback() { template<> std::function)> -SubscriptionHandler::get_callback() { +SubscriptionHandler::get_translated_callback() { return [this](const std::shared_ptr message) { try { this->get_message_pair()->template read(*message); @@ -94,7 +70,7 @@ SubscriptionHandler::get_callback() { } template<> -std::function)> SubscriptionHandler::get_callback() { +std::function)> SubscriptionHandler::get_translated_callback() { return [this](const std::shared_ptr message) { try { this->get_message_pair()->template read(*message); @@ -104,4 +80,5 @@ std::function)> SubscriptionHandler +#include + using namespace modulo_core::communication; class CommunicationTest : public ::testing::Test { @@ -60,3 +63,17 @@ TEST_F(CommunicationTest, BasicTypes) { this->communicate(1, 2); this->communicate("this", "that"); } + +TEST_F(CommunicationTest, CustomTypes) { + sensor_msgs::msg::Image initial_image; + initial_image.height = 480; + sensor_msgs::msg::Image new_image; + new_image.height = 320; + this->communicate(initial_image, new_image); + + sensor_msgs::msg::Imu initial_imu; + initial_imu.linear_acceleration.x = 1.0; + sensor_msgs::msg::Imu new_imu; + new_imu.linear_acceleration.x = 0.5; + this->communicate(initial_imu, new_imu); +} diff --git a/source/modulo_core/test/cpp/communication/test_message_pair.cpp b/source/modulo_core/test/cpp/communication/test_message_pair.cpp index 3de869a01..9ad7af733 100644 --- a/source/modulo_core/test/cpp/communication/test_message_pair.cpp +++ b/source/modulo_core/test/cpp/communication/test_message_pair.cpp @@ -1,6 +1,7 @@ #include #include "modulo_core/communication/MessagePair.hpp" +#include using namespace modulo_core::communication; @@ -28,6 +29,29 @@ test_message_interface(const DataT& initial_value, const DataT& new_value, const EXPECT_EQ(initial_value, *message_pair->get_data()); } +template +static void test_custom_message_interface( + const DataT& initial_value, const DataT& new_value, const std::shared_ptr clock) { + auto data = std::make_shared(initial_value); + auto message_pair = std::make_shared>(data, clock); + EXPECT_EQ(initial_value, *message_pair->get_data()); + EXPECT_EQ(initial_value, message_pair->write_message()); + + std::shared_ptr message_pair_interface(message_pair); + auto message = message_pair_interface->template write(); + EXPECT_EQ(initial_value, message); + + *data = new_value; + EXPECT_EQ(new_value, *message_pair->get_data()); + EXPECT_EQ(new_value, message_pair->write_message()); + message = message_pair_interface->template write(); + EXPECT_EQ(new_value, message); + + message = initial_value; + message_pair_interface->template read(message); + EXPECT_EQ(initial_value, *message_pair->get_data()); +} + class MessagePairTest : public ::testing::Test { protected: void SetUp() override { clock_ = std::make_shared(); } @@ -73,3 +97,21 @@ TEST_F(MessagePairTest, EncodedState) { EXPECT_TRUE(initial_value.data().isApprox( std::dynamic_pointer_cast(message_pair->get_data())->data())); } + +TEST_F(MessagePairTest, GenericTypes) { + float initial_float = 0.1; + float new_float = 0.2; + test_custom_message_interface(initial_float, new_float, clock_); + + geometry_msgs::msg::Twist initial_value; + initial_value.linear.x = 0.1; + geometry_msgs::msg::Twist new_value; + initial_value.linear.x = 42; + test_custom_message_interface(initial_value, new_value, clock_); + + sensor_msgs::msg::Image initial_img; + initial_img.height = 320; + sensor_msgs::msg::Image new_img; + new_img.height = 480; + test_custom_message_interface(initial_img, new_img, clock_); +} diff --git a/source/modulo_core/test/cpp/communication/test_publisher_handler.cpp b/source/modulo_core/test/cpp/communication/test_publisher_handler.cpp index 35d59b391..a2a1c0a10 100644 --- a/source/modulo_core/test/cpp/communication/test_publisher_handler.cpp +++ b/source/modulo_core/test/cpp/communication/test_publisher_handler.cpp @@ -5,6 +5,9 @@ #include "modulo_core/communication/MessagePair.hpp" #include "modulo_core/communication/PublisherHandler.hpp" +#include +#include + using namespace modulo_core::communication; template @@ -70,3 +73,8 @@ TEST_F(PublisherTest, EncodedState) { publisher_interface = publisher_handler->create_publisher_interface(message_pair); EXPECT_NO_THROW(publisher_interface->publish()); } + +TEST_F(PublisherTest, CustomTypes) { + test_publisher_interface(node, sensor_msgs::msg::Image()); + test_publisher_interface(node, sensor_msgs::msg::Imu()); +} diff --git a/source/modulo_core/test/cpp/communication/test_subscription_handler.cpp b/source/modulo_core/test/cpp/communication/test_subscription_handler.cpp index 2844df324..c3e3b22b5 100644 --- a/source/modulo_core/test/cpp/communication/test_subscription_handler.cpp +++ b/source/modulo_core/test/cpp/communication/test_subscription_handler.cpp @@ -4,6 +4,9 @@ #include "modulo_core/communication/SubscriptionHandler.hpp" +#include +#include + using namespace modulo_core::communication; template @@ -55,3 +58,13 @@ TEST_F(SubscriptionTest, EncodedState) { // use in subscription interface auto subscription_interface = subscription_handler->create_subscription_interface(subscription); } + +TEST_F(SubscriptionTest, CustomTypes) { + sensor_msgs::msg::Image image; + image.height = 480; + test_subscription_interface(node, image); + + sensor_msgs::msg::Imu imu; + imu.linear_acceleration.x = 1.0; + test_subscription_interface(node, imu); +} diff --git a/source/modulo_utils/CMakeLists.txt b/source/modulo_utils/CMakeLists.txt index 9de7fe03c..33a710628 100644 --- a/source/modulo_utils/CMakeLists.txt +++ b/source/modulo_utils/CMakeLists.txt @@ -6,9 +6,9 @@ if(NOT CMAKE_C_STANDARD) set(CMAKE_C_STANDARD 99) endif() -# default to C++17 +# default to C++20 if(NOT CMAKE_CXX_STANDARD) - set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD 20) endif() if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")