From bd61ae530bf3a5671b0b220cba30b96afa3c4161 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sat, 8 Jun 2024 18:39:06 +0800 Subject: [PATCH] relax seq len checking in rotary_emb (#20778) ### Description Length checking is even more strict for packed batching input. There are two cases for a batch of input_ids. - padded seq with equal length of inputs. ``` |----********| |------------| |--------****| |-***********| ``` - packed seqs with different length of input_ids `|----|---------|----|-|` The max_seq_length is either from graph_inputs or the position_ids. While in most of cases, we will cache the max_seq_length of rotary_cache in the model ans shared among all layers. ### Motivation and Context --------- Co-authored-by: kailums --- docs/ContribOperators.md | 2 + .../contrib_ops/cpu/bert/rotary_embedding.cc | 3 +- .../contrib_ops/cpu/bert/rotary_embedding.h | 1 + .../contrib_ops/cuda/bert/rotary_embedding.cc | 3 +- .../contrib_ops/cuda/bert/rotary_embedding.h | 1 + .../core/graph/contrib_ops/bert_defs.cc | 4 ++ .../contrib_ops/rotary_embedding_op_test.cc | 54 ++++++++++++++++++- 7 files changed, 65 insertions(+), 3 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 4d7493bd69650..45306c852a906 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5181,6 +5181,8 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
+
is_packed_batching : int
+
ragged batch inputs or not. Default value is 0
num_heads : int
Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
rotary_embedding_dim : int
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 195ebdf6a4811..6732f8b96cce2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -30,6 +30,7 @@ RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + is_packed_batching = (info.GetAttrOrDefault("is_packed_batching", 0) == 1); if (rotary_embedding_dim > 0) { ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); @@ -119,7 +120,7 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); - if (parameters.sequence_length > parameters.max_sequence_length) { + if (is_packed_batching == false && parameters.sequence_length > parameters.max_sequence_length) { // Launch update_cos_sin_cache kernel with scale ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); } diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index b291db538d1d1..4664cf1de2c87 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -25,6 +25,7 @@ class RotaryEmbedding final : public OpKernel { int num_heads; int rotary_embedding_dim; bool interleaved; + bool is_packed_batching; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index ab7479f2938fe..eef33192e6e6b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -37,6 +37,7 @@ RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + is_packed_batching = (info.GetAttrOrDefault("is_packed_batching", 0) == 1); } template @@ -57,7 +58,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); - if (parameters.sequence_length > parameters.max_sequence_length) { + if (is_packed_batching == false && parameters.sequence_length > parameters.max_sequence_length) { // Launch update_cos_sin_cache kernel with scale ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h index d52f61d670444..55f654a2cea8f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -22,6 +22,7 @@ class RotaryEmbedding final : public CudaKernel { int num_heads; int rotary_embedding_dim; bool interleaved; + bool is_packed_batching; }; } // namespace cuda diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 916f0c92fd38d..2a14ba1db4bb7 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1348,6 +1348,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("is_packed_batching", + "ragged batch inputs or not. Default value is 0", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "input", "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index e64de0e6da16a..89552da58b938 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -32,6 +32,7 @@ static void RunTest( int num_heads, int max_sequence_length, int64_t interleaved, + int64_t is_packed_batching, TensorType tensor_type, bool disable_cpu, bool disable_cuda, @@ -50,7 +51,7 @@ static void RunTest( : head_size / 2}; assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); - assert(max_sequence_length >= sequence_length); + assert((is_packed_batching == 0 && max_sequence_length >= sequence_length) || is_packed_batching == 1); if (position_ids.size() == 1) { pos_dims = {1}; } else { @@ -89,6 +90,10 @@ static void RunTest( test.AddAttribute("num_heads", num_heads); } + if (rotary_embedding_dim > 0) { + test.AddAttribute("is_packed_batching", is_packed_batching); + } + if (tensor_type == TensorType::kFloat) { test.AddInput("input", input_dims, input_data); test.AddInput("position_ids", pos_dims, position_ids); @@ -129,6 +134,7 @@ static void RunTests(const std::vector& input_data, int num_heads = 0, int max_sequence_length = 0, int64_t interleaved = 0, + int64_t is_packed_batching = 0, bool use_float16 = true, bool disable_dml = false) { // FP32 test for CPU @@ -144,6 +150,7 @@ static void RunTests(const std::vector& input_data, num_heads, max_sequence_length, interleaved, + is_packed_batching, TensorType::kFloat, false, /* disable_cpu */ true, /* disable_cuda */ @@ -162,6 +169,7 @@ static void RunTests(const std::vector& input_data, num_heads, max_sequence_length, interleaved, + is_packed_batching, TensorType::kFloat, false, /* disable_cpu */ false, /* disable_cuda */ @@ -181,6 +189,7 @@ static void RunTests(const std::vector& input_data, num_heads, max_sequence_length, interleaved, + is_packed_batching, TensorType::kFloat16, true, /* disable_cpu */ false, /* disable_cuda*/ @@ -734,6 +743,49 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { num_heads, max_sequence_length, interleaved, + 0, // is_packed_batching + true, /*use_fp16*/ + true /*disable_dml*/); +} + +TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 1; + int head_size = 6; + int rotary_embedding_dim = 4; + int max_sequence_length = 2; + int64_t interleaved = 0; // false + + std::vector input_data = {-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, + 1.0076f, -0.7529f, -0.2250f, -0.4327f, -1.5071f, -0.4586f}; + + std::vector position_ids = {0, 0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f}; + + std::vector output_data = {-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, + 1.0076f, -0.0427f, -0.2250f, -0.8673f, -1.5071f, -0.4586f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + 1, // is_packed_batching true, /*use_fp16*/ true /*disable_dml*/); }