From e693210a7ce84acee8049d096798d7614599f2a3 Mon Sep 17 00:00:00 2001 From: baishuai Date: Thu, 9 Jan 2025 01:49:28 +0800 Subject: [PATCH 01/42] add qwen2.5vl --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/qwen2_5_vl.md | 327 +++ src/transformers/__init__.py | 22 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/qwen2_5_vl/__init__.py | 79 + .../qwen2_5_vl/configuration_qwen2_5_vl.py | 237 ++ .../qwen2_5_vl/image_processing_qwen2_5_vl.py | 480 ++++ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1971 +++++++++++++++++ .../qwen2_5_vl/processing_qwen2_5_vl.py | 218 ++ tests/models/qwen2_5_vl/__init__.py | 0 .../test_image_processing_qwen2_5_vl.py | 252 +++ .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 546 +++++ .../qwen2_5_vl/test_processor_qwen2_5_vl.py | 113 + 17 files changed, 4255 insertions(+) create mode 100644 docs/source/en/model_doc/qwen2_5_vl.md create mode 100644 src/transformers/models/qwen2_5_vl/__init__.py create mode 100644 src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py create mode 100644 src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py create mode 100644 src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py create mode 100644 src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py create mode 100644 tests/models/qwen2_5_vl/__init__.py create mode 100644 tests/models/qwen2_5_vl/test_image_processing_qwen2_5_vl.py create mode 100644 tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py create mode 100644 tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f130ddbf72b9..ec55ec92947d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -922,6 +922,8 @@ title: Pix2Struct - local: model_doc/pixtral title: Pixtral + - local: model_doc/qwen2_5_vl + title: Qwen2_5_VL - local: model_doc/qwen2_audio title: Qwen2Audio - local: model_doc/qwen2_vl diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md new file mode 100644 index 000000000000..20ca767723ff --- /dev/null +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -0,0 +1,327 @@ + + +# Qwen2.5-VL + +## Overview + +The [Qwen2_5-VL](https://qwenlm.github.io/blog/qwen2_5-vl/) model is an update to [Qwen2-VL](https://arxiv.org/abs/2409.12191) from Qwen team, Alibaba Group. + +The abstract from this update is the following: + +*Qwen2_5-VL marks a major step forward from Qwen2-VL, built upon the latest Qwen2.5 LLM. We've accelerated training and testing through the strategic implementation of window attention within the ViT. The ViT architecture itself has been refined with SwiGLU and RMSNorm, aligning it more closely with the LLM's structure. A key innovation is the expansion of native dynamic resolution to encompass the temporal dimension, in addition to spatial aspects. Furthermore, we've upgraded MRoPE, incorporating absolute time alignment on the time axis to allow the model to effectively capture temporal dynamics, regardless of frame rate, leading to superior video understanding.* + +## Usage example + +### Single Media inference + +The model can accept both images and videos as input. Here's an example code for inference. + +```python + +from PIL import Image +import requests +import torch +from torchvision import io +from typing import Dict +from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor + +# Load the model in half-precision on the available device(s) +model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", device_map="auto") +processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + +# Image +url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" +image = Image.open(requests.get(url, stream=True).raw) + +conversation = [ + { + "role":"user", + "content":[ + { + "type":"image", + }, + { + "type":"text", + "text":"Describe this image." + } + ] + } +] + + +# Preprocess the inputs +text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) +# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>\n<|im_start|>assistant\n' + +inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt") +inputs = inputs.to('cuda') + +# Inference: Generation of the output +output_ids = model.generate(**inputs, max_new_tokens=128) +generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] +output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) +print(output_text) + +# Video +def fetch_video(ele: Dict, nframe_factor=2): + if isinstance(ele['video'], str): + def round_by_factor(number: int, factor: int) -> int: + return round(number / factor) * factor + + video = ele["video"] + if video.startswith("file://"): + video = video[7:] + + video, _, info = io.read_video( + video, + start_pts=ele.get("video_start", 0.0), + end_pts=ele.get("video_end", None), + pts_unit="sec", + output_format="TCHW", + ) + assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], nframe_factor) + else: + fps = ele.get("fps", 1.0) + nframes = round_by_factor(video.size(0) / info["video_fps"] * fps, nframe_factor) + idx = torch.linspace(0, video.size(0) - 1, nframes, dtype=torch.int64) + return video[idx] + +video_info = {"type": "video", "video": "/path/to/video.mp4", "fps": 1.0} +video = fetch_video(video_info) +conversation = [ + { + "role": "user", + "content": [ + {"type": "video"}, + {"type": "text", "text": "What happened in the video?"}, + ], + } +] + +# Preprocess the inputs +text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) +# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>What happened in the video?<|im_end|>\n<|im_start|>assistant\n' + +# Qwen2.5VL modifies the time positional encoding (MRoPE) according to the video's frame rate (FPS). +# Therefore, the video's FPS information needs to be provided as input. +inputs = processor(text=[text_prompt], videos=[video], fps=[1.0], padding=True, return_tensors="pt") +inputs = inputs.to('cuda') + +# Inference: Generation of the output +output_ids = model.generate(**inputs, max_new_tokens=128) +generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] +output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) +print(output_text) +``` + +### Batch Mixed Media Inference + +The model can batch inputs composed of mixed samples of various types such as images, videos, and text. Here is an example. + +```python +image1 = Image.open("/path/to/image1.jpg") +image2 = Image.open("/path/to/image2.jpg") +image3 = Image.open("/path/to/image3.jpg") +image4 = Image.open("/path/to/image4.jpg") +image5 = Image.open("/path/to/image5.jpg") +video = fetch_video({ + "type": "video", + "video": "/path/to/video.mp4", + "fps": 1.0 +}) + +# Conversation for the first image +conversation1 = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Describe this image."} + ] + } +] + +# Conversation with two images +conversation2 = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "image"}, + {"type": "text", "text": "What is written in the pictures?"} + ] + } +] + +# Conversation with pure text +conversation3 = [ + { + "role": "user", + "content": "who are you?" + } +] + + +# Conversation with mixed midia +conversation4 = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "image"}, + {"type": "video"}, + {"type": "text", "text": "What are the common elements in these medias?"}, + ], + } +] + +conversations = [conversation1, conversation2, conversation3, conversation4] +# Preparation for batch inference +texts = [processor.apply_chat_template(msg, add_generation_prompt=True) for msg in conversations] +inputs = processor( + text=texts, + images=[image1, image2, image3, image4, image5], + videos=[video], + padding=True, + return_tensors="pt", +) +inputs = inputs.to('cuda') + +# Batch Inference +output_ids = model.generate(**inputs, max_new_tokens=128) +generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] +output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) +print(output_text) +``` + +### Usage Tips + +#### Image Resolution trade-off + +The model supports a wide range of resolution inputs. By default, it uses the native resolution for input, but higher resolutions can enhance performance at the cost of more computation. Users can set the minimum and maximum number of pixels to achieve an optimal configuration for their needs. + +```python +min_pixels = 224*224 +max_pixels = 2048*2048 +processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) +``` + +In case of limited GPU RAM, one can reduce the resolution as follows: + +```python +min_pixels = 256*28*28 +max_pixels = 1024*28*28 +processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) +``` +This ensures each image gets encoded using a number between 256-1024 tokens. The 28 comes from the fact that the model uses a patch size of 14 and a temporal patch size of 2 (14 x 2 = 28). + +#### Multiple Image Inputs + +By default, images and video content are directly included in the conversation. When handling multiple images, it's helpful to add labels to the images and videos for better reference. Users can control this behavior with the following settings: + +```python +conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Hello, how are you?"} + ] + }, + { + "role": "assistant", + "content": "I'm doing well, thank you for asking. How can I assist you today?" + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Can you describe these images and video?"}, + {"type": "image"}, + {"type": "image"}, + {"type": "video"}, + {"type": "text", "text": "These are from my vacation."} + ] + }, + { + "role": "assistant", + "content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?" + }, + { + "role": "user", + "content": "It was a trip to the mountains. Can you see the details in the images and video?" + } +] + +# default: +prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True) +# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n' + + +# add ids +prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True) +# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n' + +``` + +#### Flash-Attention 2 to speed up generation + +First, make sure to install the latest version of Flash Attention 2: + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Also, you should have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`. + +To load and run a model using Flash Attention-2, simply add `attn_implementation="flash_attention_2"` when loading the model as follows: + +```python +from transformers import Qwen2_5_VLForConditionalGeneration + +model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", +) +``` + + + +## Qwen2_5_VLConfig + +[[autodoc]] Qwen2_5_VLConfig + +## Qwen2_5_VLImageProcessor + +[[autodoc]] Qwen2_5_VLImageProcessor + - preprocess + +## Qwen2_5_VLProcessor + +[[autodoc]] Qwen2_5_VLProcessor + +## Qwen2_5_VLModel + +[[autodoc]] Qwen2_5_VLModel + - forward + +## Qwen2_5_VLForConditionalGeneration + +[[autodoc]] Qwen2_5_VLForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2b4980306c53..cf043ca6e0c1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -711,6 +711,10 @@ "Qwen2VLConfig", "Qwen2VLProcessor", ], + "models.qwen2_5_vl": [ + "Qwen2_5_VLConfig", + "Qwen2_5_VLProcessor", + ], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], "models.recurrent_gemma": ["RecurrentGemmaConfig"], "models.reformer": ["ReformerConfig"], @@ -1254,6 +1258,7 @@ _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) + _import_structure["models.qwen2_5_vl"].extend(["Qwen2_5_VLImageProcessor"]) _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) _import_structure["models.sam"].extend(["SamImageProcessor"]) _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) @@ -3266,6 +3271,13 @@ "Qwen2VLPreTrainedModel", ] ) + _import_structure["models.qwen2_5_vl"].extend( + [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + ] + ) _import_structure["models.rag"].extend( [ "RagModel", @@ -5737,6 +5749,10 @@ from .models.pvt import PvtConfig from .models.pvt_v2 import PvtV2Config from .models.qwen2 import Qwen2Config, Qwen2Tokenizer + from .models.qwen2_5_vl import ( + Qwen2_5_VLConfig, + Qwen2_5_VLProcessor, + ) from .models.qwen2_audio import ( Qwen2AudioConfig, Qwen2AudioEncoderConfig, @@ -6313,6 +6329,7 @@ PoolFormerImageProcessor, ) from .models.pvt import PvtImageProcessor + from .models.qwen2_5_vl import Qwen2_5_VLImageProcessor from .models.qwen2_vl import Qwen2VLImageProcessor from .models.rt_detr import RTDetrImageProcessor from .models.sam import SamImageProcessor @@ -7914,6 +7931,11 @@ Qwen2Model, Qwen2PreTrainedModel, ) + from .models.qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + ) from .models.qwen2_audio import ( Qwen2AudioEncoder, Qwen2AudioForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 150eca78e388..dd7d93f8d162 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -213,6 +213,7 @@ pvt, pvt_v2, qwen2, + qwen2_5_vl, qwen2_audio, qwen2_moe, qwen2_vl, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e8deb34018e5..331338509390 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -234,6 +234,7 @@ ("pvt_v2", "PvtV2Config"), ("qdqbert", "QDQBertConfig"), ("qwen2", "Qwen2Config"), + ("qwen2_5_vl", "Qwen2_5_VLConfig"), ("qwen2_audio", "Qwen2AudioConfig"), ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"), ("qwen2_moe", "Qwen2MoeConfig"), @@ -566,6 +567,7 @@ ("pvt_v2", "PVTv2"), ("qdqbert", "QDQBert"), ("qwen2", "Qwen2"), + ("qwen2_5_vl", "Qwen2_5_VL"), ("qwen2_audio", "Qwen2Audio"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoE"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d2d52a77579a..82ad3806dc40 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -219,6 +219,7 @@ ("pvt_v2", "PvtV2Model"), ("qdqbert", "QDQBertModel"), ("qwen2", "Qwen2Model"), + ("qwen2_5_vl", "Qwen2_5_VLModel"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), @@ -779,6 +780,7 @@ ("mllama", "MllamaForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), + ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), ("qwen2_vl", "Qwen2VLForConditionalGeneration"), ("video_llava", "VideoLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), @@ -812,6 +814,7 @@ ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("pixtral", "LlavaForConditionalGeneration"), + ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), ("qwen2_vl", "Qwen2VLForConditionalGeneration"), ("udop", "UdopForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 815e2ca755be..f510cf12db07 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -88,6 +88,7 @@ ("pix2struct", "Pix2StructProcessor"), ("pixtral", "PixtralProcessor"), ("pop2piano", "Pop2PianoProcessor"), + ("qwen2_5_vl", "Qwen2_5_VLProcessor"), ("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), ("sam", "SamProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 2f6624057e0f..9df36cc71640 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -418,6 +418,7 @@ "Qwen2TokenizerFast" if is_tokenizers_available() else None, ), ), + ("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), ("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), ( "qwen2_moe", diff --git a/src/transformers/models/qwen2_5_vl/__init__.py b/src/transformers/models/qwen2_5_vl/__init__.py new file mode 100644 index 000000000000..e2ee26089d5b --- /dev/null +++ b/src/transformers/models/qwen2_5_vl/__init__.py @@ -0,0 +1,79 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_vision_available, +) + + +_import_structure = { + "configuration_qwen2_5_vl": ["Qwen2_5_VLConfig"], + "processing_qwen2_5_vl": ["Qwen2_5_VLProcessor"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_qwen2_5_vl"] = [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + ] + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_qwen2_5_vl"] = ["Qwen2_5_VLImageProcessor"] + + +if TYPE_CHECKING: + from .configuration_qwen2_5_vl import Qwen2_5_VLConfig + from .processing_qwen2_5_vl import Qwen2_5_VLProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + ) + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_qwen2_5_vl import Qwen2_5_VLImageProcessor + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py new file mode 100644 index 000000000000..41b2fa1f92a9 --- /dev/null +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2.5-VL model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=25, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=list(range(7, 32, 8)), + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2.5-VL-7B-Instruct [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2.5-VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2_5_VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2.5-VL configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # Initializing a model from the Qwen2.5-VL configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + ```""" + + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2_5_VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py new file mode 100644 index 000000000000..5bc8b2db5104 --- /dev/null +++ b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py @@ -0,0 +1,480 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Qwen2.5-VL.""" + +import math +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + VideoInput, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + from PIL import Image + + +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched images from {images}") + + +# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos +def make_batched_videos(videos) -> List[VideoInput]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + if isinstance(videos[0], Image.Image): + return [videos] + elif len(videos[0].shape) == 4: + return [list(video) for video in videos] + + elif is_valid_image(videos) and len(videos.shape) == 4: + return [list(videos)] + + raise ValueError(f"Could not make batched video from {videos}") + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 56 * 56, + max_pixels: int = 14 * 14 * 4 * 1280, +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class Qwen2_5_VLImageProcessor(BaseImageProcessor): + r""" + Constructs a Qwen2.5-VL image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + min_pixels (`int`, *optional*, defaults to `56 * 56`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + model_input_names = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + ] + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + min_pixels: int = 56 * 56, + max_pixels: int = 28 * 28 * 1280, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: Union[ImageInput, VideoInput], + do_resize: bool = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`List[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, + size=(resized_height, resized_width), + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + patches = np.array(processed_images) + if data_format == ChannelDimension.LAST: + patches = patches.transpose(0, 3, 1, 2) + if patches.shape[0] == 1: + patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h, grid_w = ( + resized_height // self.patch_size, + resized_width // self.patch_size, + ) + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * self.temporal_patch_size * self.patch_size * self.patch_size, + ) + + return flatten_patches, (grid_t, grid_h, grid_w) + + def preprocess( + self, + images: ImageInput, + videos: VideoInput = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + videos (`VideoInput`): + Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If + passing in videos with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + if images is not None: + images = make_batched_images(images) + if videos is not None: + videos = make_batched_videos(videos) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if images is not None: + pixel_values, vision_grid_thws = [], [] + for image in images: + patches, image_grid_thw = self._preprocess( + image, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(patches) + vision_grid_thws.append(image_grid_thw) + pixel_values = np.array(pixel_values) + vision_grid_thws = np.array(vision_grid_thws) + data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} + + if videos is not None: + pixel_values, vision_grid_thws = [], [] + for images in videos: + patches, video_grid_thw = self._preprocess( + images, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(patches) + vision_grid_thws.append(video_grid_thw) + pixel_values = np.array(pixel_values) + vision_grid_thws = np.array(vision_grid_thws) + data = { + "pixel_values_videos": pixel_values, + "video_grid_thw": vision_grid_thws, + } + + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py new file mode 100644 index 000000000000..c18c7ea896a3 --- /dev/null +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -0,0 +1,1971 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5-VL model.""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_outputs import ( + BaseModelOutputWithPast, + ModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + + from ...modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +@dataclass +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLCausalLMOutputWithPast with Qwen2VL->Qwen2_5_VL +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLRotaryEmbedding with Qwen2VL->Qwen2_5_VL +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[Qwen2_5_VLConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Qwen2_5_VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.hidden_size = hidden_size + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.hidden_size) + return hidden_states + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention +class VisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], + torch.finfo(q.dtype).min, + device=q.device, + dtype=q.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.VisionFlashAttention2 +class VisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.VisionSdpaAttention +class VisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_VL_VISION_ATTENTION_CLASSES = { + "eager": VisionAttention, + "flash_attention_2": VisionFlashAttention2, + "sdpa": VisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + + self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLAttention with Qwen2VL->Qwen2_5_VL +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2.5-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLFlashAttention2 with Qwen2VL->Qwen2_5_VL +class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): + """ + Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention with Qwen2VL->Qwen2_5_VL +class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): + """ + Qwen2_5 Sdpaattention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2_5_VLSdpaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2_5_VLMLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + QWEN2_5_VL_START_DOCSTRING, +) +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLPreTrainedModel with Qwen2VL->Qwen2_5_VL +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + hidden_size=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = PatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.up_proj.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.up_proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + QWEN2_5_VL_START_DOCSTRING, +) +# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLModel with Qwen2VL->Qwen2_5_VL +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + def __init__(self, config: Qwen2_5_VLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +QWEN2_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = ( + ( + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + * second_per_grid_t + * self.config.vision_config.tokens_per_second + ) + .long() + .flatten() + ) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + } + ) + return model_inputs diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py new file mode 100644 index 000000000000..e54a1b52e43e --- /dev/null +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Qwen2.5-VL. +""" + +from typing import List, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, VideoInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2_5_VLVideosKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2_5_VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2_5_VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "Qwen2_5_VLImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2_5_VLImageProcessor's [`~Qwen2_5_VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/tests/models/qwen2_5_vl/__init__.py b/tests/models/qwen2_5_vl/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/qwen2_5_vl/test_image_processing_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_image_processing_qwen2_5_vl.py new file mode 100644 index 000000000000..4c991ec710ef --- /dev/null +++ b/tests/models/qwen2_5_vl/test_image_processing_qwen2_5_vl.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from transformers.models.qwen2_5_vl.image_processing_qwen2_5_vl import smart_resize +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ( + ImageProcessingTestMixin, + prepare_image_inputs, +) + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import Qwen2_5_VLImageProcessor + + +class Qwen2_5_VLImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=56, + max_resolution=1024, + min_pixels=56 * 56, + max_pixels=28 * 28 * 1280, + do_normalize=True, + image_mean=OPENAI_CLIP_MEAN, + image_std=OPENAI_CLIP_STD, + do_resize=True, + patch_size=14, + temporal_patch_size=2, + merge_size=2, + do_convert_rgb=True, + ): + self.parent = parent + self.batch_size = batch_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.num_channels = num_channels + self.image_mean = OPENAI_CLIP_MEAN + self.image_std = OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_resize = do_resize + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, + "patch_size": self.patch_size, + "temporal_patch_size": self.temporal_patch_size, + "merge_size": self.merge_size, + } + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + images = prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + return [[image] for image in images] + + +@require_torch +@require_vision +class Qwen2_5_VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = Qwen2_5_VLImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = Qwen2_5_VLImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "min_pixels")) + self.assertTrue(hasattr(image_processing, "max_pixels")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "patch_size")) + self.assertTrue(hasattr(image_processing, "temporal_patch_size")) + self.assertTrue(hasattr(image_processing, "merge_size")) + + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.min_pixels, 56 * 56) + self.assertEqual(image_processor.max_pixels, 28 * 28 * 1280) + + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, min_pixels=256 * 256, max_pixels=640 * 640 + ) + self.assertEqual(image_processor.min_pixels, 256 * 256) + self.assertEqual(image_processor.max_pixels, 640 * 640) + + def test_select_best_resolution(self): + # Test with a final resize resolution + best_resolution = smart_resize(561, 278, factor=28) + self.assertEqual(best_resolution, (560, 280)) + + def test_call_pil(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image[0], Image.Image) + + # Test not batched input + prcocess_out = image_processing(image_inputs[0], return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (4900, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + def test_call_numpy(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image[0], np.ndarray) + + # Test not batched input + prcocess_out = image_processing(image_inputs[0], return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (4900, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + def test_call_pytorch(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image[0], torch.Tensor) + + # Test not batched input + prcocess_out = image_processing(image_inputs[0], return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (4900, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + @unittest.skip(reason="Qwen2_5_VLImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") + def test_call_numpy_4_channels(self): + pass + + def test_nested_input(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = image_inputs[:3] + image_inputs[3:] + prcocess_out = image_processing(image_inputs_nested, return_tensors="pt") + encoded_images_nested = prcocess_out.pixel_values + image_grid_thws_nested = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Image processor should return same pixel values, independently of ipnut format + self.assertTrue((encoded_images_nested == encoded_images).all()) + self.assertTrue((image_grid_thws_nested == expected_image_grid_thws).all()) diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py new file mode 100644 index 000000000000..e41ebf0ca25e --- /dev/null +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -0,0 +1,546 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Qwen2.5-VL model.""" + +import gc +import unittest + +import requests + +from transformers import ( + AutoProcessor, + Qwen2_5_VLConfig, + Qwen2_5_VLForConditionalGeneration, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, +) + + +if is_torch_available(): + import torch + +else: + is_torch_greater_or_equal_than_2_0 = False + +if is_vision_available(): + from PIL import Image + + +class Qwen2_5_VLVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + num_channels=3, + ignore_index=-100, + image_size=14, + bos_token_id=0, + eos_token_id=1, + pad_token_id=2, + vision_start_token_id=3, + image_token_id=4, + video_token_id=5, + hidden_act="silu", + hidden_size=32, + vocab_size=99, + intermediate_size=37, + max_position_embeddings=512, + max_window_layers=3, + model_type="qwen2_5_vl", + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rope_theta=10000, + tie_word_embeddings=True, + is_training=True, + vision_config={ + "depth": 2, + "in_chans": 3, + "hidden_act": "silu", + "intermediate_size": 32, + "out_hidden_size": 32, + "hidden_size": 32, + "num_heads": 4, + "patch_size": 14, + "spatial_patch_size": 14, + "spatial_merge_size": 1, + "temporal_patch_size": 2, + }, + rope_scaling={"type": "mrope", "mrope_section": [2, 1, 1]}, + ): + self.parent = parent + self.ignore_index = ignore_index + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.vision_start_token_id = vision_start_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.hidden_act = hidden_act + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.max_window_layers = max_window_layers + self.model_type = model_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.rope_theta = rope_theta + self.tie_word_embeddings = tie_word_embeddings + self.vision_config = vision_config + self.rope_scaling = rope_scaling + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.is_training = is_training + self.vocab_size = vocab_size + self.num_image_tokens = 32 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return Qwen2_5_VLConfig( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + hidden_act=self.hidden_act, + max_position_embeddings=self.max_position_embeddings, + vision_config=self.vision_config, + model_type=self.model_type, + max_window_layers=self.max_window_layers, + rope_scaling=self.rope_scaling, + tie_word_embeddings=self.tie_word_embeddings, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + vision_start_token_id=self.vision_start_token_id, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vocab_size=self.vocab_size, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = config.vision_config.patch_size + temporal_patch_size = config.vision_config.temporal_patch_size + pixel_values = floats_tensor( + [ + self.batch_size * (self.image_size**2) // (patch_size**2), + self.num_channels * (patch_size**2) * temporal_patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.video_token_id] = self.pad_token_id + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[:, self.num_image_tokens] = self.image_token_id + labels = torch.zeros( + (self.batch_size, self.seq_length), + dtype=torch.long, + device=torch_device, + ) + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size), + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + return config, inputs_dict + + def create_and_check_qwen2_5_vl_model_fp16_forward( + self, config, input_ids, pixel_values, attention_mask, image_grid_thw + ): + model = Qwen2_5_VLForConditionalGeneration(config=config) + model.to(torch_device) + model.half() + model.eval() + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + def create_and_check_qwen2_5_vl_model_fp16_autocast_forward( + self, config, input_ids, pixel_values, attention_mask, image_grid_thw + ): + config.torch_dtype = torch.float16 + model = Qwen2_5_VLForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type="cuda", dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `Qwen2_5_VLForConditionalGeneration`. + """ + + all_model_classes = (Qwen2_5_VLForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Qwen2_5_VLForConditionalGeneration,) if is_torch_available() else () + test_pruning = False + test_head_masking = False + + def setUp(self): + self.model_tester = Qwen2_5_VLVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen2_5_VLConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs through an error with explicit message saying what is wrong + when number of images don't match number of image tokens in the text. + Also we need to test multi-image cases when one prompr has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + _ = model(**input_dict) # successfull forward with no modifications + + # remove one image but leave the image token in text + patch_size = config.vision_config.patch_size + one_img_length = (self.model_tester.image_size**2) // (patch_size**2) + input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...] + input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...] + with self.assertRaises(ValueError): + _ = model(**input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = input_dict["input_ids"][:1] + pixel_values = input_dict["pixel_values"][:one_img_length] + image_grid_thw = input_dict["image_grid_thw"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) + _ = model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + @unittest.skip(reason="Feedforward chunking is not yet supported") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="CPU offload is not yet supported") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_model_parallelism(self): + pass + + @unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models") + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="We cannot configure to output a smaller model.") + def test_model_is_small(self): + pass + + @unittest.skip( + reason="Qwen2.5-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that" + ) + def test_beam_search_low_memory(self): + pass + + @unittest.skip( + reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs" + ) + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") + def test_generate_compile_fullgraph(self): + pass + + +@require_torch +class Qwen2_5_VLIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + self.messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What kind of dog is this?"}, + ], + } + ] + url = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg" + self.image = Image.open(requests.get(url, stream=True).raw) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + def test_small_model_integration_test(self): + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + ) + + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text], images=[self.image], return_tensors="pt") + + expected_input_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151655] # fmt: skip + assert expected_input_ids == inputs.input_ids[0].tolist()[:17] + + expected_pixel_slice = torch.tensor( + [ + [0.8792, 0.8792, 0.9084], + [1.1858, 1.1858, 1.2296], + [1.2004, 1.2004, 1.2150], + [1.4340, 1.4340, 1.4194], + [1.3902, 1.4048, 1.4194], + [1.5216, 1.5362, 1.5362], + ], + dtype=torch.float32, + device="cpu", + ) + assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3) + + # verify generation + inputs = inputs.to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30) + EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets" + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_batch(self): + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + 'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices', + 'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets' + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_batch_wo_image(self): + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + messages2 = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who are you?"}, + ] + text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + 'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets', + 'system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to assist with various tasks and answer questions to the best of my' + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_batch_different_resolutions(self): + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + text2 = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + image2 = self.image.resize((224, 224)) + inputs = self.processor( + text=[text, text2], + images=[self.image, image2], + padding=True, + return_tensors="pt", + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + ] + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_flash_attn + @require_torch_gpu + def test_small_model_integration_test_batch_flashatt2(self): + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + ] + + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True)[0], + self.processor.batch_decode(output, skip_special_tokens=True)[1], + ) + + @slow + @require_flash_attn + @require_torch_gpu + def test_small_model_integration_test_batch_wo_image_flashatt2(self): + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + messages2 = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who are you?"}, + ] + text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to answer a wide range of questions and provide information on various topics", + ] + + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) diff --git a/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py new file mode 100644 index 000000000000..ff7f04e3b4c7 --- /dev/null +++ b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py @@ -0,0 +1,113 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import pytest + +from transformers import AutoProcessor, Qwen2Tokenizer +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import Qwen2_5_VLImageProcessor, Qwen2_5_VLProcessor + + +@require_vision +@require_torch +class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Qwen2_5_VLProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = Qwen2_5_VLProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", patch_size=4) + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + image_processor = self.get_image_processor() + + processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) + processor = Qwen2_5_VLProcessor.from_pretrained(self.tmpdirname, use_fast=False) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string()) + self.assertIsInstance(processor.tokenizer, Qwen2Tokenizer) + self.assertIsInstance(processor.image_processor, Qwen2_5_VLImageProcessor) + + def test_image_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + input_image_proc = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, text="dummy", return_tensors="np") + + for key in input_image_proc.keys(): + self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual( + list(inputs.keys()), + ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"], + ) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + # test if it raises when no text is passed + with pytest.raises(TypeError): + processor(images=image_input) + + def test_model_input_names(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + video_inputs = self.prepare_video_inputs() + + inputs = processor(text=input_str, images=image_input, videos=video_inputs) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) From 8a713a90cca403fa9434b5aa0ee2c0ef2d7361a0 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Thu, 9 Jan 2025 02:59:46 +0800 Subject: [PATCH 02/42] fix --- src/transformers/__init__.py | 24 ++--- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 91 ++++++------------- 2 files changed, 38 insertions(+), 77 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cf043ca6e0c1..24ec37d3f344 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -701,6 +701,10 @@ "Qwen2Config", "Qwen2Tokenizer", ], + "models.qwen2_5_vl": [ + "Qwen2_5_VLConfig", + "Qwen2_5_VLProcessor", + ], "models.qwen2_audio": [ "Qwen2AudioConfig", "Qwen2AudioEncoderConfig", @@ -711,10 +715,6 @@ "Qwen2VLConfig", "Qwen2VLProcessor", ], - "models.qwen2_5_vl": [ - "Qwen2_5_VLConfig", - "Qwen2_5_VLProcessor", - ], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], "models.recurrent_gemma": ["RecurrentGemmaConfig"], "models.reformer": ["ReformerConfig"], @@ -1257,8 +1257,8 @@ _import_structure["models.pixtral"].append("PixtralImageProcessor") _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) - _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) _import_structure["models.qwen2_5_vl"].extend(["Qwen2_5_VLImageProcessor"]) + _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) _import_structure["models.sam"].extend(["SamImageProcessor"]) _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) @@ -3247,6 +3247,13 @@ "Qwen2PreTrainedModel", ] ) + _import_structure["models.qwen2_5_vl"].extend( + [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + ] + ) _import_structure["models.qwen2_audio"].extend( [ "Qwen2AudioEncoder", @@ -3271,13 +3278,6 @@ "Qwen2VLPreTrainedModel", ] ) - _import_structure["models.qwen2_5_vl"].extend( - [ - "Qwen2_5_VLForConditionalGeneration", - "Qwen2_5_VLModel", - "Qwen2_5_VLPreTrainedModel", - ] - ) _import_structure["models.rag"].extend( [ "RagModel", diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c18c7ea896a3..bc168d63b834 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -368,10 +368,7 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.proj = nn.Linear(dim, dim) def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -379,17 +376,10 @@ def forward( k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) attention_mask = torch.full( - [1, seq_length, seq_length], - torch.finfo(q.dtype).min, - device=q.device, - dtype=q.dtype, + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype ) for i in range(1, len(cu_seqlens)): - attention_mask[ - ..., - cu_seqlens[i - 1] : cu_seqlens[i], - cu_seqlens[i - 1] : cu_seqlens[i], - ] = 0 + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 q = q.transpose(0, 1) k = k.transpose(0, 1) @@ -413,10 +403,7 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.proj = nn.Linear(dim, dim) def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -440,10 +427,7 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.proj = nn.Linear(dim, dim) def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -452,11 +436,7 @@ def forward( attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): - attention_mask[ - ..., - cu_seqlens[i - 1] : cu_seqlens[i], - cu_seqlens[i - 1] : cu_seqlens[i], - ] = True + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) @@ -570,9 +550,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( @@ -580,11 +560,7 @@ def forward( ) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - } # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -597,7 +573,7 @@ def forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # Fix precision issues in Qwen2.5-VL float16 inference + # Fix precision issues in Qwen2-VL float16 inference # Replace inf values with zeros in attention weights to prevent NaN propagation if query_states.dtype == torch.float16: attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) @@ -659,9 +635,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Because the input can be padded, the absolute sequence length depends on the max position id. cos, sin = position_embeddings @@ -670,11 +646,7 @@ def forward( ) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - } # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -743,8 +715,8 @@ def forward( # Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): """ - Qwen2_5 Sdpaattention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2_5_VLSdpaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ @@ -783,9 +755,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( @@ -793,11 +765,7 @@ def forward( ) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - } # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -1201,9 +1169,7 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) # the hard coded `3` is for temporal, height and width. @@ -1213,11 +1179,7 @@ def forward( position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds @@ -1283,6 +1245,7 @@ def forward( attentions=all_self_attns, ) + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1359,6 +1322,7 @@ def _update_causal_mask( return causal_mask @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2_5_VL def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1400,10 +1364,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: From be1f811c9db80d5476b3ab69dfadc2ec6dedf2fc Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Thu, 9 Jan 2025 03:12:22 +0800 Subject: [PATCH 03/42] pass check table --- docs/source/en/index.md | 1 + src/transformers/utils/dummy_pt_objects.py | 21 +++++++++++++++++++ .../utils/dummy_vision_objects.py | 7 +++++++ 3 files changed, 29 insertions(+) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 0135f8f0eb91..6bc3820270d1 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -282,6 +282,7 @@ Flax), PyTorch, and/or TensorFlow. | [PVTv2](model_doc/pvt_v2) | ✅ | ❌ | ❌ | | [QDQBert](model_doc/qdqbert) | ✅ | ❌ | ❌ | | [Qwen2](model_doc/qwen2) | ✅ | ❌ | ❌ | +| [Qwen2_5_VL](model_doc/qwen2_5_vl) | ✅ | ❌ | ❌ | | [Qwen2Audio](model_doc/qwen2_audio) | ✅ | ❌ | ❌ | | [Qwen2MoE](model_doc/qwen2_moe) | ✅ | ❌ | ❌ | | [Qwen2VL](model_doc/qwen2_vl) | ✅ | ❌ | ❌ | diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 76df70f19203..92d8a6c4d979 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7808,6 +7808,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Qwen2_5_VLForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen2_5_VLModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen2_5_VLPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Qwen2AudioEncoder(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index e9e87a4b3d44..185538c8825f 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -555,6 +555,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class Qwen2_5_VLImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class Qwen2VLImageProcessor(metaclass=DummyObject): _backends = ["vision"] From a666eb02f6f6953fbef50c7b41280503303d7688 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Wed, 15 Jan 2025 02:14:49 +0800 Subject: [PATCH 04/42] add modular file --- .../models/qwen2_5_vl/modular_qwen_2_5_vl.py | 917 ++++++++++++++++++ .../qwen2_5_vl/test_processor_qwen2_5_vl.py | 2 +- utils/check_repo.py | 1 + 3 files changed, 919 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py new file mode 100644 index 000000000000..91e2d70d17ce --- /dev/null +++ b/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py @@ -0,0 +1,917 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5-VL model.""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN + +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VisionTransformerPretrainedModel, + Qwen2VLModel, + Qwen2VLForConditionalGeneration, + Qwen2RMSNorm, + Qwen2VLVisionBlock, + Qwen2VLPreTrainedModel, + Qwen2VLCausalLMOutputWithPast, + PatchMerger, + VisionAttention, + VisionSdpaAttention, + PatchEmbed, +) + +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb + + from ...modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + + +def apply_rotary_pos_emb_flashatt( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + tensor_ = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) + return output + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=25, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=list(range(7, 32, 8)), + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(Qwen2VLConfig): + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + use_sliding_window=use_sliding_window, + sliding_window=sliding_window, + max_window_layers=max_window_layers, + attention_dropout=attention_dropout, + vision_config=vision_config, + rope_scaling=rope_scaling, + **kwargs, + ) + if isinstance(vision_config, dict): + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2_5_VLVisionConfig() + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + + +class Qwen2_5_VLPatchMerger(PatchMerger): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__(dim, context_dim, spatial_merge_size) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + + +class VisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen + ).reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": VisionAttention, + "flash_attention_2": VisionFlashAttention2, + "sdpa": VisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(Qwen2VLVisionBlock): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__(config, attn_implementation) + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + hidden_size=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen2_5_VLVisionBlock(config, config._attn_implementation) + for _ in range(config.depth) + ] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.up_proj.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.up_proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( + config.vision_config + ) + + def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * second_per_grid_t + * self.config.vision_config.tokens_per_second + ) + .long() + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if ( + position_ids is None + and input_ids is not None + and (attention_mask is None or attention_mask.ndim == 2) + ): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + cache_position is not None and cache_position[0] == 0 + ) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = ( + self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + } + ) + return model_inputs \ No newline at end of file diff --git a/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py index ff7f04e3b4c7..2bc47d76acd4 100644 --- a/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py @@ -36,7 +36,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - processor = Qwen2_5_VLProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", patch_size=4) + processor = Qwen2_5_VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", patch_size=4) processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): diff --git a/utils/check_repo.py b/utils/check_repo.py index d20760bcf75e..a1190b68ed30 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -70,6 +70,7 @@ "Pop2PianoStack", "Qwen2AudioEncoder", "Qwen2VisionTransformerPretrainedModel", + "Qwen2_5_VisionTransformerPretrainedModel", "SwitchTransformersStack", "TFDPRSpanPredictor", "MaskFormerSwinModel", From a96b32d64a267678d1cedb913326bee80da12eb6 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Wed, 15 Jan 2025 02:32:00 +0800 Subject: [PATCH 05/42] fix style --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 15 +- .../models/qwen2_5_vl/modular_qwen_2_5_vl.py | 216 +++++------------- 2 files changed, 73 insertions(+), 158 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index bc168d63b834..93716db2d4d4 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -54,10 +54,12 @@ if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb from ...modeling_flash_attention_utils import _flash_attention_forward else: flash_attn_varlen_func = None + apply_rotary_emb = None logger = logging.get_logger(__name__) @@ -294,6 +296,14 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> to return output +def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + tensor_ = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) + return output + + class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() @@ -372,6 +382,7 @@ def forward( ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -407,8 +418,8 @@ def forward( ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py index 91e2d70d17ce..d74d53b37ba6 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py @@ -19,8 +19,6 @@ # limitations under the License. """PyTorch Qwen2.5-VL model.""" -import math -from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch @@ -29,38 +27,36 @@ import torch.utils.checkpoint from torch.nn import CrossEntropyLoss -from ...activations import ACT2FN - +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VisionTransformerPretrainedModel, - Qwen2VLModel, - Qwen2VLForConditionalGeneration, + PatchEmbed, + PatchMerger, Qwen2RMSNorm, - Qwen2VLVisionBlock, - Qwen2VLPreTrainedModel, Qwen2VLCausalLMOutputWithPast, - PatchMerger, + Qwen2VLForConditionalGeneration, + Qwen2VLPreTrainedModel, + Qwen2VLVisionBlock, VisionAttention, + VisionRotaryEmbedding, VisionSdpaAttention, - PatchEmbed, ) -from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig +from ...activations import ACT2FN +from ...cache_utils import StaticCache +from ...configuration_utils import PretrainedConfig +from ...utils import is_flash_attn_2_available if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb - from ...modeling_flash_attention_utils import _flash_attention_forward else: flash_attn_varlen_func = None apply_rotary_emb = None -def apply_rotary_pos_emb_flashatt( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: +def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() cos = freqs.cos() sin = freqs.sin() @@ -174,9 +170,7 @@ def __init__(self, config, bias: bool = False): self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): - return self.down_proj( - self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) - ) + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) class Qwen2_5_VLPatchMerger(PatchMerger): @@ -199,19 +193,14 @@ def forward( rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = ( - self.qkv(hidden_states) - .reshape(seq_length, 3, self.num_heads, -1) - .permute(1, 0, 2, 3) - .unbind(0) - ) + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func( - q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen - ).reshape(seq_length, -1) + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) attn_output = self.proj(attn_output) return attn_output @@ -257,10 +246,7 @@ def __init__(self, config) -> None: self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( - [ - Qwen2_5_VLVisionBlock(config, config._attn_implementation) - for _ in range(config.depth) - ] + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] ) self.merger = Qwen2_5_VLPatchMerger( dim=config.out_hidden_size, @@ -308,18 +294,14 @@ def get_window_index(self, grid_thw): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 - vit_merger_window_size = ( - self.window_size // self.spatial_merge_size // self.patch_size - ) + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size, ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w - ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size @@ -342,18 +324,14 @@ def get_window_index(self, grid_thw): index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) - cu_seqlens_tmp = ( - seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - ) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens - def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) @@ -365,20 +343,14 @@ def forward( cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = hidden_states.size() - hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 - ) + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 - ) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum( + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 @@ -417,9 +389,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): def __init__(self, config): super().__init__(config) - self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( - config.vision_config - ) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) def get_rope_index( self, @@ -487,9 +457,7 @@ def get_rope_index( video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] - if input_ids is not None and ( - image_grid_thw is not None or video_grid_thw is not None - ): + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) @@ -505,9 +473,7 @@ def get_rope_index( for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere( - input_ids == vision_start_token_id - ).squeeze(1) + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() @@ -555,75 +521,39 @@ def get_rope_index( ) text_len = ed - st - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) t_index = ( ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) * second_per_grid_t * self.config.vision_config.tokens_per_second ) .long() .flatten() ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( - position_ids.device - ) - mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) - ) - mrope_position_deltas = torch.tensor( - mrope_position_deltas, device=input_ids.device - ).unsqueeze(1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = ( - position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) - ) - max_position_ids = position_ids.max(0, keepdim=False)[0].max( - -1, keepdim=True - )[0] + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( @@ -639,10 +569,6 @@ def get_rope_index( return position_ids, mrope_position_deltas - @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC - ) def forward( self, input_ids: torch.LongTensor = None, @@ -703,19 +629,11 @@ def forward( "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) @@ -734,9 +652,7 @@ def forward( .expand_as(inputs_embeds) .to(inputs_embeds.device) ) - image_embeds = image_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: @@ -754,24 +670,16 @@ def forward( .expand_as(inputs_embeds) .to(inputs_embeds.device) ) - video_embeds = video_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if ( - position_ids is None - and input_ids is not None - and (attention_mask is None or attention_mask.ndim == 2) - ): + if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only - if ( - cache_position is not None and cache_position[0] == 0 - ) or self.rope_deltas is None: + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, @@ -863,9 +771,7 @@ def prepare_inputs_for_generation( if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (the "else", a no op, is Exception 2) + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if cache_position[0] != 0: @@ -886,18 +792,16 @@ def prepare_inputs_for_generation( batch_size, sequence_length = input_ids.shape device = input_ids.device - attention_mask = ( - self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, ) model_inputs.update( @@ -914,4 +818,4 @@ def prepare_inputs_for_generation( "second_per_grid_ts": second_per_grid_ts, } ) - return model_inputs \ No newline at end of file + return model_inputs From 8191ce9bad44fa683abc071ec30f807f2286910e Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Wed, 15 Jan 2025 02:40:06 +0800 Subject: [PATCH 06/42] Update src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Co-authored-by: Minho Shim <6764739+minostauros@users.noreply.github.com> --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 93716db2d4d4..db41a1de8739 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1513,7 +1513,7 @@ def get_decoder(self): def get_rope_index( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, From 2b70136fe18c80a72b2d7e5d0d6ddcf0479b9ace Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Wed, 15 Jan 2025 02:40:21 +0800 Subject: [PATCH 07/42] Update src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Co-authored-by: Minho Shim <6764739+minostauros@users.noreply.github.com> --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index db41a1de8739..c3c710792d6c 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1799,7 +1799,7 @@ def forward( attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( From c23055b742efdcf072c47773e504b4be50740935 Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Wed, 15 Jan 2025 02:44:04 +0800 Subject: [PATCH 08/42] Update src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Co-authored-by: Minho Shim <6764739+minostauros@users.noreply.github.com> --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c3c710792d6c..ceefe0ef5445 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1672,7 +1672,7 @@ def get_rope_index( if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: From b1e2f6881b0ded5d067b1da479f99967610dfb97 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Wed, 15 Jan 2025 03:02:55 +0800 Subject: [PATCH 09/42] padd copy check --- .../models/qwen2_5_vl/image_processing_qwen2_5_vl.py | 1 + src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py index 5bc8b2db5104..cb2c0057ca09 100644 --- a/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py @@ -169,6 +169,7 @@ class Qwen2_5_VLImageProcessor(BaseImageProcessor): "image_grid_thw", "pixel_values_videos", "video_grid_thw", + "second_per_grid_ts", ] def __init__( diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ceefe0ef5445..577318e16c02 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -405,7 +405,6 @@ def forward( return attn_output -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.VisionFlashAttention2 class VisionFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() From 9d0e553e45aaa743f628d43440f539ef1be8a664 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Fri, 17 Jan 2025 17:37:50 +0800 Subject: [PATCH 10/42] use modular --- .../qwen2_5_vl/configuration_qwen2_5_vl.py | 86 +- .../qwen2_5_vl/image_processing_qwen2_5_vl.py | 60 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1026 +++++++++-------- ...r_qwen_2_5_vl.py => modular_qwen2_5_vl.py} | 235 +++- .../qwen2_5_vl/processing_qwen2_5_vl.py | 26 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 3 + 6 files changed, 852 insertions(+), 584 deletions(-) rename src/transformers/models/qwen2_5_vl/{modular_qwen_2_5_vl.py => modular_qwen2_5_vl.py} (76%) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 41b2fa1f92a9..08cfa44ea4b0 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -1,6 +1,17 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. # +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,14 +23,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Qwen2.5-VL model configuration""" - from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation -from ...utils import logging - - -logger = logging.get_logger(__name__) class Qwen2_5_VLVisionConfig(PretrainedConfig): @@ -38,7 +43,7 @@ def __init__( patch_size=14, spatial_merge_size=2, temporal_patch_size=2, - tokens_per_second=25, + tokens_per_second=4, window_size=112, out_hidden_size=3584, fullatt_block_indexes=list(range(7, 32, 8)), @@ -62,12 +67,44 @@ def __init__( self.out_hidden_size = out_hidden_size +class Qwen25VisionConfig(PretrainedConfig): + model_type = "qwen2_5_" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + class Qwen2_5_VLConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a - Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - Qwen2.5-VL-7B-Instruct [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -75,7 +112,7 @@ class Qwen2_5_VLConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 152064): - Vocabulary size of the Qwen2.5-VL model. Defines the number of different tokens that can be represented by the + Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen2_5_VLModel`] hidden_size (`int`, *optional*, defaults to 8192): Dimension of the hidden representations. @@ -158,19 +195,29 @@ class Qwen2_5_VLConfig(PretrainedConfig): ```python >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig - >>> # Initializing a Qwen2.5-VL configuration + >>> # Initializing a Qwen2_5_VL style configuration >>> configuration = Qwen2_5_VLConfig() + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + >>> # Accessing the model configuration >>> configuration = model.config - - >>> # Initializing a model from the Qwen2.5-VL configuration - >>> model = Qwen2_5_VLForConditionalGeneration(configuration) ```""" model_type = "qwen2_5_vl" sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2_5_VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, @@ -195,10 +242,11 @@ def __init__( rope_scaling=None, **kwargs, ): + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) if isinstance(vision_config, dict): - self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + self.vision_config = Qwen25VisionConfig(**vision_config) elif vision_config is None: - self.vision_config = Qwen2_5_VLVisionConfig() + self.vision_config = Qwen25VisionConfig() self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -233,5 +281,7 @@ def __init__( self.rope_scaling["type"] = "default" self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self, ignore_keys={"mrope_section"}) - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + if isinstance(vision_config, dict): + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2_5_VLVisionConfig() diff --git a/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py index cb2c0057ca09..f0b405febc78 100644 --- a/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. # @@ -17,19 +23,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Image processor class for Qwen2.5-VL.""" - import math from typing import Dict, List, Optional, Union import numpy as np -from ...image_processing_utils import BaseImageProcessor, BatchFeature -from ...image_transforms import ( - convert_to_rgb, - resize, - to_channel_dimension_format, -) +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import BaseImageProcessor +from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format from ...image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, @@ -49,13 +50,13 @@ from ...utils import TensorType, is_vision_available, logging -logger = logging.get_logger(__name__) - - if is_vision_available(): from PIL import Image +logger = logging.get_logger(__name__) + + def make_batched_images(images) -> List[List[ImageInput]]: """ Accepts images in list or nested list format, and makes a list of images for preprocessing. @@ -79,7 +80,6 @@ def make_batched_images(images) -> List[List[ImageInput]]: raise ValueError(f"Could not make batched images from {images}") -# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): return videos @@ -97,11 +97,7 @@ def make_batched_videos(videos) -> List[VideoInput]: def smart_resize( - height: int, - width: int, - factor: int = 28, - min_pixels: int = 56 * 56, - max_pixels: int = 14 * 14 * 4 * 1280, + height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 ): """Rescales the image so that the following conditions are met: @@ -262,7 +258,7 @@ def _preprocess( # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] - if is_scaled_image(images[0]) and do_rescale: + if do_rescale and is_scaled_image(images[0]): logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." @@ -284,10 +280,7 @@ def _preprocess( max_pixels=self.max_pixels, ) image = resize( - image, - size=(resized_height, resized_width), - resample=resample, - input_data_format=input_data_format, + image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format ) if do_rescale: @@ -295,10 +288,7 @@ def _preprocess( if do_normalize: image = self.normalize( - image=image, - mean=image_mean, - std=image_std, - input_data_format=input_data_format, + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format ) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) @@ -307,14 +297,12 @@ def _preprocess( patches = np.array(processed_images) if data_format == ChannelDimension.LAST: patches = patches.transpose(0, 3, 1, 2) - if patches.shape[0] == 1: - patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) + if patches.shape[0] % self.temporal_patch_size != 0: + repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0) + patches = np.concatenate([patches, repeats], axis=0) channel = patches.shape[1] grid_t = patches.shape[0] // self.temporal_patch_size - grid_h, grid_w = ( - resized_height // self.patch_size, - resized_width // self.patch_size, - ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size patches = patches.reshape( grid_t, self.temporal_patch_size, @@ -328,8 +316,7 @@ def _preprocess( ) patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( - grid_t * grid_h * grid_w, - channel * self.temporal_patch_size * self.patch_size * self.patch_size, + grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size ) return flatten_patches, (grid_t, grid_h, grid_w) @@ -473,9 +460,6 @@ def preprocess( vision_grid_thws.append(video_grid_thw) pixel_values = np.array(pixel_values) vision_grid_thws = np.array(vision_grid_thws) - data = { - "pixel_values_videos": pixel_values, - "video_grid_thw": vision_grid_thws, - } + data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws} return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 577318e16c02..a78dc64952f1 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. # @@ -17,7 +23,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Qwen2.5-VL model.""" import math from dataclasses import dataclass @@ -26,21 +31,16 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, -) -from ...modeling_outputs import ( - BaseModelOutputWithPast, - ModelOutput, -) +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import VideosKwargs from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -56,19 +56,453 @@ from flash_attn import flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb - from ...modeling_flash_attention_utils import _flash_attention_forward -else: - flash_attn_varlen_func = None - apply_rotary_emb = None +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + tensor_ = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) + return output + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +Qwen2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VisionTransformerPretrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.up_proj.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.up_proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) -logger = logging.get_logger(__name__) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] -_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + return hidden_states @dataclass -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLCausalLMOutputWithPast with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): """ Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. @@ -107,42 +541,6 @@ class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): rope_deltas: Optional[torch.LongTensor] = None -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class Qwen2_5_VLMLP(nn.Module): - def __init__(self, config, bias: bool = False): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLRotaryEmbedding with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLRotaryEmbedding(nn.Module): def __init__( self, @@ -211,7 +609,7 @@ def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) - # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for thw grids # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) @@ -231,261 +629,67 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - -def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - tensor_ = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) - return output - - -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - -class PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_channels: int = 3, - hidden_size: int = 1152, - ) -> None: - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.hidden_size = hidden_size - - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d( - in_channels, - hidden_size, - kernel_size=kernel_size, - stride=kernel_size, - bias=False, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype - hidden_states = hidden_states.view( - -1, - self.in_channels, - self.temporal_patch_size, - self.patch_size, - self.patch_size, - ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.hidden_size) - return hidden_states - - -class PatchMerger(nn.Module): - def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: - super().__init__() - self.hidden_size = context_dim * (spatial_merge_size**2) - self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) - self.mlp = nn.Sequential( - nn.Linear(self.hidden_size, self.hidden_size), - nn.GELU(), - nn.Linear(self.hidden_size, dim), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) - return x - - -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention -class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -class VisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 - ) - attn_output = self.proj(attn_output) - return attn_output - - -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.VisionSdpaAttention -class VisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj -QWEN2_VL_VISION_ATTENTION_CLASSES = { - "eager": VisionAttention, - "flash_attention_2": VisionFlashAttention2, - "sdpa": VisionSdpaAttention, -} +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). -class Qwen2_5_VLVisionBlock(nn.Module): - def __init__(self, config, attn_implementation: str = "sdpa") -> None: - super().__init__() - self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) - self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. - self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.hidden_size, num_heads=config.num_heads - ) - self.mlp = Qwen2_5_VLMLP(config, bias=True) + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: - hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - ) - hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - return hidden_states + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -498,7 +702,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLAttention with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -610,7 +813,6 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLFlashAttention2 with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): """ Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` @@ -722,7 +924,6 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -814,7 +1015,7 @@ def forward( return attn_output, None, past_key_value -QWEN2_VL_ATTENTION_CLASSES = { +QWEN2_5_VL_ATTENTION_CLASSES = { "eager": Qwen2_5_VLAttention, "flash_attention_2": Qwen2_5_VLFlashAttention2, "sdpa": Qwen2_5_VLSdpaAttention, @@ -831,9 +1032,9 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.mlp = Qwen2_5_VLMLP(config) + self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -905,28 +1106,10 @@ def forward( return outputs -QWEN2_5_VL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2_5_VLConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - @add_start_docstrings( "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", - QWEN2_5_VL_START_DOCSTRING, + Qwen2_5_VL_START_DOCSTRING, ) -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLPreTrainedModel with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLPreTrainedModel(PreTrainedModel): config_class = Qwen2_5_VLConfig base_model_prefix = "model" @@ -938,6 +1121,9 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): @@ -950,170 +1136,10 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): - config_class = Qwen2_5_VLVisionConfig - _no_split_modules = ["Qwen2_5_VLVisionBlock"] - - def __init__(self, config) -> None: - super().__init__(config) - self.spatial_merge_size = config.spatial_merge_size - self.patch_size = config.patch_size - self.fullatt_block_indexes = config.fullatt_block_indexes - self.window_size = config.window_size - self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size - - self.patch_embed = PatchEmbed( - patch_size=config.patch_size, - temporal_patch_size=config.temporal_patch_size, - in_channels=config.in_channels, - hidden_size=config.hidden_size, - ) - - head_dim = config.hidden_size // config.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) - - self.blocks = nn.ModuleList( - [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] - ) - self.merger = PatchMerger( - dim=config.out_hidden_size, - context_dim=config.hidden_size, - spatial_merge_size=config.spatial_merge_size, - ) - self.gradient_checkpointing = False - - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.up_proj.weight.dtype - - def get_device(self) -> torch.device: - return self.blocks[0].mlp.up_proj.weight.device - - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - - seq_len, _ = hidden_states.size() - hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - hidden_states = hidden_states[window_index, :, :] - hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - for layer_num, blk in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: - cu_seqlens_now = cu_seqlens - else: - cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - rotary_pos_emb=rotary_pos_emb, - ) - hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) - hidden_states = hidden_states[reverse_indices, :] - - return hidden_states - - @add_start_docstrings( "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", - QWEN2_5_VL_START_DOCSTRING, + Qwen2_5_VL_START_DOCSTRING, ) -# Copied from transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLModel with Qwen2VL->Qwen2_5_VL class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): def __init__(self, config: Qwen2_5_VLConfig): super().__init__(config) @@ -1255,7 +1281,6 @@ def forward( attentions=all_self_attns, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1332,7 +1357,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2_5_VL def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1400,7 +1424,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -QWEN2_VL_INPUTS_DOCSTRING = r""" +QWEN2_5_VL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -1461,18 +1485,16 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses - [`Qwen2VLImageProcessor`] for processing images. + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing images. pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): The tensors corresponding to the input videos. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses - [`Qwen2VLImageProcessor`] for processing videos. + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): - The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. """ @@ -1480,6 +1502,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"] def __init__(self, config): super().__init__(config) @@ -1512,7 +1536,7 @@ def get_decoder(self): def get_rope_index( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, @@ -1671,7 +1695,7 @@ def get_rope_index( if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: @@ -1688,7 +1712,7 @@ def get_rope_index( return position_ids, mrope_position_deltas - @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1798,7 +1822,7 @@ def forward( attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( @@ -1940,3 +1964,7 @@ def prepare_inputs_for_generation( } ) return model_inputs + + +class Qwen2_5_VLVideosKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py similarity index 76% rename from src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py rename to src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index d74d53b37ba6..26eeae5fba93 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen_2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -19,6 +19,7 @@ # limitations under the License. """PyTorch Qwen2.5-VL model.""" +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch @@ -28,6 +29,7 @@ from torch.nn import CrossEntropyLoss from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig +from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from transformers.models.qwen2_vl.modeling_qwen2_vl import ( PatchEmbed, PatchMerger, @@ -35,15 +37,19 @@ Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration, Qwen2VLPreTrainedModel, - Qwen2VLVisionBlock, VisionAttention, VisionRotaryEmbedding, VisionSdpaAttention, ) +from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor from ...activations import ACT2FN from ...cache_utils import StaticCache from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, VideoInput +from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_flash_attn_2_available @@ -80,7 +86,7 @@ def __init__( patch_size=14, spatial_merge_size=2, temporal_patch_size=2, - tokens_per_second=25, + tokens_per_second=4, window_size=112, out_hidden_size=3584, fullatt_block_indexes=list(range(7, 32, 8)), @@ -179,7 +185,7 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) -class VisionFlashAttention2(nn.Module): +class Qwen2_5_VLVisionFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads @@ -205,16 +211,24 @@ def forward( return attn_output +class Qwen2_5_VLVisionAttention(VisionAttention): + pass + + +class Qwen2_5_VLVisionSdpaAttention(VisionSdpaAttention): + pass + + QWEN2_5_VL_VISION_ATTENTION_CLASSES = { - "eager": VisionAttention, - "flash_attention_2": VisionFlashAttention2, - "sdpa": VisionSdpaAttention, + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, } -class Qwen2_5_VLVisionBlock(Qwen2VLVisionBlock): +class Qwen2_5_VLVisionBlock(nn.Module): def __init__(self, config, attn_implementation: str = "sdpa") -> None: - super().__init__(config, attn_implementation) + super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( @@ -222,13 +236,22 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None: ) self.mlp = Qwen2_5_VLMLP(config, bias=True) + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + class Qwen2_5_VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): config_class = Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) self.spatial_merge_size = config.spatial_merge_size self.patch_size = config.patch_size self.fullatt_block_indexes = config.fullatt_block_indexes @@ -239,7 +262,7 @@ def __init__(self, config) -> None: patch_size=config.patch_size, temporal_patch_size=config.temporal_patch_size, in_channels=config.in_channels, - hidden_size=config.hidden_size, + embed_dim=config.hidden_size, ) head_dim = config.hidden_size // config.num_heads @@ -383,6 +406,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): + pass + + class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): config_class = Qwen2_5_VLConfig _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"] @@ -588,7 +616,7 @@ def forward( rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, - ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -738,7 +766,7 @@ def forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - return Qwen2VLCausalLMOutputWithPast( + return Qwen2_5_VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, @@ -819,3 +847,184 @@ def prepare_inputs_for_generation( } ) return model_inputs + + +class Qwen2_5_VLImageProcessor(Qwen2VLImageProcessor): + r""" + Constructs a Qwen2.5-VL image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + min_pixels (`int`, *optional*, defaults to `56 * 56`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + model_input_names = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + +class Qwen2_5_VLVideosKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(Qwen2VLProcessor): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2_5_VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2_5_VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + image_processor_class = "Qwen2_5_VLImageProcessor" + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2_5_VLImageProcessor's [`~Qwen2_5_VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index e54a1b52e43e..7207148fcee8 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. # @@ -17,24 +23,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Processor class for Qwen2.5-VL. -""" - from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -class Qwen2_5_VLVideosKwargs(VideosKwargs, total=False): - fps: Union[List[float], float] +from .modeling_qwen2_5_vl import Qwen2_5_VLVideosKwargs class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): @@ -63,6 +58,7 @@ class Qwen2_5_VLProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] valid_kwargs = ["chat_template"] + image_processor_class = "Qwen2_5_VLImageProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") @@ -206,9 +202,7 @@ def post_process_image_text_to_text(self, generated_outputs): `List[str]`: The decoded text. """ return self.tokenizer.batch_decode( - generated_outputs, - skip_special_tokens=True, - clean_up_tokenization_spaces=False, + generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False ) @property diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index ea169987b94c..abb4a023a6bc 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -917,6 +917,9 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): From 23e6635c3c29771ab7f642755146a03aee831d4f Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Fri, 17 Jan 2025 17:55:07 +0800 Subject: [PATCH 11/42] fix --- utils/check_repo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/check_repo.py b/utils/check_repo.py index a1190b68ed30..79af3e7764a1 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -138,6 +138,7 @@ "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model "Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration. + "Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration. "MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests ] From aa0dc95c2f288788c2cf06eb3e17bb182cea66d0 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Fri, 17 Jan 2025 18:29:05 +0800 Subject: [PATCH 12/42] fix --- .../qwen2_5_vl/configuration_qwen2_5_vl.py | 40 +------------------ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 4 +- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 10 ++--- utils/check_repo.py | 2 +- 4 files changed, 10 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 08cfa44ea4b0..9a8bf3286518 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -27,7 +27,7 @@ from ...modeling_rope_utils import rope_config_validation -class Qwen2_5_VLVisionConfig(PretrainedConfig): +class Qwen25VisionConfig(PretrainedConfig): model_type = "qwen2_5_vl" base_config_key = "vision_config" @@ -67,38 +67,6 @@ def __init__( self.out_hidden_size = out_hidden_size -class Qwen25VisionConfig(PretrainedConfig): - model_type = "qwen2_5_" - base_config_key = "vision_config" - - def __init__( - self, - depth=32, - embed_dim=1280, - hidden_size=3584, - hidden_act="quick_gelu", - mlp_ratio=4, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.embed_dim = embed_dim - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - - class Qwen2_5_VLConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a @@ -206,7 +174,7 @@ class Qwen2_5_VLConfig(PretrainedConfig): ```""" model_type = "qwen2_5_vl" - sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + sub_configs = {"vision_config": Qwen25VisionConfig} keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `Qwen2_5_VL` base_model_tp_plan = { @@ -281,7 +249,3 @@ def __init__( self.rope_scaling["type"] = "default" self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self, ignore_keys={"mrope_section"}) - if isinstance(vision_config, dict): - self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) - elif vision_config is None: - self.vision_config = Qwen2_5_VLVisionConfig() diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index a78dc64952f1..d186a626abfe 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -49,7 +49,7 @@ logging, replace_return_docstrings, ) -from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen25VisionConfig if is_flash_attn_2_available(): @@ -325,7 +325,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: Qwen2_5_VL_START_DOCSTRING, ) class Qwen2_5_VisionTransformerPretrainedModel(PreTrainedModel): - config_class = Qwen2_5_VLVisionConfig + config_class = Qwen25VisionConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5_VLVisionBlock"] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 26eeae5fba93..f88770f17beb 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -70,7 +70,7 @@ def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> return output -class Qwen2_5_VLVisionConfig(PretrainedConfig): +class Qwen25VisionConfig(PretrainedConfig): model_type = "qwen2_5_vl" base_config_key = "vision_config" @@ -112,7 +112,7 @@ def __init__( class Qwen2_5_VLConfig(Qwen2VLConfig): model_type = "qwen2_5_vl" - sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + sub_configs = {"vision_config": Qwen25VisionConfig} def __init__( self, @@ -160,9 +160,9 @@ def __init__( **kwargs, ) if isinstance(vision_config, dict): - self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + self.vision_config = Qwen25VisionConfig(**vision_config) elif vision_config is None: - self.vision_config = Qwen2_5_VLVisionConfig() + self.vision_config = Qwen25VisionConfig() class Qwen2_5_VLMLP(nn.Module): @@ -247,7 +247,7 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: class Qwen2_5_VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): - config_class = Qwen2_5_VLVisionConfig + config_class = Qwen25VisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] def __init__(self, config, *inputs, **kwargs) -> None: diff --git a/utils/check_repo.py b/utils/check_repo.py index 79af3e7764a1..23f27a662c5b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -138,7 +138,7 @@ "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model "Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration. - "Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration. + "Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration. "MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests ] From 05a1ab156f675a4895d0470fb1ee960444111e23 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Fri, 17 Jan 2025 18:44:54 +0800 Subject: [PATCH 13/42] fix --- src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py | 2 -- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 9a8bf3286518..0e3e7d27b176 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -34,7 +34,6 @@ class Qwen25VisionConfig(PretrainedConfig): def __init__( self, depth=32, - embed_dim=1280, hidden_size=3584, hidden_act="silu", intermediate_size=3420, @@ -52,7 +51,6 @@ def __init__( super().__init__(**kwargs) self.depth = depth - self.embed_dim = embed_dim self.hidden_size = hidden_size self.hidden_act = hidden_act self.intermediate_size = intermediate_size diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index f88770f17beb..0d3cb481c7b7 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -77,7 +77,6 @@ class Qwen25VisionConfig(PretrainedConfig): def __init__( self, depth=32, - embed_dim=1280, hidden_size=3584, hidden_act="silu", intermediate_size=3420, @@ -95,7 +94,6 @@ def __init__( super().__init__(**kwargs) self.depth = depth - self.embed_dim = embed_dim self.hidden_size = hidden_size self.hidden_act = hidden_act self.intermediate_size = intermediate_size From 00b05f2d0059f4c6545d23984bd07b39fcc5fa26 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Fri, 17 Jan 2025 18:56:47 +0800 Subject: [PATCH 14/42] update flashatt2&sdpa support_list --- docs/source/en/perf_infer_gpu_one.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 019fbebd35be..36c960d95c8e 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -95,6 +95,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel) +* [Qwen2.5VL](https://huggingface.co/docs/transformers/model_doc/qwen2_5_vl#transformers.Qwen2_5_VLModel) * [RAG](https://huggingface.co/docs/transformers/model_doc/rag#transformers.RagModel) * [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel) * [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel) @@ -292,6 +293,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) * [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) +* [Qwen2.5VL](https://huggingface.co/docs/transformers/model_doc/qwen2_5_vl#transformers.Qwen2_5_VLModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) From 47eb0694c35ba82356ff373be4b11370da710e70 Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Sat, 18 Jan 2025 10:53:25 +0800 Subject: [PATCH 15/42] Update docs/source/en/_toctree.yml Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ec55ec92947d..8575203f14d4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -923,7 +923,7 @@ - local: model_doc/pixtral title: Pixtral - local: model_doc/qwen2_5_vl - title: Qwen2_5_VL + title: Qwen2.5-VL - local: model_doc/qwen2_audio title: Qwen2Audio - local: model_doc/qwen2_vl From 35b9dbb2703d055b7f8f2ce2de1d2cd46f56dbbe Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Sat, 18 Jan 2025 10:54:14 +0800 Subject: [PATCH 16/42] Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/model_doc/qwen2_5_vl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md index 20ca767723ff..9ef229e07632 100644 --- a/docs/source/en/model_doc/qwen2_5_vl.md +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -18,7 +18,7 @@ rendered properly in your Markdown viewer. ## Overview -The [Qwen2_5-VL](https://qwenlm.github.io/blog/qwen2_5-vl/) model is an update to [Qwen2-VL](https://arxiv.org/abs/2409.12191) from Qwen team, Alibaba Group. +The [Qwen2.5-VL](https://qwenlm.github.io/blog/qwen2_5-vl/) model is an update to [Qwen2-VL](https://arxiv.org/abs/2409.12191) from Qwen team, Alibaba Group. The abstract from this update is the following: From 67a10dae4893dc8f58dfab31297cce3b06135cac Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Sat, 18 Jan 2025 10:54:30 +0800 Subject: [PATCH 17/42] Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/model_doc/qwen2_5_vl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md index 9ef229e07632..6c055c9b570d 100644 --- a/docs/source/en/model_doc/qwen2_5_vl.md +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -22,7 +22,7 @@ The [Qwen2.5-VL](https://qwenlm.github.io/blog/qwen2_5-vl/) model is an update t The abstract from this update is the following: -*Qwen2_5-VL marks a major step forward from Qwen2-VL, built upon the latest Qwen2.5 LLM. We've accelerated training and testing through the strategic implementation of window attention within the ViT. The ViT architecture itself has been refined with SwiGLU and RMSNorm, aligning it more closely with the LLM's structure. A key innovation is the expansion of native dynamic resolution to encompass the temporal dimension, in addition to spatial aspects. Furthermore, we've upgraded MRoPE, incorporating absolute time alignment on the time axis to allow the model to effectively capture temporal dynamics, regardless of frame rate, leading to superior video understanding.* +*Qwen2.5-VL marks a major step forward from Qwen2-VL, built upon the latest Qwen2.5 LLM. We've accelerated training and testing through the strategic implementation of window attention within the ViT. The ViT architecture itself has been refined with SwiGLU and RMSNorm, aligning it more closely with the LLM's structure. A key innovation is the expansion of native dynamic resolution to encompass the temporal dimension, in addition to spatial aspects. Furthermore, we've upgraded MRoPE, incorporating absolute time alignment on the time axis to allow the model to effectively capture temporal dynamics, regardless of frame rate, leading to superior video understanding.* ## Usage example From e37ce587349bf0418f4c984ee6dce9d8f13ad403 Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Sat, 18 Jan 2025 10:54:42 +0800 Subject: [PATCH 18/42] Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/model_doc/qwen2_5_vl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md index 6c055c9b570d..7897c5287308 100644 --- a/docs/source/en/model_doc/qwen2_5_vl.md +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -287,7 +287,7 @@ First, make sure to install the latest version of Flash Attention 2: pip install -U flash-attn --no-build-isolation ``` -Also, you should have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`. +Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`. To load and run a model using Flash Attention-2, simply add `attn_implementation="flash_attention_2"` when loading the model as follows: From fcbb355a2426b48ed7e4de9c456d6f7dfe06a9f7 Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Sat, 18 Jan 2025 10:54:57 +0800 Subject: [PATCH 19/42] Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/model_doc/qwen2_5_vl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md index 7897c5287308..66093ee271d2 100644 --- a/docs/source/en/model_doc/qwen2_5_vl.md +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -289,7 +289,7 @@ pip install -U flash-attn --no-build-isolation Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`. -To load and run a model using Flash Attention-2, simply add `attn_implementation="flash_attention_2"` when loading the model as follows: +To load and run a model using FlashAttention-2, add `attn_implementation="flash_attention_2"` when loading the model: ```python from transformers import Qwen2_5_VLForConditionalGeneration From 7f3c66e7578f08470d2a8684784130c0e47e842e Mon Sep 17 00:00:00 2001 From: ShuaiBai623 <43326198+ShuaiBai623@users.noreply.github.com> Date: Sat, 18 Jan 2025 10:55:09 +0800 Subject: [PATCH 20/42] Update src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 0d3cb481c7b7..a5b8bf3fb445 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -431,7 +431,7 @@ def get_rope_index( Explanation: Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. Examples: input_ids: [T T T T T], here T is for text. temporal position_ids: [0, 1, 2, 3, 4] From f5a4c2b53fd149fbf451a8be3d4d9c68cebb6469 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Sat, 18 Jan 2025 11:17:17 +0800 Subject: [PATCH 21/42] update config --- .../qwen2_5_vl/configuration_qwen2_5_vl.py | 33 ++-- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 6 +- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 178 +++++++++++++++--- 3 files changed, 164 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 0e3e7d27b176..d4e06987cb5d 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -27,7 +27,7 @@ from ...modeling_rope_utils import rope_config_validation -class Qwen25VisionConfig(PretrainedConfig): +class Qwen2_5_VLVisionConfig(PretrainedConfig): model_type = "qwen2_5_vl" base_config_key = "vision_config" @@ -68,9 +68,9 @@ def __init__( class Qwen2_5_VLConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a - Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + Qwen2.5-VL-7B-Instruct [Qwen/Qwen2_5_VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -78,7 +78,7 @@ class Qwen2_5_VLConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 152064): - Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the + Vocabulary size of the Qwen2.5-VL model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen2_5_VLModel`] hidden_size (`int`, *optional*, defaults to 8192): Dimension of the hidden representations. @@ -164,26 +164,16 @@ class Qwen2_5_VLConfig(PretrainedConfig): >>> # Initializing a Qwen2_5_VL style configuration >>> configuration = Qwen2_5_VLConfig() - >>> # Initializing a model from the Qwen2-VL-7B style configuration - >>> model = Qwen2_5_VLForConditionalGeneration(configuration) - >>> # Accessing the model configuration >>> configuration = model.config + + >>> # Initializing a model from the Qwen2_5_VL configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) ```""" model_type = "qwen2_5_vl" - sub_configs = {"vision_config": Qwen25VisionConfig} + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} keys_to_ignore_at_inference = ["past_key_values"] - # Default tensor parallel plan for base model `Qwen2_5_VL` - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } def __init__( self, @@ -208,11 +198,10 @@ def __init__( rope_scaling=None, **kwargs, ): - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) if isinstance(vision_config, dict): - self.vision_config = Qwen25VisionConfig(**vision_config) + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) elif vision_config is None: - self.vision_config = Qwen25VisionConfig() + self.vision_config = Qwen2_5_VLVisionConfig() self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -247,3 +236,5 @@ def __init__( self.rope_scaling["type"] = "default" self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d186a626abfe..7acf088e1a9a 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -49,7 +49,7 @@ logging, replace_return_docstrings, ) -from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen25VisionConfig +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig if is_flash_attn_2_available(): @@ -325,7 +325,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: Qwen2_5_VL_START_DOCSTRING, ) class Qwen2_5_VisionTransformerPretrainedModel(PreTrainedModel): - config_class = Qwen25VisionConfig + config_class = Qwen2_5_VLVisionConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5_VLVisionBlock"] @@ -1548,7 +1548,7 @@ def get_rope_index( Explanation: Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. Examples: input_ids: [T T T T T], here T is for text. temporal position_ids: [0, 1, 2, 3, 4] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index a5b8bf3fb445..47c391d3d1f1 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -28,7 +28,6 @@ import torch.utils.checkpoint from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from transformers.models.qwen2_vl.modeling_qwen2_vl import ( PatchEmbed, @@ -48,6 +47,7 @@ from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput +from ...modeling_rope_utils import rope_config_validation from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_flash_attn_2_available @@ -70,7 +70,7 @@ def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> return output -class Qwen25VisionConfig(PretrainedConfig): +class Qwen2_5_VLVisionConfig(PretrainedConfig): model_type = "qwen2_5_vl" base_config_key = "vision_config" @@ -108,9 +108,115 @@ def __init__( self.out_hidden_size = out_hidden_size -class Qwen2_5_VLConfig(Qwen2VLConfig): +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2.5-VL-7B-Instruct [Qwen/Qwen2_5_VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2.5-VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2_5_VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # Initializing a model from the Qwen2_5_VL configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + ```""" + model_type = "qwen2_5_vl" - sub_configs = {"vision_config": Qwen25VisionConfig} + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, @@ -135,32 +241,46 @@ def __init__( rope_scaling=None, **kwargs, ): - super().__init__( - vocab_size=vocab_size, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - hidden_act=hidden_act, - max_position_embeddings=max_position_embeddings, - initializer_range=initializer_range, - rms_norm_eps=rms_norm_eps, - use_cache=use_cache, - tie_word_embeddings=tie_word_embeddings, - rope_theta=rope_theta, - use_sliding_window=use_sliding_window, - sliding_window=sliding_window, - max_window_layers=max_window_layers, - attention_dropout=attention_dropout, - vision_config=vision_config, - rope_scaling=rope_scaling, - **kwargs, - ) if isinstance(vision_config, dict): - self.vision_config = Qwen25VisionConfig(**vision_config) + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) elif vision_config is None: - self.vision_config = Qwen25VisionConfig() + self.vision_config = Qwen2_5_VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class Qwen2_5_VLMLP(nn.Module): @@ -245,7 +365,7 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: class Qwen2_5_VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): - config_class = Qwen25VisionConfig + config_class = Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] def __init__(self, config, *inputs, **kwargs) -> None: From f57827dd9f429f738f80544fa12b57c1730b9613 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Sat, 18 Jan 2025 11:42:11 +0800 Subject: [PATCH 22/42] update --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 7acf088e1a9a..1490d8553c58 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1290,6 +1290,14 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None From b21bd62d4c6c10448f5e5e9b9e3eb4c0563f5a34 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Sat, 18 Jan 2025 12:08:03 +0800 Subject: [PATCH 23/42] fix hf path --- src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index d4e06987cb5d..556c99001507 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -70,7 +70,7 @@ class Qwen2_5_VLConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - Qwen2.5-VL-7B-Instruct [Qwen/Qwen2_5_VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + Qwen2.5-VL-7B-Instruct [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 47c391d3d1f1..bb1ef72a79cf 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -113,7 +113,7 @@ class Qwen2_5_VLConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - Qwen2.5-VL-7B-Instruct [Qwen/Qwen2_5_VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + Qwen2.5-VL-7B-Instruct [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. From 18ea4bef6501d191d67cfdda656cf391c439de1f Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Mon, 20 Jan 2025 13:22:09 +0800 Subject: [PATCH 24/42] rename Qwen2_5_VLVideosKwargs --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 5 ----- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 4 ++-- .../models/qwen2_5_vl/processing_qwen2_5_vl.py | 9 ++++++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 1490d8553c58..bc4473559c77 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -40,7 +40,6 @@ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...processing_utils import VideosKwargs from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -1972,7 +1971,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - -class Qwen2_5_VLVideosKwargs(VideosKwargs, total=False): - fps: Union[List[float], float] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index bb1ef72a79cf..dc4c54a77b98 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -1009,12 +1009,12 @@ class Qwen2_5_VLImageProcessor(Qwen2VLImageProcessor): ] -class Qwen2_5_VLVideosKwargs(VideosKwargs, total=False): +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): fps: Union[List[float], float] class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): - videos_kwargs: Qwen2_5_VLVideosKwargs + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs _defaults = { "text_kwargs": { "padding": False, diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 7207148fcee8..6a7e1af1203d 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -27,13 +27,16 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput -from .modeling_qwen2_5_vl import Qwen2_5_VLVideosKwargs + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): - videos_kwargs: Qwen2_5_VLVideosKwargs + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs _defaults = { "text_kwargs": { "padding": False, From faae9f4895f080bae3edcc8b1603d1f67a825c26 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Mon, 20 Jan 2025 14:42:08 +0800 Subject: [PATCH 25/42] fix --- tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py index 2bc47d76acd4..c16e709084d5 100644 --- a/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py @@ -36,7 +36,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - processor = Qwen2_5_VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", patch_size=4) + processor = Qwen2_5_VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", patch_size=4) # Qwen/Qwen2.5-VL-7B-Instruct will cause the test to fail. processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): From cd3fd21c05bba138196f9ad098cc0aaffd180fe3 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Mon, 20 Jan 2025 14:47:32 +0800 Subject: [PATCH 26/42] fix --- tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py index c16e709084d5..2bc47d76acd4 100644 --- a/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_processor_qwen2_5_vl.py @@ -36,7 +36,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - processor = Qwen2_5_VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", patch_size=4) # Qwen/Qwen2.5-VL-7B-Instruct will cause the test to fail. + processor = Qwen2_5_VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", patch_size=4) processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): From f8c465cbf65fc802e8d59ef5ec2e7ca8db8e47c5 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 16:00:27 +0800 Subject: [PATCH 27/42] update --- .../models/qwen2_5_vl/__init__.py | 66 +----- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 210 +++--------------- .../models/qwen2_vl/configuration_qwen2_vl.py | 4 +- .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 2 +- 4 files changed, 37 insertions(+), 245 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/__init__.py b/src/transformers/models/qwen2_5_vl/__init__.py index e2ee26089d5b..5d3cd215b070 100644 --- a/src/transformers/models/qwen2_5_vl/__init__.py +++ b/src/transformers/models/qwen2_5_vl/__init__.py @@ -13,67 +13,17 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_torch_available, - is_vision_available, -) - - -_import_structure = { - "configuration_qwen2_5_vl": ["Qwen2_5_VLConfig"], - "processing_qwen2_5_vl": ["Qwen2_5_VLProcessor"], -} - - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_qwen2_5_vl"] = [ - "Qwen2_5_VLForConditionalGeneration", - "Qwen2_5_VLModel", - "Qwen2_5_VLPreTrainedModel", - ] - -try: - if not is_vision_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["image_processing_qwen2_5_vl"] = ["Qwen2_5_VLImageProcessor"] +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure if TYPE_CHECKING: - from .configuration_qwen2_5_vl import Qwen2_5_VLConfig - from .processing_qwen2_5_vl import Qwen2_5_VLProcessor - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration, - Qwen2_5_VLModel, - Qwen2_5_VLPreTrainedModel, - ) - - try: - if not is_vision_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .image_processing_qwen2_5_vl import Qwen2_5_VLImageProcessor - - + from .configuration_qwen2_5_vl import * + from .image_processing_qwen2_5_vl import * + from .modeling_qwen2_5_vl import * + from .processing_qwen2_5_vl import * else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index dc4c54a77b98..99546ac74c86 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -28,6 +28,7 @@ import torch.utils.checkpoint from torch.nn import CrossEntropyLoss +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from transformers.models.qwen2_vl.modeling_qwen2_vl import ( PatchEmbed, @@ -47,7 +48,6 @@ from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput -from ...modeling_rope_utils import rope_config_validation from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_flash_attn_2_available @@ -88,7 +88,7 @@ def __init__( tokens_per_second=4, window_size=112, out_hidden_size=3584, - fullatt_block_indexes=list(range(7, 32, 8)), + fullatt_block_indexes=[7, 15, 23, 31], **kwargs, ): super().__init__(**kwargs) @@ -108,179 +108,9 @@ def __init__( self.out_hidden_size = out_hidden_size -class Qwen2_5_VLConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a - Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2.5-VL-7B-Instruct [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 152064): - Vocabulary size of the Qwen2.5-VL model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2_5_VLModel`] - hidden_size (`int`, *optional*, defaults to 8192): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 29568): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 80): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 64): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 80): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - vision_config (`Dict`, *optional*): - The config for the visual encoder initialization. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - - ```python - >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig - - >>> # Initializing a Qwen2_5_VL style configuration - >>> configuration = Qwen2_5_VLConfig() - - >>> # Accessing the model configuration - >>> configuration = model.config - - >>> # Initializing a model from the Qwen2_5_VL configuration - >>> model = Qwen2_5_VLForConditionalGeneration(configuration) - ```""" - +class Qwen2_5_VLConfig(Qwen2VLConfig): model_type = "qwen2_5_vl" sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=152064, - hidden_size=8192, - intermediate_size=29568, - num_hidden_layers=80, - num_attention_heads=64, - num_key_value_heads=8, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-05, - use_cache=True, - tie_word_embeddings=False, - rope_theta=1000000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=80, - attention_dropout=0.0, - vision_config=None, - rope_scaling=None, - **kwargs, - ): - if isinstance(vision_config, dict): - self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) - elif vision_config is None: - self.vision_config = Qwen2_5_VLVisionConfig() - - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.rope_scaling = rope_scaling - - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. - # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations - # one can set it to "linear"/"dynamic" etc. to have scaled RoPE - # TODO: @raushan update config in the hub - if self.rope_scaling is not None and "type" in self.rope_scaling: - if self.rope_scaling["type"] == "mrope": - self.rope_scaling["type"] = "default" - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self, ignore_keys={"mrope_section"}) - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class Qwen2_5_VLMLP(nn.Module): @@ -297,6 +127,12 @@ def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) +class Qwen2_5_VisionPatchEmbed(PatchEmbed): + pass + +class Qwen2_5_VisionRotaryEmbedding(VisionRotaryEmbedding): + pass + class Qwen2_5_VLPatchMerger(PatchMerger): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__(dim, context_dim, spatial_merge_size) @@ -376,7 +212,7 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.window_size = config.window_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size - self.patch_embed = PatchEmbed( + self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=config.patch_size, temporal_patch_size=config.temporal_patch_size, in_channels=config.in_channels, @@ -384,7 +220,7 @@ def __init__(self, config, *inputs, **kwargs) -> None: ) head_dim = config.hidden_size // config.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] @@ -396,12 +232,6 @@ def __init__(self, config, *inputs, **kwargs) -> None: ) self.gradient_checkpointing = False - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.up_proj.weight.dtype - - def get_device(self) -> torch.device: - return self.blocks[0].mlp.up_proj.weight.device - def rot_pos_emb(self, grid_thw): pos_ids = [] for t, h, w in grid_thw: @@ -473,6 +303,16 @@ def get_window_index(self, grid_thw): return window_index, cu_window_seqlens def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) @@ -672,7 +512,9 @@ def get_rope_index( t_index = ( ( - torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) * second_per_grid_t * self.config.vision_config.tokens_per_second ) @@ -784,7 +626,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) + pixel_values = pixel_values.type(self.visual.dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] @@ -802,7 +644,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + pixel_values_videos = pixel_values_videos.type(self.visual.dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index b409de920328..49a0836cf96a 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -198,9 +198,9 @@ def __init__( **kwargs, ): if isinstance(vision_config, dict): - self.vision_config = Qwen2VLVisionConfig(**vision_config) + self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: - self.vision_config = Qwen2VLVisionConfig() + self.vision_config = self.sub_configs["vision_config"]() self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index e41ebf0ca25e..863e8ac0f410 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -377,7 +377,7 @@ def test_small_model_integration_test(self): inputs = self.processor(text=[text], images=[self.image], return_tensors="pt") expected_input_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151655] # fmt: skip - assert expected_input_ids == inputs.input_ids[0].tolist()[:17] + torch.testing.assert_close(expected_input_ids, inputs.input_ids[0].tolist()[:17]) expected_pixel_slice = torch.tensor( [ From 84247a475a18df5c16a1fc1b74be81985ffe45db Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 16:25:13 +0800 Subject: [PATCH 28/42] excuted modular --- .../qwen2_5_vl/configuration_qwen2_5_vl.py | 28 +++-- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 100 +++++++++--------- 2 files changed, 71 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 556c99001507..2f4eefa6bf32 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -45,7 +45,7 @@ def __init__( tokens_per_second=4, window_size=112, out_hidden_size=3584, - fullatt_block_indexes=list(range(7, 32, 8)), + fullatt_block_indexes=[7, 15, 23, 31], **kwargs, ): super().__init__(**kwargs) @@ -68,9 +68,9 @@ def __init__( class Qwen2_5_VLConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a - Qwen2.5-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - Qwen2.5-VL-7B-Instruct [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -78,7 +78,7 @@ class Qwen2_5_VLConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 152064): - Vocabulary size of the Qwen2.5-VL model. Defines the number of different tokens that can be represented by the + Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen2_5_VLModel`] hidden_size (`int`, *optional*, defaults to 8192): Dimension of the hidden representations. @@ -164,16 +164,26 @@ class Qwen2_5_VLConfig(PretrainedConfig): >>> # Initializing a Qwen2_5_VL style configuration >>> configuration = Qwen2_5_VLConfig() + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + >>> # Accessing the model configuration >>> configuration = model.config - - >>> # Initializing a model from the Qwen2_5_VL configuration - >>> model = Qwen2_5_VLForConditionalGeneration(configuration) ```""" model_type = "qwen2_5_vl" sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2_5_VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, @@ -199,9 +209,9 @@ def __init__( **kwargs, ): if isinstance(vision_config, dict): - self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: - self.vision_config = Qwen2_5_VLVisionConfig() + self.vision_config = self.sub_configs["vision_config"]() self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index bc4473559c77..4ecf11958d34 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -85,6 +85,44 @@ def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -264,44 +302,6 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: return hidden_states -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - -class PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_channels: int = 3, - embed_dim: int = 1152, - ) -> None: - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.embed_dim = embed_dim - - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype - hidden_states = hidden_states.view( - -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) - return hidden_states - - Qwen2_5_VL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -342,7 +342,7 @@ def __init__(self, config, *inputs, **kwargs): self.window_size = config.window_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size - self.patch_embed = PatchEmbed( + self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=config.patch_size, temporal_patch_size=config.temporal_patch_size, in_channels=config.in_channels, @@ -350,7 +350,7 @@ def __init__(self, config, *inputs, **kwargs): ) head_dim = config.hidden_size // config.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] @@ -373,12 +373,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.up_proj.weight.dtype - - def get_device(self) -> torch.device: - return self.blocks[0].mlp.up_proj.weight.device - def rot_pos_emb(self, grid_thw): pos_ids = [] for t, h, w in grid_thw: @@ -450,6 +444,16 @@ def get_window_index(self, grid_thw): return window_index, cu_window_seqlens def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) @@ -1790,7 +1794,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) + pixel_values = pixel_values.type(self.visual.dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] @@ -1808,7 +1812,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + pixel_values_videos = pixel_values_videos.type(self.visual.dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] From 0cd1f621015254426cce9c0dbadc8185e055c1e9 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 16:41:16 +0800 Subject: [PATCH 29/42] rollback init --- .../models/qwen2_5_vl/__init__.py | 66 ++++++++++++++++--- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/__init__.py b/src/transformers/models/qwen2_5_vl/__init__.py index 5d3cd215b070..e2ee26089d5b 100644 --- a/src/transformers/models/qwen2_5_vl/__init__.py +++ b/src/transformers/models/qwen2_5_vl/__init__.py @@ -13,17 +13,67 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import _LazyModule -from ...utils.import_utils import define_import_structure +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_vision_available, +) + + +_import_structure = { + "configuration_qwen2_5_vl": ["Qwen2_5_VLConfig"], + "processing_qwen2_5_vl": ["Qwen2_5_VLProcessor"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_qwen2_5_vl"] = [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + ] + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_qwen2_5_vl"] = ["Qwen2_5_VLImageProcessor"] if TYPE_CHECKING: - from .configuration_qwen2_5_vl import * - from .image_processing_qwen2_5_vl import * - from .modeling_qwen2_5_vl import * - from .processing_qwen2_5_vl import * + from .configuration_qwen2_5_vl import Qwen2_5_VLConfig + from .processing_qwen2_5_vl import Qwen2_5_VLProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + ) + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_qwen2_5_vl import Qwen2_5_VLImageProcessor + + else: import sys - _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) From 446383fef787a3ee6deb75607924c325aaebd578 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 16:50:02 +0800 Subject: [PATCH 30/42] fix --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 4 ++-- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 4ecf11958d34..482bd35eb0af 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1794,7 +1794,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype()) + pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] @@ -1812,7 +1812,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype()) + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 99546ac74c86..df6888ae219d 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -626,7 +626,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype()) + pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] @@ -644,7 +644,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype()) + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] From f87160dbd28cf5b0a0b93de8b826982998e0e05b Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 17:05:22 +0800 Subject: [PATCH 31/42] formated --- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index df6888ae219d..041269c01606 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -130,9 +130,11 @@ def forward(self, hidden_state): class Qwen2_5_VisionPatchEmbed(PatchEmbed): pass + class Qwen2_5_VisionRotaryEmbedding(VisionRotaryEmbedding): pass + class Qwen2_5_VLPatchMerger(PatchMerger): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__(dim, context_dim, spatial_merge_size) @@ -512,9 +514,7 @@ def get_rope_index( t_index = ( ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) * second_per_grid_t * self.config.vision_config.tokens_per_second ) From 4012f4aa7fc197d6b0211fe64f633dc572d1ddbf Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 22:36:18 +0800 Subject: [PATCH 32/42] simpler init --- .../models/qwen2_5_vl/__init__.py | 66 +++---------------- .../qwen2_5_vl/configuration_qwen2_5_vl.py | 3 + .../qwen2_5_vl/image_processing_qwen2_5_vl.py | 3 + .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 3 + .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 10 +++ .../qwen2_5_vl/processing_qwen2_5_vl.py | 3 + 6 files changed, 30 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/__init__.py b/src/transformers/models/qwen2_5_vl/__init__.py index e2ee26089d5b..5d3cd215b070 100644 --- a/src/transformers/models/qwen2_5_vl/__init__.py +++ b/src/transformers/models/qwen2_5_vl/__init__.py @@ -13,67 +13,17 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_torch_available, - is_vision_available, -) - - -_import_structure = { - "configuration_qwen2_5_vl": ["Qwen2_5_VLConfig"], - "processing_qwen2_5_vl": ["Qwen2_5_VLProcessor"], -} - - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_qwen2_5_vl"] = [ - "Qwen2_5_VLForConditionalGeneration", - "Qwen2_5_VLModel", - "Qwen2_5_VLPreTrainedModel", - ] - -try: - if not is_vision_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["image_processing_qwen2_5_vl"] = ["Qwen2_5_VLImageProcessor"] +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure if TYPE_CHECKING: - from .configuration_qwen2_5_vl import Qwen2_5_VLConfig - from .processing_qwen2_5_vl import Qwen2_5_VLProcessor - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration, - Qwen2_5_VLModel, - Qwen2_5_VLPreTrainedModel, - ) - - try: - if not is_vision_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .image_processing_qwen2_5_vl import Qwen2_5_VLImageProcessor - - + from .configuration_qwen2_5_vl import * + from .image_processing_qwen2_5_vl import * + from .modeling_qwen2_5_vl import * + from .processing_qwen2_5_vl import * else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 2f4eefa6bf32..a1cf06d94e87 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -248,3 +248,6 @@ def __init__( rope_config_validation(self, ignore_keys={"mrope_section"}) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["Qwen2_5_VLConfig"] diff --git a/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py index f0b405febc78..7101ae603598 100644 --- a/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/image_processing_qwen2_5_vl.py @@ -463,3 +463,6 @@ def preprocess( data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws} return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Qwen2_5_VLImageProcessor"] diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 482bd35eb0af..c0d0610f4572 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1975,3 +1975,6 @@ def prepare_inputs_for_generation( } ) return model_inputs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 041269c01606..05e80e8936b1 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -988,3 +988,13 @@ def __call__( text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + +__all__ = [ + "Qwen2_5_VLConfig", + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + "Qwen2_5_VLImageProcessor", + "Qwen2_5_VLProcessor", +] diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 6a7e1af1203d..88c383d8e7e1 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -213,3 +213,6 @@ def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Qwen2_5_VLProcessor"] From d671be8cb95c4cb4784f7f8fb0ad06ffaf3c87d7 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 22:47:52 +0800 Subject: [PATCH 33/42] fix --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c0d0610f4572..894c25c32ae9 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1977,4 +1977,4 @@ def prepare_inputs_for_generation( return model_inputs -__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] +__all__ = ["Qwen2_5_VLForConditionalGeneration"] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 05e80e8936b1..7e85c648a5f1 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -993,8 +993,6 @@ def __call__( __all__ = [ "Qwen2_5_VLConfig", "Qwen2_5_VLForConditionalGeneration", - "Qwen2_5_VLModel", - "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLImageProcessor", "Qwen2_5_VLProcessor", ] From ad46b8781bcf8dda932299e986db080619f95b7f Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 23:18:15 +0800 Subject: [PATCH 34/42] fix --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 894c25c32ae9..c0d0610f4572 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1977,4 +1977,4 @@ def prepare_inputs_for_generation( return model_inputs -__all__ = ["Qwen2_5_VLForConditionalGeneration"] +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 7e85c648a5f1..bf766002d275 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -37,6 +37,7 @@ Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration, Qwen2VLPreTrainedModel, + Qwen2VLModel, VisionAttention, VisionRotaryEmbedding, VisionSdpaAttention, @@ -366,6 +367,14 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states +def Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): + pass + + +def Qwen2_5_VLModel(Qwen2VLModel): + pass + + @dataclass class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): pass @@ -993,6 +1002,8 @@ def __call__( __all__ = [ "Qwen2_5_VLConfig", "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLImageProcessor", "Qwen2_5_VLProcessor", ] From 8390d99d35e63198be9c6cdc1fc1bcfb326cf3bc Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 23:26:13 +0800 Subject: [PATCH 35/42] fix --- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index bf766002d275..29b710d79675 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -367,11 +367,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states -def Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): +class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): pass -def Qwen2_5_VLModel(Qwen2VLModel): +class Qwen2_5_VLModel(Qwen2VLModel): pass From 16913ab025a00f0235b6d4512dd8657397e791e4 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Tue, 21 Jan 2025 23:47:29 +0800 Subject: [PATCH 36/42] fix --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 144 ++++++++---------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 10 +- 2 files changed, 66 insertions(+), 88 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c0d0610f4572..fdfa1c0b1286 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -323,11 +323,11 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", Qwen2_5_VL_START_DOCSTRING, ) -class Qwen2_5_VisionTransformerPretrainedModel(PreTrainedModel): - config_class = Qwen2_5_VLVisionConfig +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2_5_VLVisionBlock"] + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True @@ -336,6 +336,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(PreTrainedModel): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) self.spatial_merge_size = config.spatial_merge_size self.patch_size = config.patch_size self.fullatt_block_indexes = config.fullatt_block_indexes @@ -362,17 +381,6 @@ def __init__(self, config, *inputs, **kwargs): ) self.gradient_checkpointing = False - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - def rot_pos_emb(self, grid_thw): pos_ids = [] for t, h, w in grid_thw: @@ -505,45 +513,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states -@dataclass -class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): - """ - Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None - - class Qwen2_5_VLRotaryEmbedding(nn.Module): def __init__( self, @@ -1109,36 +1078,6 @@ def forward( return outputs -@add_start_docstrings( - "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", - Qwen2_5_VL_START_DOCSTRING, -) -class Qwen2_5_VLPreTrainedModel(PreTrainedModel): - config_class = Qwen2_5_VLConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = True - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @add_start_docstrings( "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", Qwen2_5_VL_START_DOCSTRING, @@ -1435,6 +1374,45 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + QWEN2_5_VL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 29b710d79675..920dfa729a6a 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -203,7 +203,11 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: return hidden_states -class Qwen2_5_VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): +class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): + pass + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config_class = Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] @@ -367,10 +371,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states -class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): - pass - - class Qwen2_5_VLModel(Qwen2VLModel): pass From 72de3020e29e50a571c7287530646af48a65c04c Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Wed, 22 Jan 2025 00:14:45 +0800 Subject: [PATCH 37/42] fix --- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 920dfa729a6a..645a09cf2d32 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -36,8 +36,8 @@ Qwen2RMSNorm, Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration, - Qwen2VLPreTrainedModel, Qwen2VLModel, + Qwen2VLPreTrainedModel, VisionAttention, VisionRotaryEmbedding, VisionSdpaAttention, From 91a00e32810f7996f8580fc2856ef301d24cf416 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Wed, 22 Jan 2025 00:52:10 +0800 Subject: [PATCH 38/42] update docs --- docs/source/en/model_doc/qwen2_5_vl.md | 49 ++++++-------------------- 1 file changed, 11 insertions(+), 38 deletions(-) diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md index 66093ee271d2..9bfcb6ae0f62 100644 --- a/docs/source/en/model_doc/qwen2_5_vl.md +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -37,6 +37,7 @@ import requests import torch from torchvision import io from typing import Dict +from transformers.image_utils import load_images, load_video from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor # Load the model in half-precision on the available device(s) @@ -77,33 +78,7 @@ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, cl print(output_text) # Video -def fetch_video(ele: Dict, nframe_factor=2): - if isinstance(ele['video'], str): - def round_by_factor(number: int, factor: int) -> int: - return round(number / factor) * factor - - video = ele["video"] - if video.startswith("file://"): - video = video[7:] - - video, _, info = io.read_video( - video, - start_pts=ele.get("video_start", 0.0), - end_pts=ele.get("video_end", None), - pts_unit="sec", - output_format="TCHW", - ) - assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" - if "nframes" in ele: - nframes = round_by_factor(ele["nframes"], nframe_factor) - else: - fps = ele.get("fps", 1.0) - nframes = round_by_factor(video.size(0) / info["video_fps"] * fps, nframe_factor) - idx = torch.linspace(0, video.size(0) - 1, nframes, dtype=torch.int64) - return video[idx] - -video_info = {"type": "video", "video": "/path/to/video.mp4", "fps": 1.0} -video = fetch_video(video_info) +video = load_video(video="/path/to/video.mp4") conversation = [ { "role": "user", @@ -135,16 +110,14 @@ print(output_text) The model can batch inputs composed of mixed samples of various types such as images, videos, and text. Here is an example. ```python -image1 = Image.open("/path/to/image1.jpg") -image2 = Image.open("/path/to/image2.jpg") -image3 = Image.open("/path/to/image3.jpg") -image4 = Image.open("/path/to/image4.jpg") -image5 = Image.open("/path/to/image5.jpg") -video = fetch_video({ - "type": "video", - "video": "/path/to/video.mp4", - "fps": 1.0 -}) +images = load_images([ + "/path/to/image1.jpg", + "/path/to/image2.jpg", + "/path/to/image3.jpg", + "/path/to/image4.jpg", + "/path/to/image5.jpg", +]) +video = load_video(video="/path/to/video.mp4") # Conversation for the first image conversation1 = [ @@ -196,7 +169,7 @@ conversations = [conversation1, conversation2, conversation3, conversation4] texts = [processor.apply_chat_template(msg, add_generation_prompt=True) for msg in conversations] inputs = processor( text=texts, - images=[image1, image2, image3, image4, image5], + images=images, videos=[video], padding=True, return_tensors="pt", From ca2e16cd95aff7fbbb85fc241d2156e28f64292a Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Wed, 22 Jan 2025 13:04:41 +0800 Subject: [PATCH 39/42] fix --- tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 863e8ac0f410..d76a7ba1e20c 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -377,7 +377,7 @@ def test_small_model_integration_test(self): inputs = self.processor(text=[text], images=[self.image], return_tensors="pt") expected_input_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151655] # fmt: skip - torch.testing.assert_close(expected_input_ids, inputs.input_ids[0].tolist()[:17]) + assert torch.allclose(expected_input_ids, inputs.input_ids[0].tolist()[:17], atol=3e-3) expected_pixel_slice = torch.tensor( [ From 64db9c7d50e880d37fb54eaba551f7ea8209c788 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Wed, 22 Jan 2025 13:27:59 +0800 Subject: [PATCH 40/42] fix --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 6 +++--- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index fdfa1c0b1286..9fbbf40789e5 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1525,7 +1525,7 @@ def get_decoder(self): def get_rope_index( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, @@ -1684,7 +1684,7 @@ def get_rope_index( if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: @@ -1811,7 +1811,7 @@ def forward( attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 645a09cf2d32..9d4dd8ea6be4 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -390,7 +390,7 @@ def __init__(self, config): def get_rope_index( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, @@ -549,7 +549,7 @@ def get_rope_index( if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: @@ -674,7 +674,7 @@ def forward( attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( From c84216d72d149d5b9366eef394a97332cc27d201 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Thu, 23 Jan 2025 15:05:43 +0800 Subject: [PATCH 41/42] update Qwen2VLRotaryEmbedding for yarn --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 51 ++++--------------- .../models/qwen2_vl/modeling_qwen2_vl.py | 51 ++++--------------- 2 files changed, 18 insertions(+), 84 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 9fbbf40789e5..fdfaddfc0e94 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -514,47 +514,20 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. class Qwen2_5_VLRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Qwen2_5_VLConfig] = None, - ): + def __init__(self, config: Qwen2_5_VLConfig, device=None): super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2_5_VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -696,8 +669,6 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout self.rope_scaling = config.rope_scaling @@ -712,11 +683,7 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = Qwen2_5_VLRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) def forward( self, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 69532de0448c..602ee421be9a 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -100,47 +100,20 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput): class Qwen2VLRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Qwen2VLConfig] = None, - ): + def __init__(self, config: Qwen2VLConfig, device=None): super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -503,8 +476,6 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout self.rope_scaling = config.rope_scaling @@ -519,11 +490,7 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = Qwen2VLRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) def forward( self, From a0b3a94538801ee3dec6b7a7c979cd664419da17 Mon Sep 17 00:00:00 2001 From: gewenbin0992 Date: Thu, 23 Jan 2025 17:28:54 +0800 Subject: [PATCH 42/42] fix --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 44 +++++++++---------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 41 +++++++++-------- .../models/qwen2_vl/modeling_qwen2_vl.py | 3 -- 3 files changed, 40 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index fdfaddfc0e94..2046deef0b3e 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -334,9 +334,6 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): @@ -1623,15 +1620,14 @@ def get_rope_index( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - t_index = ( - ( - torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) - * second_per_grid_t - * self.config.vision_config.tokens_per_second - ) - .long() - .flatten() - ) + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) @@ -1747,12 +1743,12 @@ def forward( raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -1765,12 +1761,12 @@ def forward( raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 9d4dd8ea6be4..1e155d6043e0 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -521,15 +521,14 @@ def get_rope_index( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - t_index = ( - ( - torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) - * second_per_grid_t - * self.config.vision_config.tokens_per_second - ) - .long() - .flatten() - ) + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) @@ -643,12 +642,12 @@ def forward( raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -661,12 +660,12 @@ def forward( raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 602ee421be9a..c680201f1923 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -884,9 +884,6 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)):