diff --git a/src/cpp/src/continuous_batching_pipeline.cpp b/src/cpp/src/continuous_batching_pipeline.cpp index 90a141e9ee..c3fa9edb23 100644 --- a/src/cpp/src/continuous_batching_pipeline.cpp +++ b/src/cpp/src/continuous_batching_pipeline.cpp @@ -60,16 +60,20 @@ class ContinuousBatchingPipeline::Impl { ChatHistory m_history; + void _notify_requests_dropped_by_handle() { + // Notify the last time by pushing empty output + // This causes read_all() to unblock by adding anything to the queue + for (SequenceGroup::Ptr& request : m_requests) { + if (request->handle_dropped()) + request->push_empty_outputs(); + } + } + void _free_non_running_requests() { std::vector::iterator requests_iterator = m_requests.begin(); while (requests_iterator != m_requests.end()) { const auto& request = *requests_iterator; if(request->has_finished() || request->out_of_memory() || request->handle_dropped()) { - // Notify the last time even if there will be no results - // This causes read_all() to unblock - // Avoid notifying again once finished - if (request->out_of_memory() || request->handle_dropped()) - request->notify_handle(); for (const auto& sequence: request->get_sequences()) { m_scheduler->free_sequence(sequence->get_id()); } @@ -180,6 +184,7 @@ class ContinuousBatchingPipeline::Impl { for (size_t i = 0; i < m_requests.size(); ++i) { SequenceGroup::Ptr sequence_group = m_requests[i]; sequence_group->set_out_of_memory(); + sequence_group->notify_handle(); } _free_non_running_requests(); return; @@ -231,6 +236,15 @@ class ContinuousBatchingPipeline::Impl { timer.end(); } + // notify requests dropped by handle + + { + static ManualTimer timer("notify requests dropped by handle"); + timer.start(); + _notify_requests_dropped_by_handle(); + timer.end(); + } + // free non running requests for current step { diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 334e3db1d6..5a49843055 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -474,20 +474,13 @@ class SequenceGroup { } void notify_handle() { - if (handle_dropped()) { - // Push anything to the queue to unblock cancelled read_all() calls - // When handle is dropped we do not care about any remaining data - push_empty_outputs(); - return; - } - if (out_of_memory()) { set_generation_status(GenerationStatus::IGNORED); } else if (has_finished()) { set_generation_status(GenerationStatus::FINISHED); } - if (m_sampling_params.is_beam_search()) { - // For beam search streaming is not available, so we notify only upon finishing + // For beam search streaming is not available, so we notify only upon finishing + if(m_sampling_params.is_beam_search()) { if (has_finished() || out_of_memory()) { push_outputs(); }