Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sharing an ORT session #248

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 132 additions & 38 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <stdint.h>

#include <mutex>
#include <vector>

@@ -107,10 +106,10 @@ class ModelState : public BackendModel {
// onnx file, return in 'session' and 'allocator' the ORT session
// and allocator.
TRITONSERVER_Error* LoadModel(
const std::string& artifact_name,
const std::string& artifact_name, const std::string& instance_name,
const TRITONSERVER_InstanceGroupKind instance_group_kind,
const int32_t instance_group_device_id, std::string* model_path,
OrtSession** session, OrtAllocator** default_allocator,
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
cudaStream_t stream);

const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
@@ -127,6 +126,11 @@ class ModelState : public BackendModel {
TRITONSERVER_Error* AutoCompleteIO(
const char* key, const OnnxTensorInfoMap& io_infos);

TRITONSERVER_Error* GetSessionForGroup(
const std::string& group_name, std::shared_ptr<OrtSession>& session);
TRITONSERVER_Error* SetSessionForGroup(
const std::string& group_name, const std::shared_ptr<OrtSession>& session);

// Session options used when creating a ORT session.
std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;

@@ -136,6 +140,19 @@ class ModelState : public BackendModel {
// is specified both in the output section and state section, it indicates
// that the backend must return the output state to the client too.
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;

// Indicate if an onnxrt session should be shared or not. This is a model
// global and applies to all instances. So, storing it in the model state
bool share_session_between_instances_;

// maintain a map of group id to ORT session. This is only useful if
// share_session_between_instances is set to true in parameters.
// share_session_between_instances is a global model config and the user
// should be careful when setting this. There is no way to set this per
// instance group.
std::unordered_map<std::string, std::shared_ptr<OrtSession>>
group_instance_session_map_;
std::mutex group_instance_session_map_mutex_;
};

TRITONSERVER_Error*
@@ -206,7 +223,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
}

ModelState::ModelState(TRITONBACKEND_Model* triton_model)
: BackendModel(triton_model, true /* allow_optional */)
: BackendModel(triton_model, true /* allow_optional */), share_session_between_instances_(false)
{
// Create session options that will be cloned and used for each
// instance when creating that instance's session.
@@ -359,19 +376,30 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
}
}

// FIXME. Is it possible to share a single OrtSession across
// multiple instances? If so then should move loading and validation
// of the session to here instead of creating a session for each
// instance in ModelStateInstance::Create().
// This setting will apply across multiple instance groups.
// If this value is set all instances within an instance group will share
// the ort session
{
bool share_session;
triton::common::TritonJson::Value params;
if (ModelConfig().Find("parameters", &params)) {
THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter(
params, "share_session_between_instances", &share_session, false));
}
share_session_between_instances_ = share_session;
}
}

TRITONSERVER_Error*
ModelState::LoadModel(
const std::string& artifact_name,
const std::string& artifact_name, const std::string& instance_name,
const TRITONSERVER_InstanceGroupKind instance_group_kind,
const int32_t instance_group_device_id, std::string* model_path,
OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
cudaStream_t stream)
{
// Get the group name for the instance
std::string instance_group_name(GetInstanceGroupName(Name(), instance_name));
// Find the ONNX file that describes the model itself. If the model
// configuration doesn't have an explicit model file specified then
// use the default name ("model.onnx").
@@ -383,6 +411,10 @@ ModelState::LoadModel(
*model_path = JoinPath(
{RepositoryPath(), std::to_string(Version()), cc_model_filename});

// get default cpu allocator
RETURN_IF_ORT_ERROR(
ort_api->GetAllocatorWithDefaultOptions(default_allocator));

// If the model path is a directory then the actual model is
// <dir>/model.onnx.
{
@@ -393,6 +425,32 @@ ModelState::LoadModel(
}
}

// Check if we are sharing the session. If so get the session pointer and
// return
if (share_session_between_instances_) {
group_instance_session_map_mutex_.lock();
TRITONSERVER_Error* error = GetSessionForGroup(instance_group_name, session);
if (error == nullptr) {
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Reusing session for instance: ") + instance_name)
.c_str());
// Session successfully retrieved, unlock the mutex and return the session
group_instance_session_map_mutex_.unlock();
return nullptr;
}
// In case of error do not release lock and carry on to load session, set it in map
// to enable sharing session with other instances
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Could not find a session corresponding to instance group: ") + instance_group_name)
.c_str());
LOG_MESSAGE(
TRITONSERVER_LOG_VERBOSE,
TRITONSERVER_ErrorMessage(error));
TRITONSERVER_ErrorDelete(error);
}

