From 009a95aee71e6e25097910873d3c957ff05b1506 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 17 Jan 2025 12:06:45 -0500 Subject: [PATCH] Revert "feat: improve qwen2-vl startup (#2802)" This reverts commit eecca27113538716ff14b05738eb6e8c5f7afd8a. --- backends/client/src/lib.rs | 2 +- backends/v2/src/client/mod.rs | 2 +- backends/v3/src/client/mod.rs | 2 +- .../test_flash_qwen2_vl_simple.json | 26 --- .../models/test_flash_qwen2_vl.py | 161 +++++++++--------- .../models/test_flash_qwen2_vl_warmup.py | 38 ----- .../text_generation_server/models/__init__.py | 6 - .../custom_modeling/flash_qwen2_modeling.py | 7 +- .../models/custom_modeling/qwen2_vl.py | 5 +- .../models/flash_causal_lm.py | 18 +- .../models/vlm_causal_lm.py | 1 + 11 files changed, 95 insertions(+), 173 deletions(-) delete mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json delete mode 100644 integration-tests/models/test_flash_qwen2_vl_warmup.py diff --git a/backends/client/src/lib.rs b/backends/client/src/lib.rs index fbe2e7e668a..45bee10ca50 100644 --- a/backends/client/src/lib.rs +++ b/backends/client/src/lib.rs @@ -86,6 +86,6 @@ impl ChunksToString for Vec { } } -static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/backends/v2/src/client/mod.rs b/backends/v2/src/client/mod.rs index 9fe114a2c87..fa9d440645d 100644 --- a/backends/v2/src/client/mod.rs +++ b/backends/v2/src/client/mod.rs @@ -63,6 +63,6 @@ impl From for ClientError { } } -static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index ab4311c3b8c..d4ac50c9c46 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -62,6 +62,6 @@ impl From for InputChunk { } } -static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json deleted file mode 100644 index a986510f239..00000000000 --- a/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "choices": [ - { - "finish_reason": "stop", - "index": 0, - "logprobs": null, - "message": { - "content": "The correct answer is: blue", - "name": null, - "role": "assistant", - "tool_calls": null - }, - "usage": null - } - ], - "created": 1733445131, - "id": "", - "model": "Qwen/Qwen2-VL-2B-Instruct", - "object": "chat.completion", - "system_fingerprint": "2.4.2-dev0-native", - "usage": { - "completion_tokens": 7, - "prompt_tokens": 27, - "total_tokens": 34 - } -} diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 946ab2f1efb..97a533fc5d4 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -1,80 +1,81 @@ -import pytest - - -@pytest.fixture(scope="module") -def flash_qwen2_vl_handle(launcher): - with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: - yield handle - - -@pytest.fixture(scope="module") -async def flash_qwen2(flash_qwen2_vl_handle): - await flash_qwen2_vl_handle.health(300) - return flash_qwen2_vl_handle.client - - -@pytest.mark.private -async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): - response = await flash_qwen2.chat( - max_tokens=100, - seed=42, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" - }, - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - ], - ) - - assert ( - response.choices[0].message.content - == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." - ) - - assert response == response_snapshot - - -@pytest.mark.private -async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): - responses = await flash_qwen2.chat( - max_tokens=100, - seed=42, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" - }, - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - ], - stream=True, - ) - - count = 0 - generated = "" - last_response = None - async for response in responses: - count += 1 - generated += response.choices[0].delta.content - last_response = response - - assert ( - generated - == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." - ) - assert count == 58 - assert last_response == response_snapshot +# Disabled because it's broken. +# import pytest +# +# +# @pytest.fixture(scope="module") +# def flash_qwen2_vl_handle(launcher): +# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: +# yield handle +# +# +# @pytest.fixture(scope="module") +# async def flash_qwen2(flash_qwen2_vl_handle): +# await flash_qwen2_vl_handle.health(300) +# return flash_qwen2_vl_handle.client +# +# +# @pytest.mark.private +# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): +# response = await flash_qwen2.chat( +# max_tokens=100, +# seed=42, +# messages=[ +# { +# "role": "user", +# "content": [ +# { +# "type": "image_url", +# "image_url": { +# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" +# }, +# }, +# {"type": "text", "text": "Describe this image."}, +# ], +# }, +# ], +# ) +# +# assert ( +# response.choices[0].message.content +# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." +# ) +# +# assert response == response_snapshot +# +# +# @pytest.mark.private +# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): +# responses = await flash_qwen2.chat( +# max_tokens=100, +# seed=42, +# messages=[ +# { +# "role": "user", +# "content": [ +# { +# "type": "image_url", +# "image_url": { +# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" +# }, +# }, +# {"type": "text", "text": "Describe this image."}, +# ], +# }, +# ], +# stream=True, +# ) +# +# count = 0 +# generated = "" +# last_response = None +# async for response in responses: +# count += 1 +# generated += response.choices[0].delta.content +# last_response = response +# +# assert ( +# generated +# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." +# ) +# assert count == 58 +# assert last_response == response_snapshot diff --git a/integration-tests/models/test_flash_qwen2_vl_warmup.py b/integration-tests/models/test_flash_qwen2_vl_warmup.py deleted file mode 100644 index 5be87ee21a3..00000000000 --- a/integration-tests/models/test_flash_qwen2_vl_warmup.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest - - -@pytest.fixture(scope="module") -def flash_qwen2_vl_handle(launcher): - with launcher( - "Qwen/Qwen2-VL-2B-Instruct", - max_input_length=40, - max_batch_prefill_tokens=50, - max_total_tokens=51, - ) as handle: - yield handle - - -@pytest.fixture(scope="module") -async def flash_qwen2(flash_qwen2_vl_handle): - await flash_qwen2_vl_handle.health(300) - return flash_qwen2_vl_handle.client - - -@pytest.mark.private -async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): - response = await flash_qwen2.chat( - max_tokens=20, - seed=42, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is the color of the sky?"}, - ], - }, - ], - ) - - assert response.choices[0].message.content == "The correct answer is: blue" - - assert response == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 7521fb46556..e2d24643ecc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -29,7 +29,6 @@ BloomForCausalLM, ) from text_generation_server.models.globals import ATTENTION -import text_generation_server.models.globals as globals from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -1218,11 +1217,6 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == QWEN2_VL: - # TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL - logger.warning( - "Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2." - ) - globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS)) return VlmCausalLM( model_id=model_id, model_class=Qwen2VLForConditionalGeneration, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 01d3bf1a377..cc4039b1cbc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -138,12 +138,7 @@ def forward( dim=-1, ) - self.rotary_emb( - query, - torch.select(kv, dim=1, index=0), - cos[: query.shape[0], ...], - sin[: query.shape[0], ...], - ) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 95cf6a318e0..a8e1e8c1593 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -517,11 +517,11 @@ def forward( pixel_values: torch.FloatTensor = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask: Optional[torch.Tensor] = None, + pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, - image_indices: Optional[torch.Tensor] = None, + image_indices=None, ): inputs_embeds = self.embed_tokens(input_ids) @@ -533,7 +533,6 @@ def forward( ).squeeze(0) inputs_embeds[input_ids == self.image_token_id] = image_embeds - max_s = max(max_s, inputs_embeds.size(0)) hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f2d27db9181..d097c54fc2c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -56,13 +56,11 @@ MEM_POOL, ATTENTION, BLOCK_SIZE, + CUDA_GRAPHS, REQUEST_LOGPROBS, TGI_WIGGLE_ROOM, get_adapter_to_index, ) - -# avoid coping CUDA_GRAPHS value by importing globals as a module -import text_generation_server.models.globals as globals from text_generation_server.layers.attention import KVCache, Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -1637,8 +1635,8 @@ def warmup( int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") ] - elif globals.CUDA_GRAPHS is not None: - tuning_sequences = globals.CUDA_GRAPHS + elif CUDA_GRAPHS is not None: + tuning_sequences = CUDA_GRAPHS else: tuning_sequences = [1, 2, 3, 4, 5, 6, 7] @@ -1677,14 +1675,13 @@ def warmup( "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) - if globals.CUDA_GRAPHS: + if CUDA_GRAPHS: try: log_master( - logger.info, - f"Cuda Graphs are enabled for sizes {globals.CUDA_GRAPHS}", + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" ) # Warmup cuda graphs - for bs in globals.CUDA_GRAPHS: + for bs in CUDA_GRAPHS: synchronize(self.device) free_memory = get_free_memory( self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM @@ -1708,8 +1705,7 @@ def warmup( logger.exception("Decode cuda graph warmup failed") else: log_master( - logger.info, - f"Cuda Graphs are disabled (CUDA_GRAPHS={globals.CUDA_GRAPHS}).", + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." ) assert max_input_tokens is not None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 4d6ea84e023..db78341d1ed 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -236,6 +236,7 @@ def batch_tokenized_inputs( w = image.width * 2 h = image.height * 2 image = image.resize((w, h)) + if config.model_type == "llava_next": images.append(image) else: