Skip to content

Commit

Permalink
[unittest] add spec decoding multiple tokens generation unit tests (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Sep 10, 2024
1 parent 32daccd commit eeb435d
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ def test_json_fmt(self):
self.assertEqual(json.dumps({"generated_text": "Hello world"}),
req.get_next_token())

def test_json_speculative_decoding(self):
req_input = TextInput(
request_id=0,
input_text="This is a wonderful day",
parameters={"max_new_tokens": 256},
output_formatter=_json_output_formatter,
)
req = Request(req_input)
req.request_output = TextGenerationOutput(request_id=0,
input=req_input)
req.request_output.finished = True
req.request_output.set_next_token(Token(244, "He", -0.334532))
req.request_output.set_next_token(Token(576, "llo", -0.123123))
req.request_output.set_next_token(Token(4558, " world", -0.567854,
True),
is_last_token=True,
finish_reason='length')

self.assertEqual(req.get_next_token(), "")
self.assertEqual(req.get_next_token(), "")
self.assertEqual(req.get_next_token(),
json.dumps({"generated_text": "Hello world"}))

def test_json_fmt_with_appending(self):
req_input1 = TextInput(request_id=0,
input_text="This is a wonderful day",
Expand Down Expand Up @@ -152,6 +175,47 @@ def test_jsonlines_fmt(self):
"generated_text": "Hello world"
}, json.loads(req.get_next_token()))

def test_jsonlines_speculative_decoding(self):
request_input = TextInput(request_id=0,
input_text="This is a wonderful day",
parameters={"max_new_tokens": 256},
output_formatter=_jsonlines_output_formatter)
req = Request(request_input=request_input)
req.request_output = TextGenerationOutput(request_id=0,
input=request_input)
req.request_output.finished = True
req.request_output.set_next_token(Token(244, "He", -0.334532))
print(req.get_next_token(), end='')
self.assertEqual(
{"token": {
"id": 244,
"text": "He",
"log_prob": -0.334532
}}, json.loads(req.get_next_token()))
req.reset_next_token()
req.request_output.set_next_token(Token(576, "llo", -0.123123))
print(req.get_next_token(), end='')
self.assertEqual(
{"token": {
"id": 576,
"text": "llo",
"log_prob": -0.123123
}}, json.loads(req.get_next_token()))
req.reset_next_token()
req.request_output.set_next_token(Token(4558, " world", -0.567854),
is_last_token=True,
finish_reason='length')
print(req.get_next_token(), end='')
self.assertEqual(
{
"token": {
"id": 4558,
"text": " world",
"log_prob": -0.567854
},
"generated_text": "Hello world"
}, json.loads(req.get_next_token()))

def test_sse_fmt(self):
request_input = TextInput(request_id=0,
input_text="This is a wonderful day",
Expand Down

0 comments on commit eeb435d

Please sign in to comment.