diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index b71c1da..0463b8d 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -25,7 +25,6 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include - #include #include @@ -107,10 +106,10 @@ class ModelState : public BackendModel { // onnx file, return in 'session' and 'allocator' the ORT session // and allocator. TRITONSERVER_Error* LoadModel( - const std::string& artifact_name, + const std::string& artifact_name, const std::string& instance_name, const TRITONSERVER_InstanceGroupKind instance_group_kind, const int32_t instance_group_device_id, std::string* model_path, - OrtSession** session, OrtAllocator** default_allocator, + std::shared_ptr& session, OrtAllocator** default_allocator, cudaStream_t stream); const std::map>& ModelOutputs() @@ -127,6 +126,11 @@ class ModelState : public BackendModel { TRITONSERVER_Error* AutoCompleteIO( const char* key, const OnnxTensorInfoMap& io_infos); + TRITONSERVER_Error* GetSessionForGroup( + const std::string& group_name, std::shared_ptr& session); + TRITONSERVER_Error* SetSessionForGroup( + const std::string& group_name, const std::shared_ptr& session); + // Session options used when creating a ORT session. std::unique_ptr session_options_; @@ -136,6 +140,19 @@ class ModelState : public BackendModel { // is specified both in the output section and state section, it indicates // that the backend must return the output state to the client too. std::map> model_outputs_; + + // Indicate if an onnxrt session should be shared or not. This is a model + // global and applies to all instances. So, storing it in the model state + bool share_session_between_instances_; + + // maintain a map of group id to ORT session. This is only useful if + // share_session_between_instances is set to true in parameters. + // share_session_between_instances is a global model config and the user + // should be careful when setting this. There is no way to set this per + // instance group. + std::unordered_map> + group_instance_session_map_; + std::mutex group_instance_session_map_mutex_; }; TRITONSERVER_Error* @@ -206,7 +223,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) } ModelState::ModelState(TRITONBACKEND_Model* triton_model) - : BackendModel(triton_model, true /* allow_optional */) + : BackendModel(triton_model, true /* allow_optional */), share_session_between_instances_(false) { // Create session options that will be cloned and used for each // instance when creating that instance's session. @@ -359,19 +376,30 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } } - // FIXME. Is it possible to share a single OrtSession across - // multiple instances? If so then should move loading and validation - // of the session to here instead of creating a session for each - // instance in ModelStateInstance::Create(). + // This setting will apply across multiple instance groups. + // If this value is set all instances within an instance group will share + // the ort session + { + bool share_session; + triton::common::TritonJson::Value params; + if (ModelConfig().Find("parameters", ¶ms)) { + THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter( + params, "share_session_between_instances", &share_session, false)); + } + share_session_between_instances_ = share_session; + } } TRITONSERVER_Error* ModelState::LoadModel( - const std::string& artifact_name, + const std::string& artifact_name, const std::string& instance_name, const TRITONSERVER_InstanceGroupKind instance_group_kind, const int32_t instance_group_device_id, std::string* model_path, - OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream) + std::shared_ptr& session, OrtAllocator** default_allocator, + cudaStream_t stream) { + // Get the group name for the instance + std::string instance_group_name(GetInstanceGroupName(Name(), instance_name)); // Find the ONNX file that describes the model itself. If the model // configuration doesn't have an explicit model file specified then // use the default name ("model.onnx"). @@ -383,6 +411,10 @@ ModelState::LoadModel( *model_path = JoinPath( {RepositoryPath(), std::to_string(Version()), cc_model_filename}); + // get default cpu allocator + RETURN_IF_ORT_ERROR( + ort_api->GetAllocatorWithDefaultOptions(default_allocator)); + // If the model path is a directory then the actual model is // /model.onnx. { @@ -393,6 +425,32 @@ ModelState::LoadModel( } } + // Check if we are sharing the session. If so get the session pointer and + // return + if (share_session_between_instances_) { + group_instance_session_map_mutex_.lock(); + TRITONSERVER_Error* error = GetSessionForGroup(instance_group_name, session); + if (error == nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Reusing session for instance: ") + instance_name) + .c_str()); + // Session successfully retrieved, unlock the mutex and return the session + group_instance_session_map_mutex_.unlock(); + return nullptr; + } + // In case of error do not release lock and carry on to load session, set it in map + // to enable sharing session with other instances + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Could not find a session corresponding to instance group: ") + instance_group_name) + .c_str()); + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + TRITONSERVER_ErrorMessage(error)); + TRITONSERVER_ErrorDelete(error); + } + { bool exists; RETURN_IF_ERROR(FileExists(*model_path, &exists)); @@ -835,12 +893,23 @@ ModelState::LoadModel( glock.lock(); } + // This will be allocated by OnnxRT here but will be freed when the last + // instance of shared_ptr is released + OrtSession* session_ptr; RETURN_IF_ERROR(OnnxLoader::LoadSession( - true /* is_path */, *model_path, soptions, session)); - - // get default cpu allocator - RETURN_IF_ORT_ERROR( - ort_api->GetAllocatorWithDefaultOptions(default_allocator)); + true /* is_path */, *model_path, soptions, &session_ptr)); + session = std::shared_ptr(session_ptr, SessionDeleter()); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Created session for instance: ") + instance_name) + .c_str()); + if (share_session_between_instances_) { + // The session was created fine this is not a critical error + LOG_IF_ERROR( + SetSessionForGroup(instance_group_name, session), + "Failed to map ort session to the group for sharing"); + group_instance_session_map_mutex_.unlock(); + } return nullptr; // success } @@ -882,9 +951,9 @@ ModelState::AutoCompleteConfig() RETURN_IF_ERROR( ModelConfig().MemberAsString("default_model_filename", &artifact_name)); - // Must cleanup 'session'. 'allocator' is default allocator which + // 'allocator' is default allocator which // is managed by ONNX Runtime so don't need to free/release - std::unique_ptr session; + std::shared_ptr session; OrtAllocator* default_allocator; std::string model_path; { @@ -913,12 +982,9 @@ ModelState::AutoCompleteConfig() } } #endif // TRITON_ENABLE_GPU - - OrtSession* sptr = nullptr; RETURN_IF_ERROR(LoadModel( - artifact_name, kind, 0, &model_path, &sptr, &default_allocator, - nullptr)); - session.reset(sptr); + artifact_name, "", kind, 0, &model_path, + session, &default_allocator, nullptr)); } OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( @@ -1085,6 +1151,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos) return nullptr; // success } +TRITONSERVER_Error* +ModelState::GetSessionForGroup( + const std::string& group_name, std::shared_ptr& session) +{ + RETURN_ERROR_IF_TRUE( + group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG, + std::string("Empty group name")); + { + std::unordered_map>::iterator + session_entry; + session_entry = group_instance_session_map_.find(group_name); + RETURN_ERROR_IF_TRUE( + (session_entry == group_instance_session_map_.end()), + TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group in session map: ") + group_name); + + session = session_entry->second; + } + return nullptr; +} + +TRITONSERVER_Error* +ModelState::SetSessionForGroup( + const std::string& group_name, const std::shared_ptr& session) +{ + RETURN_ERROR_IF_TRUE( + group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG, + std::string("Empty instance group name")); + group_instance_session_map_[group_name] = session; + LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Mapped session for instance group: ") + group_name).c_str()); + return nullptr; +} + // // ModelInstanceState // @@ -1171,7 +1269,7 @@ class ModelInstanceState : public BackendModelInstance { // Onnx Runtime variables that are used across runs on this // instance. - OrtSession* session_; + std::shared_ptr session_; OrtAllocator* default_allocator_; OrtMemoryInfo* cuda_allocator_info_; const OrtMemoryInfo* cpu_allocator_info_; @@ -1223,7 +1321,7 @@ ModelInstanceState::ModelInstanceState( io_binding_(nullptr), output_buffer_(nullptr) { THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel( - ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_, + ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_, &default_allocator_, CudaStream())); if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { @@ -1236,7 +1334,7 @@ ModelInstanceState::ModelInstanceState( ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_)); THROW_IF_BACKEND_INSTANCE_ORT_ERROR( - ort_api->CreateIoBinding(session_, &io_binding_)); + ort_api->CreateIoBinding(session_.get(), &io_binding_)); THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_)); @@ -1335,9 +1433,6 @@ ModelInstanceState::~ModelInstanceState() ort_api->ReleaseRunOptions(runOptions_); ort_api->ReleaseIoBinding(io_binding_); ort_api->ReleaseMemoryInfo(cuda_allocator_info_); - if (session_ != nullptr) { - OnnxLoader::UnloadSession(session_); - } // 'default_allocator_' is default allocator which is managed by ONNX // Runtime } @@ -1399,7 +1494,7 @@ ModelInstanceState::ValidateBooleanSequenceControl( if (*have_control) { OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos)); const auto& iit = input_tensor_infos.find(tensor_name); if (iit == input_tensor_infos.end()) { return TRITONSERVER_ErrorNew( @@ -1456,7 +1551,7 @@ ModelInstanceState::ValidateTypedSequenceControl( if (*have_control) { OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos)); const auto& iit = input_tensor_infos.find(tensor_name); if (iit == input_tensor_infos.end()) { return TRITONSERVER_ErrorNew( @@ -1503,17 +1598,17 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) { std::set input_tensor_names; - RETURN_IF_ERROR(InputNames(session_, input_tensor_names)); + RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names)); RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos_)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos_)); std::set overridable_initializer_tensor_names; RETURN_IF_ERROR(OverridableInitializerNames( - session_, overridable_initializer_tensor_names)); + session_.get(), overridable_initializer_tensor_names)); OnnxTensorInfoMap overridable_initializer_tensor_infos; RETURN_IF_ERROR(OverridableInitializerInfos( - session_, default_allocator_, overridable_initializer_tensor_infos)); + session_.get(), default_allocator_, overridable_initializer_tensor_infos)); if (input_tensor_infos_.size() != expected_input_cnt) { return TRITONSERVER_ErrorNew( @@ -1650,10 +1745,10 @@ TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() { std::set output_tensor_names; - RETURN_IF_ERROR(OutputNames(session_, output_tensor_names)); + RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names)); RETURN_IF_ERROR( - OutputInfos(session_, default_allocator_, output_tensor_infos_)); + OutputInfos(session_.get(), default_allocator_, output_tensor_infos_)); triton::common::TritonJson::Value ios; RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios)); @@ -2050,7 +2145,7 @@ ModelInstanceState::OrtRun( const uint32_t response_count) { RETURN_IF_ORT_ERROR( - ort_api->RunWithBinding(session_, runOptions_, io_binding_)); + ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_)); return nullptr; } @@ -2590,7 +2685,6 @@ ModelInstanceState::ReadOutputTensors( } } - } else { char* output_buffer = nullptr; RETURN_IF_ORT_ERROR( diff --git a/src/onnxruntime_utils.cc b/src/onnxruntime_utils.cc index 5599fb4..f52af0a 100644 --- a/src/onnxruntime_utils.cc +++ b/src/onnxruntime_utils.cc @@ -25,6 +25,7 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "onnxruntime_utils.h" +#include namespace triton { namespace backend { namespace onnxruntime { @@ -550,5 +551,23 @@ CompareDimsSupported( return nullptr; // success } +std::string +GetInstanceGroupName( + const std::string& model_name, const std::string& instance_name) +{ + if (model_name.empty() || instance_name.empty()) { + return ""; + } + // Using regex search to extract instance group name from model instance name + // model instance naming follows pattern: __ + // instance group naming follows pattern: _ + std::regex group_name_regex('(' + model_name + '_' + "[0-9]" + ')'); + std::smatch group_name; + if (std::regex_search(instance_name, group_name, group_name_regex)) { + return group_name.str(1); + } + + return ""; +} }}} // namespace triton::backend::onnxruntime diff --git a/src/onnxruntime_utils.h b/src/onnxruntime_utils.h index f862a74..cd2db2c 100644 --- a/src/onnxruntime_utils.h +++ b/src/onnxruntime_utils.h @@ -157,4 +157,7 @@ TRITONSERVER_Error* CompareDimsSupported( const std::vector& model_shape, const std::vector& dims, const int max_batch_size, const bool compare_exact); +std::string GetInstanceGroupName( + const std::string& model_name, const std::string& instance_name); + }}} // namespace triton::backend::onnxruntime