From 064c51bfbd9198900a6be5756333385d1da7ade8 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 29 Jan 2025 10:09:56 +0100 Subject: [PATCH] fix --- tests/generation/test_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d9b4bbbe8c6..1863598335b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1185,7 +1185,9 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): "return_dict_in_generate": True, "use_cache": True, } - output_greedy = model.generate(**generation_kwargs, **inputs_dict) + logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) + + output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) # test with the same assistant model or randomly init one # in the first case all candidate tokens are accepted, in the second none is accepted @@ -1197,7 +1199,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) generation_kwargs.update({"assistant_model": assistant_model}) - output_assisted = model.generate(**generation_kwargs, **inputs_dict) + output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) # The two outputs must match and their shape must be as expected self._check_similar_generate_outputs(output_greedy, output_assisted)