From eeb435d27c5417ee28c83ac522360eee843bb420 Mon Sep 17 00:00:00 2001 From: Sindhu Somasundaram <56774226+sindhuvahinis@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:16:36 -0700 Subject: [PATCH] [unittest] add spec decoding multiple tokens generation unit tests (#2373) --- .../djl_python/tests/test_rolling_batch.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/engines/python/setup/djl_python/tests/test_rolling_batch.py b/engines/python/setup/djl_python/tests/test_rolling_batch.py index b4de55bc3..64f0bec16 100644 --- a/engines/python/setup/djl_python/tests/test_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/test_rolling_batch.py @@ -37,6 +37,29 @@ def test_json_fmt(self): self.assertEqual(json.dumps({"generated_text": "Hello world"}), req.get_next_token()) + def test_json_speculative_decoding(self): + req_input = TextInput( + request_id=0, + input_text="This is a wonderful day", + parameters={"max_new_tokens": 256}, + output_formatter=_json_output_formatter, + ) + req = Request(req_input) + req.request_output = TextGenerationOutput(request_id=0, + input=req_input) + req.request_output.finished = True + req.request_output.set_next_token(Token(244, "He", -0.334532)) + req.request_output.set_next_token(Token(576, "llo", -0.123123)) + req.request_output.set_next_token(Token(4558, " world", -0.567854, + True), + is_last_token=True, + finish_reason='length') + + self.assertEqual(req.get_next_token(), "") + self.assertEqual(req.get_next_token(), "") + self.assertEqual(req.get_next_token(), + json.dumps({"generated_text": "Hello world"})) + def test_json_fmt_with_appending(self): req_input1 = TextInput(request_id=0, input_text="This is a wonderful day", @@ -152,6 +175,47 @@ def test_jsonlines_fmt(self): "generated_text": "Hello world" }, json.loads(req.get_next_token())) + def test_jsonlines_speculative_decoding(self): + request_input = TextInput(request_id=0, + input_text="This is a wonderful day", + parameters={"max_new_tokens": 256}, + output_formatter=_jsonlines_output_formatter) + req = Request(request_input=request_input) + req.request_output = TextGenerationOutput(request_id=0, + input=request_input) + req.request_output.finished = True + req.request_output.set_next_token(Token(244, "He", -0.334532)) + print(req.get_next_token(), end='') + self.assertEqual( + {"token": { + "id": 244, + "text": "He", + "log_prob": -0.334532 + }}, json.loads(req.get_next_token())) + req.reset_next_token() + req.request_output.set_next_token(Token(576, "llo", -0.123123)) + print(req.get_next_token(), end='') + self.assertEqual( + {"token": { + "id": 576, + "text": "llo", + "log_prob": -0.123123 + }}, json.loads(req.get_next_token())) + req.reset_next_token() + req.request_output.set_next_token(Token(4558, " world", -0.567854), + is_last_token=True, + finish_reason='length') + print(req.get_next_token(), end='') + self.assertEqual( + { + "token": { + "id": 4558, + "text": " world", + "log_prob": -0.567854 + }, + "generated_text": "Hello world" + }, json.loads(req.get_next_token())) + def test_sse_fmt(self): request_input = TextInput(request_id=0, input_text="This is a wonderful day",