Skip to content

Commit

Permalink
relax seq len checking in rotary_emb (microsoft#20778)
Browse files Browse the repository at this point in the history
### Description
Length checking is even more strict for packed batching input.
There are two cases for a batch of input_ids.
- padded seq with equal length of inputs. 
```
|----********|
|------------|
|--------****|
|-***********|
```
- packed seqs with different length of input_ids
`|----|---------|----|-|`

The max_seq_length is either from graph_inputs or the position_ids.
While in most of cases, we will cache the max_seq_length of rotary_cache
in the model ans shared among all layers.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: kailums <[email protected]>
  • Loading branch information
wejoncy and kailums authored Jun 8, 2024
1 parent 981893c commit bd61ae5
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -5181,6 +5181,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>is_packed_batching</tt> : int</dt>
<dd>ragged batch inputs or not. Default value is 0</dd>
<dt><tt>num_heads</tt> : int</dt>
<dd>Number of attention heads. Default value is 0. Must use with rotary_embedding_dim</dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
is_packed_batching = (info.GetAttrOrDefault<int64_t>("is_packed_batching", 0) == 1);

if (rotary_embedding_dim > 0) {
ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
Expand Down Expand Up @@ -119,7 +120,7 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {

Tensor* output = context->Output(0, input->Shape());

if (parameters.sequence_length > parameters.max_sequence_length) {
if (is_packed_batching == false && parameters.sequence_length > parameters.max_sequence_length) {
// Launch update_cos_sin_cache kernel with scale
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class RotaryEmbedding final : public OpKernel {
int num_heads;
int rotary_embedding_dim;
bool interleaved;
bool is_packed_batching;
};

} // namespace contrib
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info)
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
is_packed_batching = (info.GetAttrOrDefault<int64_t>("is_packed_batching", 0) == 1);
}

template <typename T>
Expand All @@ -57,7 +58,7 @@ Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const {

Tensor* output = context->Output(0, input->Shape());

if (parameters.sequence_length > parameters.max_sequence_length) {
if (is_packed_batching == false && parameters.sequence_length > parameters.max_sequence_length) {
// Launch update_cos_sin_cache kernel with scale
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RotaryEmbedding final : public CudaKernel {
int num_heads;
int rotary_embedding_dim;
bool interleaved;
bool is_packed_batching;
};

} // namespace cuda
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Number of attention heads. Default value is 0. Must use with rotary_embedding_dim",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("is_packed_batching",
"ragged batch inputs or not. Default value is 0",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
"input",
"3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)",
Expand Down
54 changes: 53 additions & 1 deletion onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ static void RunTest(
int num_heads,
int max_sequence_length,
int64_t interleaved,
int64_t is_packed_batching,
TensorType tensor_type,
bool disable_cpu,
bool disable_cuda,
Expand All @@ -50,7 +51,7 @@ static void RunTest(
: head_size / 2};

assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0);
assert(max_sequence_length >= sequence_length);
assert((is_packed_batching == 0 && max_sequence_length >= sequence_length) || is_packed_batching == 1);
if (position_ids.size() == 1) {
pos_dims = {1};
} else {
Expand Down Expand Up @@ -89,6 +90,10 @@ static void RunTest(
test.AddAttribute<int64_t>("num_heads", num_heads);
}

if (rotary_embedding_dim > 0) {
test.AddAttribute<int64_t>("is_packed_batching", is_packed_batching);
}

if (tensor_type == TensorType::kFloat) {
test.AddInput<float>("input", input_dims, input_data);
test.AddInput<int64_t>("position_ids", pos_dims, position_ids);
Expand Down Expand Up @@ -129,6 +134,7 @@ static void RunTests(const std::vector<float>& input_data,
int num_heads = 0,
int max_sequence_length = 0,
int64_t interleaved = 0,
int64_t is_packed_batching = 0,
bool use_float16 = true,
bool disable_dml = false) {
// FP32 test for CPU
Expand All @@ -144,6 +150,7 @@ static void RunTests(const std::vector<float>& input_data,
num_heads,
max_sequence_length,
interleaved,
is_packed_batching,
TensorType::kFloat,
false, /* disable_cpu */
true, /* disable_cuda */
Expand All @@ -162,6 +169,7 @@ static void RunTests(const std::vector<float>& input_data,
num_heads,
max_sequence_length,
interleaved,
is_packed_batching,
TensorType::kFloat,
false, /* disable_cpu */
false, /* disable_cuda */
Expand All @@ -181,6 +189,7 @@ static void RunTests(const std::vector<float>& input_data,
num_heads,
max_sequence_length,
interleaved,
is_packed_batching,
TensorType::kFloat16,
true, /* disable_cpu */
false, /* disable_cuda*/
Expand Down Expand Up @@ -734,6 +743,49 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) {
num_heads,
max_sequence_length,
interleaved,
0, // is_packed_batching
true, /*use_fp16*/
true /*disable_dml*/);
}

TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) {
int batch_size = 1;
int sequence_length = 3;
int num_heads = 1;
int head_size = 6;
int rotary_embedding_dim = 4;
int max_sequence_length = 2;
int64_t interleaved = 0; // false

std::vector<float> input_data = {-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f,
-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f,
1.0076f, -0.7529f, -0.2250f, -0.4327f, -1.5071f, -0.4586f};

std::vector<int64_t> position_ids = {0, 0, 1};

std::vector<float> cos_cache = {
1.0000f, 1.0000f, 1.0000f, 0.5403f};

std::vector<float> sin_cache = {
0.0000f, 0.0000f, 0.0000f, 0.8415f};

std::vector<float> output_data = {-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f,
-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f,
1.0076f, -0.0427f, -0.2250f, -0.8673f, -1.5071f, -0.4586f};

RunTests(input_data,
position_ids,
cos_cache,
sin_cache,
output_data,
batch_size,
sequence_length,
head_size,
rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
1, // is_packed_batching
true, /*use_fp16*/
true /*disable_dml*/);
}
Expand Down

0 comments on commit bd61ae5

Please sign in to comment.