{
bool exists;
RETURN_IF_ERROR(FileExists(*model_path, &exists));
@@ -835,12 +893,23 @@ ModelState::LoadModel(
glock.lock();
}

// This will be allocated by OnnxRT here but will be freed when the last
// instance of shared_ptr is released
OrtSession* session_ptr;
RETURN_IF_ERROR(OnnxLoader::LoadSession(
true /* is_path */, *model_path, soptions, session));

// get default cpu allocator
RETURN_IF_ORT_ERROR(
ort_api->GetAllocatorWithDefaultOptions(default_allocator));
true /* is_path */, *model_path, soptions, &session_ptr));
session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter());
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Created session for instance: ") + instance_name)
.c_str());
if (share_session_between_instances_) {
// The session was created fine this is not a critical error
LOG_IF_ERROR(
SetSessionForGroup(instance_group_name, session),
"Failed to map ort session to the group for sharing");
group_instance_session_map_mutex_.unlock();
}

return nullptr; // success
}
@@ -882,9 +951,9 @@ ModelState::AutoCompleteConfig()
RETURN_IF_ERROR(
ModelConfig().MemberAsString("default_model_filename", &artifact_name));

// Must cleanup 'session'. 'allocator' is default allocator which
// 'allocator' is default allocator which
// is managed by ONNX Runtime so don't need to free/release
std::unique_ptr<OrtSession, SessionDeleter> session;
std::shared_ptr<OrtSession> session;
OrtAllocator* default_allocator;
std::string model_path;
{
@@ -913,12 +982,9 @@ ModelState::AutoCompleteConfig()
}
}
#endif // TRITON_ENABLE_GPU

OrtSession* sptr = nullptr;
RETURN_IF_ERROR(LoadModel(
artifact_name, kind, 0, &model_path, &sptr, &default_allocator,
nullptr));
session.reset(sptr);
artifact_name, "", kind, 0, &model_path,
session, &default_allocator, nullptr));
}
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
@@ -1085,6 +1151,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
return nullptr; // success
}

TRITONSERVER_Error*
ModelState::GetSessionForGroup(
const std::string& group_name, std::shared_ptr<OrtSession>& session)
{
RETURN_ERROR_IF_TRUE(
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
std::string("Empty group name"));
{
std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
session_entry;
session_entry = group_instance_session_map_.find(group_name);
RETURN_ERROR_IF_TRUE(
(session_entry == group_instance_session_map_.end()),
TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group in session map: ") + group_name);

session = session_entry->second;
}
return nullptr;
}

TRITONSERVER_Error*
ModelState::SetSessionForGroup(
const std::string& group_name, const std::shared_ptr<OrtSession>& session)
{
RETURN_ERROR_IF_TRUE(
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
std::string("Empty instance group name"));
group_instance_session_map_[group_name] = session;
LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Mapped session for instance group: ") + group_name).c_str());
return nullptr;
}

//
// ModelInstanceState
//
@@ -1171,7 +1269,7 @@ class ModelInstanceState : public BackendModelInstance {

// Onnx Runtime variables that are used across runs on this
// instance.
OrtSession* session_;
std::shared_ptr<OrtSession> session_;
OrtAllocator* default_allocator_;
OrtMemoryInfo* cuda_allocator_info_;
const OrtMemoryInfo* cpu_allocator_info_;
@@ -1223,7 +1321,7 @@ ModelInstanceState::ModelInstanceState(
io_binding_(nullptr), output_buffer_(nullptr)
{
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_,
ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_,
&default_allocator_, CudaStream()));

if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
@@ -1236,7 +1334,7 @@ ModelInstanceState::ModelInstanceState(
ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_));

THROW_IF_BACKEND_INSTANCE_ORT_ERROR(
ort_api->CreateIoBinding(session_, &io_binding_));
ort_api->CreateIoBinding(session_.get(), &io_binding_));

THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_));

@@ -1335,9 +1433,6 @@ ModelInstanceState::~ModelInstanceState()
ort_api->ReleaseRunOptions(runOptions_);
ort_api->ReleaseIoBinding(io_binding_);
ort_api->ReleaseMemoryInfo(cuda_allocator_info_);
if (session_ != nullptr) {
OnnxLoader::UnloadSession(session_);
}
// 'default_allocator_' is default allocator which is managed by ONNX
// Runtime
}
@@ -1399,7 +1494,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
if (*have_control) {
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos));
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
const auto& iit = input_tensor_infos.find(tensor_name);
if (iit == input_tensor_infos.end()) {
return TRITONSERVER_ErrorNew(
@@ -1456,7 +1551,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
if (*have_control) {
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos));
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
const auto& iit = input_tensor_infos.find(tensor_name);
if (iit == input_tensor_infos.end()) {
return TRITONSERVER_ErrorNew(
@@ -1503,17 +1598,17 @@ TRITONSERVER_Error*
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
{
std::set<std::string> input_tensor_names;
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names));
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos_));
InputInfos(session_.get(), default_allocator_, input_tensor_infos_));

std::set<std::string> overridable_initializer_tensor_names;
RETURN_IF_ERROR(OverridableInitializerNames(
session_, overridable_initializer_tensor_names));
session_.get(), overridable_initializer_tensor_names));

OnnxTensorInfoMap overridable_initializer_tensor_infos;
RETURN_IF_ERROR(OverridableInitializerInfos(
session_, default_allocator_, overridable_initializer_tensor_infos));
session_.get(), default_allocator_, overridable_initializer_tensor_infos));

if (input_tensor_infos_.size() != expected_input_cnt) {
return TRITONSERVER_ErrorNew(
@@ -1650,10 +1745,10 @@ TRITONSERVER_Error*
ModelInstanceState::ValidateOutputs()
{
std::set<std::string> output_tensor_names;
RETURN_IF_ERROR(OutputNames(session_, output_tensor_names));
RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names));

RETURN_IF_ERROR(
OutputInfos(session_, default_allocator_, output_tensor_infos_));
OutputInfos(session_.get(), default_allocator_, output_tensor_infos_));

triton::common::TritonJson::Value ios;
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios));
@@ -2050,7 +2145,7 @@ ModelInstanceState::OrtRun(
const uint32_t response_count)
{
RETURN_IF_ORT_ERROR(
ort_api->RunWithBinding(session_, runOptions_, io_binding_));
ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_));
return nullptr;
}

@@ -2590,7 +2685,6 @@ ModelInstanceState::ReadOutputTensors(
}
}


} else {
char* output_buffer = nullptr;
RETURN_IF_ORT_ERROR(
19 changes: 19 additions & 0 deletions src/onnxruntime_utils.cc
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "onnxruntime_utils.h"
#include <regex>

namespace triton { namespace backend { namespace onnxruntime {

@@ -550,5 +551,23 @@ CompareDimsSupported(
return nullptr; // success
}

std::string
GetInstanceGroupName(
const std::string& model_name, const std::string& instance_name)
{
if (model_name.empty() || instance_name.empty()) {
return "";
}
// Using regex search to extract instance group name from model instance name
// model instance naming follows pattern: <model name>_<instance group index>_<instance index>
// instance group naming follows pattern: <model name>_<instance group index>
std::regex group_name_regex('(' + model_name + '_' + "[0-9]" + ')');
std::smatch group_name;
if (std::regex_search(instance_name, group_name, group_name_regex)) {
return group_name.str(1);
}

return "";
}

}}} // namespace triton::backend::onnxruntime
3 changes: 3 additions & 0 deletions src/onnxruntime_utils.h
Original file line number Diff line number Diff line change
@@ -157,4 +157,7 @@ TRITONSERVER_Error* CompareDimsSupported(
const std::vector<int64_t>& model_shape, const std::vector<int64_t>& dims,
const int max_batch_size, const bool compare_exact);

std::string GetInstanceGroupName(
const std::string& model_name, const std::string& instance_name);

}}} // namespace triton::backend::onnxruntime