diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 1527c2b..164b072 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -69,10 +69,38 @@ def parse_arguments(): help="Maximum number of tokens to generate.", ) parser.add_argument( - "--temp", type=float, default=DEFAULT_TEMP, help="Temperature for sampling." + "--temperature", + type=float, + default=DEFAULT_TEMP, + help="Temperature for sampling.", ) parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.") parser.add_argument("--verbose", action="store_false", help="Detailed output.") + parser.add_argument( + "--merge-similar-tokens-ratio", + type=float, + default=1.0, + help="Ratio of visual tokens to keep during merging similar tokens (between 0.1 and 1.0).", + choices=[x / 10 for x in range(1, 11)], + ) + parser.add_argument( + "--filter-topk-tokens-ratio", + type=float, + help="Ratio of visual tokens to keep during filtering topk tokens (between 0.1 and 1.0).", + choices=[x / 10 for x in range(1, 11)], + ) + parser.add_argument( + "--max-kv-size", + type=int, + default=None, + help="Set the maximum key-value cache size", + ) + parser.add_argument( + "--prefill-step-size", + type=int, + default=256, + help="Set the prefill step size", + ) return parser.parse_args() @@ -97,6 +125,9 @@ def main(): prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) kwargs = {} + + if args.max_kv_size is not None: + kwargs["max_kv_size"] = args.max_kv_size if args.resize_shape is not None: resize_shape = args.resize_shape if len(resize_shape) not in [1, 2]: @@ -107,6 +138,19 @@ def main(): else resize_shape ) + kwargs["merge_similar_tokens_ratio"] = args.merge_similar_tokens_ratio + if args.filter_topk_tokens_ratio is None: + # If merge ratio is specified but filter ratio isn't, automatically set filter ratio + if args.merge_similar_tokens_ratio < 0.9: + # For aggressive merging (ratio < 0.9), keep 90% of tokens + kwargs["filter_topk_tokens_ratio"] = 0.90 + else: + # For light merging (ratio >= 0.9), keep all tokens + kwargs["filter_topk_tokens_ratio"] = 1.0 + else: + # Use explicitly provided filter ratio + kwargs["filter_topk_tokens_ratio"] = args.filter_topk_tokens_ratio + if args.chat: chat = [] if args.system: @@ -124,7 +168,7 @@ def main(): prompt, args.image, max_tokens=args.max_tokens, - temp=args.temp, + temperature=args.temperature, **kwargs, ): response += chunk.text @@ -139,7 +183,7 @@ def main(): processor, prompt, image=args.image, - temp=args.temp, + temperature=args.temperature, max_tokens=args.max_tokens, verbose=args.verbose, **kwargs, diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index ee47eb6..03d9994 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional import mlx.core as mx +import mlx.nn as nn from PIL import Image from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor from transformers.image_processing_utils import get_size_dict @@ -97,6 +98,10 @@ def update(self, keys, values): self.keys[..., prev : self.offset, :] = keys self.values[..., prev : self.offset, :] = values + @property + def state(self): + return self.keys, self.values + class SimpleKVCache: """A simple key-value cache for transformer attention layers. @@ -148,7 +153,7 @@ def update(self, keys, values): class RotatingKVCache: - def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): + def __init__(self, head_dim, n_kv_heads, max_size, keep=None, step=256): self.n_kv_heads = n_kv_heads if isinstance(head_dim, int): self.k_head_dim = self.v_head_dim = head_dim @@ -156,7 +161,7 @@ def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): self.k_head_dim, self.v_head_dim = head_dim else: raise ValueError("head_dim must be an int or a tuple of two ints") - self.keep = keep + self.keep = keep if keep is not None else step // 2 self.keys = None self.values = None self.offset = 0 @@ -271,3 +276,122 @@ class LanguageModelOutput: logits: mx.array cross_attention_states: Optional[List[mx.array]] = None encoder_outputs: Optional[List[mx.array]] = None + + +class BaseModel(nn.Module): + def __init__(self): + super().__init__() + self.vision_tower = None + self.language_model = None + + def prefill(self, input_embeds, cache=None, prefill_step_size=256): + # Process input in batches for better parallelization + num_batches = ( + input_embeds.shape[1] + prefill_step_size - 1 + ) // prefill_step_size + + if num_batches > 1: + # Pre-allocate slices for better memory efficiency + slices = [ + input_embeds[:, i * prefill_step_size : (i + 1) * prefill_step_size, :] + for i in range(num_batches - 1) + ] + + # Process all full-sized batches in parallel + for slice in slices: + mask = create_attention_mask(slice, cache) + self.language_model(inputs_embeds=slice, cache=cache, mask=mask) + if cache is not None: + mx.eval([c.state for c in cache]) + mx.metal.clear_cache() + + # Return remaining slice + remaining_embeds = input_embeds[ + :, (num_batches - 1) * prefill_step_size :, : + ] + return remaining_embeds + + return input_embeds + + def get_topk_tokens(self, image_feature, attn, dominant_tokens_ratio=None): + batch_size, seq_len = image_feature.shape[:2] + + k_tokens = ( + int(image_feature.shape[1] * dominant_tokens_ratio) + if dominant_tokens_ratio is not None + else None + ) # keep 25% of the visual tokens + if k_tokens is None: + return image_feature + cls_idx = 0 # self.config.image_token_index + + attn_rec = mx.sum(attn[:, :, cls_idx + 1 :, cls_idx], axis=1) + + topk_idx = mx.argsort(attn_rec, axis=1)[:, -k_tokens:] + # use this to plot the dominant attention map + # https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/model/multimodal_encoder/clip_encoder.py#L62 + # https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/serve/gradio_web_server.py#L424 + + # Create CLS token indices array + # Shape: (B, 1) + cls_indices = mx.full((batch_size, 1), cls_idx, dtype=mx.int32) + + # Concat with CLS token index + # Add 1 to account for the offset after CLS token + dominant_idx = mx.concatenate([cls_indices, topk_idx + cls_idx + 1], axis=1) + + image_feature = mx.take(image_feature, dominant_idx, axis=1)[0] + return image_feature + + def merge_similar_visual_tokens( + self, image_feature, visual_token_ratio, merge_ratio=0.4 + ): + # Skip CLS token (first token) + tokens = image_feature[:, 1:] + batch_size, num_tokens, hidden_dim = tokens.shape + + # Calculate target number of tokens + target_tokens = max(1, int(num_tokens * visual_token_ratio)) + + while num_tokens > target_tokens: + # Calculate similarities between adjacent tokens + tokens_a = tokens[:, :-1] # all except last + tokens_b = tokens[:, 1:] # all except first + + # Calculate cosine similarity + a_norm = mx.sqrt(mx.sum(tokens_a * tokens_a, axis=-1, keepdims=True)) + b_norm = mx.sqrt(mx.sum(tokens_b * tokens_b, axis=-1, keepdims=True)) + similarities = mx.sum(tokens_a * tokens_b, axis=-1) + similarities = similarities / (a_norm.squeeze(-1) * b_norm.squeeze(-1)) + + # Sort similarities and get indices of pairs to merge + # We'll merge about 50% of remaining excess tokens in each iteration + num_to_merge = max(1, int((num_tokens - target_tokens) * merge_ratio)) + merge_indices = mx.argsort(similarities, axis=-1)[:, -num_to_merge:] + + # Create a list to track which indices to merge + to_merge = set(merge_indices[0].tolist()) + + # Merge selected pairs + new_tokens = [] + i = 0 + while i < num_tokens: + if i < num_tokens - 1 and i in to_merge: + # Merge this token with the next one + merged = (tokens[:, i : i + 1] + tokens[:, i + 1 : i + 2]) / 2 + new_tokens.append(merged) + i += 2 + elif i > 0 and (i - 1) in to_merge: + # Skip this token as it was merged in the previous step + i += 1 + else: + # Keep this token as is + new_tokens.append(tokens[:, i : i + 1]) + i += 1 + + # Update tokens + tokens = mx.concatenate(new_tokens, axis=1) + num_tokens = tokens.shape[1] + + # Reattach CLS token + return mx.concatenate([image_feature[:, :1], tokens], axis=1) diff --git a/mlx_vlm/models/cache.py b/mlx_vlm/models/cache.py new file mode 100644 index 0000000..c444b4e --- /dev/null +++ b/mlx_vlm/models/cache.py @@ -0,0 +1,150 @@ +import hashlib +import os +from pathlib import Path +from typing import Dict, Optional, Union + +import mlx.core as mx +from safetensors.torch import load_file, save_file + + +class VLMFeatureCache: + """Cache for storing and retrieving image features from Vision Language Models.""" + + def __init__(self, cache_dir: Union[str, Path]): + """ + Initialize the feature cache. + + Args: + cache_dir: Directory to store the cached features + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _compute_file_hash(self, file_path: Union[str, Path]) -> str: + """ + Compute SHA-256 hash of a file. + + Args: + file_path: Path to the file + + Returns: + str: Hex digest of the file hash + """ + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + # Read the file in chunks to handle large files efficiently + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + def _get_cache_path(self, file_hash: str) -> Path: + """ + Get the cache file path for a given file hash. + + Args: + file_hash: SHA-256 hash of the original file + + Returns: + Path: Path where the cached features should be stored + """ + return self.cache_dir / f"{file_hash}.safetensors" + + def save_features( + self, + image_path: Union[str, Path], + features: Dict[str, mx.array], + metadata: Optional[Dict[str, str]] = None, + ) -> str: + """ + Save image features to cache. + + Args: + image_path: Path to the original image file + features: Dictionary of feature tensors to cache + metadata: Optional metadata to store with the features + + Returns: + str: Hash of the cached file + """ + file_hash = self._compute_file_hash(image_path) + cache_path = self._get_cache_path(file_hash) + + # Add original file path to metadata + if metadata is None: + metadata = {} + metadata["original_file"] = str(image_path) + metadata["format"] = "mlx" + + # Save features using safetensors + mx.save_safetensors(str(cache_path), {"image_features": features}, metadata) + return file_hash + + def load_features( + self, image_path: Union[str, Path] + ) -> Optional[Dict[str, mx.array]]: + """ + Load cached features for an image if they exist. + + Args: + image_path: Path to the image file + + Returns: + Optional[Dict[str, mx.array]]: Cached features if they exist, None otherwise + """ + file_hash = self._compute_file_hash(image_path) + cache_path = self._get_cache_path(file_hash) + + if not cache_path.exists(): + return None + + features = mx.load(str(cache_path)) + return features + + def get_metadata(self, image_path: Union[str, Path]) -> Optional[Dict[str, str]]: + """ + Get metadata for cached features if they exist. + + Args: + image_path: Path to the image file + + Returns: + Optional[Dict[str, str]]: Metadata if cached features exist, None otherwise + """ + file_hash = self._compute_file_hash(image_path) + cache_path = self._get_cache_path(file_hash) + + if not cache_path.exists(): + return None + + return load_file(cache_path) + + def clear_cache(self): + """Remove all cached features.""" + for cache_file in self.cache_dir.glob("*.safetensors"): + cache_file.unlink() + + def get_cache_size(self) -> int: + """ + Get the total size of cached features in bytes. + + Returns: + int: Total size of cache in bytes + """ + return ( + sum(f.stat().st_size for f in self.cache_dir.glob("*.safetensors")) + / (1024 * 1024 * 1024), + "GB", + ) + + def __contains__(self, image_path: Union[str, Path]) -> bool: + """ + Check if features for an image are cached. + + Args: + image_path: Path to the image file + + Returns: + bool: True if features are cached, False otherwise + """ + file_hash = self._compute_file_hash(image_path) + return self._get_cache_path(file_hash).exists() diff --git a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py index 3629bcc..f1ebe91 100644 --- a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +++ b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py @@ -442,57 +442,19 @@ def __call__( images_spatial_crop = kwargs.get("images_spatial_crop", None) images_seq_mask = kwargs.get("images_seq_mask", None) - input_embeddings = self.get_input_embeddings( + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds = self.get_input_embeddings( input_ids, pixel_values, images_spatial_crop, images_seq_mask ) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( - input_ids, cache=cache, inputs_embeds=input_embeddings + input_ids, cache=cache, inputs_embeds=inputs_embeds ) return logits - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.projector_config = ProjectorConfig.from_dict( - model_config.projector_config - ) - model_config.text_config = TextConfig.from_dict(model_config.text_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model - @staticmethod def sanitize(weights): def transform_key(key): diff --git a/mlx_vlm/models/florence2/florence2.py b/mlx_vlm/models/florence2/florence2.py index 7895c4e..cf59a3b 100644 --- a/mlx_vlm/models/florence2/florence2.py +++ b/mlx_vlm/models/florence2/florence2.py @@ -351,46 +351,6 @@ def __call__( return outputs - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model - @staticmethod def sanitize(weights): sanitized_weights = {} diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 6de5642..bb09b78 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -259,57 +259,17 @@ def __call__( cache=None, **kwargs, ): - input_embeddings = self.get_input_embeddings(input_ids, pixel_values) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds = self.get_input_embeddings(input_ids, pixel_values) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( - inputs=input_ids, cache=cache, inputs_embeds=input_embeddings + inputs=input_ids, cache=cache, inputs_embeds=inputs_embeds ) return logits - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - config = json.load(f) - - text_config = AutoConfig.from_pretrained(config["text_config"]["model_type"]) - text_config = text_config.to_dict() - config["text_config"] = text_config - model_config = ModelConfig.from_dict(config) - model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) - model_config.text_config = TextConfig.from_dict(config["text_config"]) - model_config.perceiver_config = PerceiverConfig.from_dict( - config["perceiver_config"] - ) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = model.sanitize(weights=weights) - weights = VisionModel(model_config.vision_config).sanitize(weights=weights) - weights = LanguageModel(model_config.text_config).sanitize(weights=weights) - model.load_weights(list(weights.items())) - return model - def sanitize(self, weights): weights = { ( diff --git a/mlx_vlm/models/idefics3/idefics3.py b/mlx_vlm/models/idefics3/idefics3.py index 025909a..16f6334 100644 --- a/mlx_vlm/models/idefics3/idefics3.py +++ b/mlx_vlm/models/idefics3/idefics3.py @@ -144,54 +144,17 @@ def __call__( cache=None, **kwargs, ): - input_embeddings = self.get_input_embeddings(input_ids, pixel_values) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds = self.get_input_embeddings(input_ids, pixel_values) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( - inputs=input_ids, cache=cache, inputs_embeds=input_embeddings + inputs=input_ids, cache=cache, inputs_embeds=inputs_embeds ) return logits - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - config = json.load(f) - - text_config = AutoConfig.from_pretrained(config["text_config"]["model_type"]) - text_config = text_config.to_dict() - config["text_config"] = text_config - model_config = ModelConfig.from_dict(config) - model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) - model_config.text_config = TextConfig.from_dict(config["text_config"]) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = model.sanitize(weights=weights) - weights = VisionModel(model_config.vision_config).sanitize(weights=weights) - weights = LanguageModel(model_config.text_config).sanitize(weights=weights) - model.load_weights(list(weights.items())) - return model - def sanitize(self, weights): weights = { ( diff --git a/mlx_vlm/models/llava/language.py b/mlx_vlm/models/llava/language.py index 3efef2e..af507a0 100644 --- a/mlx_vlm/models/llava/language.py +++ b/mlx_vlm/models/llava/language.py @@ -22,6 +22,7 @@ class TextConfig: rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None tie_word_embeddings: bool = False + sliding_window: int = None @classmethod def from_dict(cls, params): @@ -51,9 +52,9 @@ def __init__(self, config: TextConfig): super().__init__() dim = config.hidden_size + self.config = config self.n_heads = n_heads = config.num_attention_heads self.n_kv_heads = n_kv_heads = config.num_key_value_heads - self.repeats = n_heads // n_kv_heads head_dim = config.hidden_size // n_heads @@ -88,7 +89,8 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[KVCache] = None, ) -> mx.array: - B, L, D = x.shape + + B, L, D = x.shape[:3] queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) @@ -101,6 +103,12 @@ def __call__( queries = self.rope(queries, offset=cache.offset) keys = self.rope(keys, offset=cache.offset) keys, values = cache.update_and_fetch(keys, values) + if self.config.sliding_window: + print(self.config.sliding_window) + keys = keys[:, :, self.config.sliding_window :, :] + values = values[:, :, self.config.sliding_window :, :] + if mask is not None: + mask = mask[:, self.config.sliding_window :] else: queries = self.rope(queries) keys = self.rope(keys) @@ -164,9 +172,10 @@ def __init__(self, config: TextConfig): def __call__( self, - inputs: mx.array, + inputs: mx.array = None, cache=None, inputs_embeds=None, + mask: Optional[mx.array] = None, ): # for passing merged input embeddings if inputs_embeds is None: @@ -174,11 +183,12 @@ def __call__( else: h = inputs_embeds - mask = create_attention_mask(h) - if cache is None: cache = [None] * len(self.layers) + # if mask is None: + mask = create_attention_mask(h, cache) + for layer, c in zip(self.layers, cache): h = layer(h, mask, c) @@ -200,12 +210,12 @@ def __init__(self, config: TextConfig): def __call__( self, - inputs: mx.array, + inputs: mx.array = None, cache=None, inputs_embeds=None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache, inputs_embeds) + out = self.model(inputs, cache, inputs_embeds, mask) if self.config.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 696d9db..7412263 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -54,7 +55,7 @@ def __call__(self, x: mx.array) -> mx.array: return x -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config @@ -68,6 +69,8 @@ def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, + merge_similar_tokens_ratio: Optional[float] = 1, + filter_topk_tokens_ratio: Optional[float] = 1, ): if pixel_values is None: return self.language_model.model.embed_tokens(input_ids) @@ -76,17 +79,31 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - *_, hidden_states = self.vision_tower( - pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + *_, hidden_states, all_attn = self.vision_tower( + pixel_values.transpose(0, 2, 3, 1), + output_hidden_states=True, + output_attn=True, ) + # Get the attention from the desired layer + all_attn = all_attn[self.vision_feature_layer] # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] + # Select dominant tokens + selected_image_feature = self.get_topk_tokens( + selected_image_feature, all_attn, filter_topk_tokens_ratio + ) + + # Merge similar tokens + selected_image_feature = self.merge_similar_visual_tokens( + selected_image_feature, merge_similar_tokens_ratio + ) + if self.vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature + pass # Keep full image features without modification else: raise ValueError( "Unexpected feature selection strategy: " @@ -111,12 +128,6 @@ def _merge_input_ids_with_image_features( # Positions of tokens in input_ids, assuming batch size is 1 image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() - if len(image_positions) != num_images: - raise ValueError( - f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." - ) - text_segments = [] start_idx = 0 @@ -140,48 +151,22 @@ def __call__( cache=None, **kwargs, ): - input_embddings = self.get_input_embeddings(input_ids, pixel_values) - logits = self.language_model( - input_ids, cache=cache, inputs_embeds=input_embddings + merge_similar_tokens_ratio = kwargs.get("merge_similar_tokens_ratio", 1) + filter_topk_tokens_ratio = kwargs.get("filter_topk_tokens_ratio", 1) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds = self.get_input_embeddings( + input_ids, + pixel_values, + merge_similar_tokens_ratio, + filter_topk_tokens_ratio, ) - return logits - - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size ) + self.config.text_config.sliding_window = 4096 - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config.text_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model + logits = self.language_model( + input_ids, cache=cache, inputs_embeds=inputs_embeds + ) + return logits diff --git a/mlx_vlm/models/llava/vision.py b/mlx_vlm/models/llava/vision.py index 31c2734..0a1fdc1 100644 --- a/mlx_vlm/models/llava/vision.py +++ b/mlx_vlm/models/llava/vision.py @@ -84,7 +84,7 @@ def __init__( self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) - def __call__(self, queries, keys, values, mask=None): + def __call__(self, queries, keys, values, mask=None, output_attn=False): queries = self.q_proj(queries) keys = self.k_proj(keys) values = self.v_proj(values) @@ -96,12 +96,15 @@ def __call__(self, queries, keys, values, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) + if output_attn: + return self.out_proj(output), attn + else: + return self.out_proj(output) class MLP(nn.Module): @@ -128,13 +131,15 @@ def __init__(self, config: VisionConfig): self.mlp = MLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + def __call__( + self, x: mx.array, mask: Optional[mx.array] = None, output_attn: bool = False + ) -> mx.array: y = self.layer_norm1(x) - y = self.self_attn(y, y, y, mask) + y, attn = self.self_attn(y, y, y, mask, output_attn=output_attn) x = x + y y = self.layer_norm2(x) y = self.mlp(y) - return x + y + return x + y, attn class Encoder(nn.Module): @@ -207,20 +212,24 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, + output_attn: bool = False, ) -> mx.array: x = self.embeddings(x) if self.config.model_type == "clip_vision_model": x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None + all_attn = () if output_attn else None for l in self.encoder.layers: - x = l(x, mask=None) + x, attn = l(x, mask=None, output_attn=output_attn) if output_hidden_states: encoder_states = encoder_states + (x,) + if output_attn: + all_attn = all_attn + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) - return pooler_output, x, encoder_states + return pooler_output, x, encoder_states, all_attn class VisionModel(nn.Module): @@ -234,9 +243,12 @@ def __init__(self, config: VisionConfig): self.vision_model = ClipVisionModel(config) def __call__( - self, x: mx.array, output_hidden_states: Optional[bool] = None + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + output_attn: bool = False, ) -> mx.array: - return self.vision_model(x, output_hidden_states) + return self.vision_model(x, output_hidden_states, output_attn) def sanitize(self, weights): sanitized_weights = {} diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index 730d4dc..3b0c733 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -139,19 +139,31 @@ def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, + merge_similar_tokens_ratio: Optional[float] = 1, + filter_topk_tokens_ratio: Optional[float] = 1, ): if pixel_values is None: return self.language_model.model.embed_tokens(input_ids) inputs_embeds = self.language_model.model.embed_tokens(input_ids) - *_, hidden_state = self.vision_tower( - pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + *_, hidden_state, all_attn = self.vision_tower( + pixel_values.transpose(0, 2, 3, 1), + output_hidden_states=True, + output_attn=True, ) image_features = hidden_state[-1].astype(pixel_values.dtype) assert image_features.shape[-2] == 729 + image_features = self.get_topk_tokens( + image_features, all_attn, filter_topk_tokens_ratio + ) + + image_features = self.merge_similar_visual_tokens( + image_features, merge_similar_tokens_ratio + ) + image_features = self.mm_projector(image_features) final_inputs_embeds = self._prepare_inputs_for_multimodal( @@ -192,59 +204,24 @@ def __call__( cache: Optional[Tuple[mx.array, mx.array]] = None, **kwargs, ): - input_embeddings = self.get_input_embeddings(input_ids, pixel_values) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + merge_similar_tokens_ratio = kwargs.get("merge_similar_tokens_ratio", 1) + filter_topk_tokens_ratio = kwargs.get("filter_topk_tokens_ratio", 1) + inputs_embeds = self.get_input_embeddings( + input_ids, + pixel_values, + merge_similar_tokens_ratio, + filter_topk_tokens_ratio, + ) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( - inputs=input_ids, cache=cache, inputs_embeds=input_embeddings, mask=mask + inputs=input_ids, cache=cache, inputs_embeds=inputs_embeds, mask=mask ) return logits - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - config = json.load(f) - - siglip_config = AutoConfig.from_pretrained(config["mm_vision_tower"]) - text_config = AutoConfig.from_pretrained(config["language_model"]) - siglip_config = siglip_config.to_dict() - text_config = text_config.to_dict() - config["vision_config"] = siglip_config["vision_config"] - config["text_config"] = text_config - - model_config = ModelConfig.from_dict(config) - model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) - model_config.text_config = TextConfig.from_dict(config["text_config"]) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = model.sanitize(weights=weights) - - weights = VisionModel(model_config.vision_config).sanitize(weights=weights) - weights = LanguageModel(model_config.text_config).sanitize(weights=weights) - model.load_weights(list(weights.items())) - return model - def sanitize(self, weights): weights = { ( diff --git a/mlx_vlm/models/llava_bunny/vision.py b/mlx_vlm/models/llava_bunny/vision.py index df3e3c5..21aa099 100644 --- a/mlx_vlm/models/llava_bunny/vision.py +++ b/mlx_vlm/models/llava_bunny/vision.py @@ -83,7 +83,7 @@ def __init__( self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) - def __call__(self, queries, keys, values, mask=None): + def __call__(self, queries, keys, values, mask=None, output_attn=False): queries = self.q_proj(queries) keys = self.k_proj(keys) values = self.v_proj(values) @@ -95,11 +95,15 @@ def __call__(self, queries, keys, values, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) + + if output_attn: + return self.out_proj(output), attn + else: + return self.out_proj(output) class MHA(nn.Module): @@ -124,7 +128,7 @@ def __init__( self.in_proj = nn.Linear(dims, dims * 3, bias=bias) self.out_proj = nn.Linear(dims, dims, bias=bias) - def __call__(self, queries: mx.array, kv: mx.array, mask=None): + def __call__(self, queries: mx.array, kv: mx.array, mask=None, output_attn=False): B, L, D = queries.shape qkv = self.in_proj(queries) @@ -137,11 +141,15 @@ def __call__(self, queries: mx.array, kv: mx.array, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) + + if output_attn: + return self.out_proj(output), attn + else: + return self.out_proj(output) class MLP(nn.Module): @@ -170,11 +178,11 @@ def __init__(self, config: VisionConfig): def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: y = self.layer_norm1(x) - y = self.self_attn(y, y, y, mask) + y, attn = self.self_attn(y, y, y, mask) x = x + y y = self.layer_norm2(x) y = self.mlp(y) - return x + y + return x + y, attn class Encoder(nn.Module): @@ -225,19 +233,22 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, + output_attn: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) encoder_states = (x,) if output_hidden_states else None for l in self.encoder.layers: - x = l(x, mask=None) + x, attn = l(x, mask=None, output_attn=output_attn) if output_hidden_states: encoder_states = encoder_states + (x,) + if output_attn: + all_attn = all_attn + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) pooler_output = self.head(pooler_output) - return pooler_output, x, encoder_states + return pooler_output, x, encoder_states, all_attn class SigLipMultiheadAttentionPoolingHead(nn.Module): diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index f10649f..6f45145 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -100,6 +100,7 @@ def get_input_embeddings( # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) + print("image_features.shape", image_features.shape) if self.image_newline is not None: self.image_newline = np.array(self.image_newline)[None, None, :] self.image_newline = np.broadcast_to( @@ -144,49 +145,13 @@ def __call__( cache=None, **kwargs, ): - - input_embddings = self.get_input_embeddings(input_ids, pixel_values) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds = self.get_input_embeddings(input_ids, pixel_values) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( - input_ids, cache=cache, inputs_embeds=input_embddings + input_ids, cache=cache, inputs_embeds=inputs_embeds ) return logits - - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config.text_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model diff --git a/mlx_vlm/models/mllama/mllama.py b/mlx_vlm/models/mllama/mllama.py index 4bb8bc2..c5fe1cd 100644 --- a/mlx_vlm/models/mllama/mllama.py +++ b/mlx_vlm/models/mllama/mllama.py @@ -9,7 +9,7 @@ import mlx.nn as nn from huggingface_hub import snapshot_download -from ..base import KVCache +from ..base import BaseModel, KVCache from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -36,7 +36,7 @@ def from_dict(cls, params): ) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config @@ -60,6 +60,7 @@ def __call__( aspect_ratio_ids = kwargs.pop("aspect_ratio_ids", None) aspect_ratio_mask = kwargs.pop("aspect_ratio_mask", None) cross_attention_mask = kwargs.pop("cross_attention_mask", None) + prefill_step_size = kwargs.pop("prefill_step_size", 256) inputs_embeds = None @@ -111,6 +112,11 @@ def __call__( :, :, cache_position ] + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) + # Process language input outputs = self.language_model( input_ids=input_ids, @@ -158,46 +164,6 @@ def _prepare_cross_attention_mask( return cross_attention_mask, full_text_row_masked_out_mask - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model - def sanitize(self, weights): def transform_key(key): if "vision_tower" not in key: diff --git a/mlx_vlm/models/molmo/language.py b/mlx_vlm/models/molmo/language.py index 90b781a..59518cf 100644 --- a/mlx_vlm/models/molmo/language.py +++ b/mlx_vlm/models/molmo/language.py @@ -79,7 +79,7 @@ def __init__(self, config: TextConfig): self.act = SwiGLU() def __call__(self, x, mask=None, cache=None): - batch_size, seq_len, D = x.shape + batch_size, seq_len, D = x.shape[:3] attn_in = self.attn_norm(x) qkv = self.att_proj(attn_in) @@ -184,7 +184,7 @@ def __call__( cache = [None] * self.config.n_layers if mask is None: - mask = create_attention_mask(h) + mask = create_attention_mask(h, cache) for block, c in zip(self.blocks, cache): h = block(h, mask, c) @@ -212,7 +212,7 @@ def __init__(self, config: TextConfig): def __call__( self, - input_ids: mx.array, + input_ids: mx.array = None, inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None, diff --git a/mlx_vlm/models/molmo/molmo.py b/mlx_vlm/models/molmo/molmo.py index e3e345d..262c84b 100644 --- a/mlx_vlm/models/molmo/molmo.py +++ b/mlx_vlm/models/molmo/molmo.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -36,26 +37,39 @@ def from_dict(cls, params): ) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config self.language_model = LanguageModel(config.text_config) self.vision_tower = VisionModel(config.vision_config) - def __call__( - self, - input_ids: mx.array, - pixel_values: mx.array, - mask: mx.array, - cache=None, - **kwargs, - ) -> Dict[str, Union[mx.array, List[Tuple[mx.array, mx.array]]]]: + def merge_features(self, input_ids: mx.array, image_features: mx.array, **kwargs): + image_input_idx = kwargs.get("image_input_idx", None) + if image_input_idx is not None and image_input_idx.ndim == 2: + image_input_idx = mx.expand_dims(image_input_idx, 0) + elif image_input_idx is None: + raise ValueError("image_input_idx must be provided") + + batch_size, seq_len = input_ids.shape[:2] + num_image, num_patch = image_input_idx.shape[1:3] + image_input_idx = image_input_idx.reshape(batch_size, num_image * num_patch) + + valid = np.where(image_input_idx >= 0)[0].tolist() + batch_idx = mx.arange(batch_size) + batch_idx = mx.tile(batch_idx[:, None], [1, image_features.shape[1]]) + + input_embeddings = self.language_model.model.wte(input_ids) + input_embeddings[batch_idx[valid], image_input_idx[valid]] += image_features[ + valid + ] + return input_embeddings + + def encode_image(self, input_ids: mx.array, pixel_values: mx.array, **kwargs): if input_ids.ndim == 1: input_ids = input_ids[None, :] batch_size, seq_len = input_ids.shape - image_input_idx = kwargs.get("image_input_idx", None) image_masks = kwargs.get("image_masks", None) @@ -94,69 +108,53 @@ def __call__( image_features = image_features.reshape( batch_size, num_image * num_patch, -1 ) - image_input_idx = image_input_idx.reshape(batch_size, num_image * num_patch) + input_embeddings = self.merge_features( + input_ids, image_features, image_input_idx=image_input_idx + ) + else: + input_embeddings = None - valid = np.where(image_input_idx >= 0)[0].tolist() - batch_idx = mx.arange(batch_size) - batch_idx = mx.tile(batch_idx[:, None], [1, image_features.shape[1]]) + return input_embeddings, image_features - input_embeddings = self.language_model.model.wte(input_ids) - input_embeddings[ - batch_idx[valid], image_input_idx[valid] - ] += image_features[valid] + def __call__( + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, + ) -> Dict[str, Union[mx.array, List[Tuple[mx.array, mx.array]]]]: + input_ids = input_ids[None, :] + image_features = kwargs.get("image_features", None) + merged_features = kwargs.get("merged_features", None) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + + if pixel_values is None: + inputs_embeds = self.language_model.model.wte(input_ids)[0] + elif image_features is None: + inputs_embeds, image_features = self.encode_image( + input_ids, pixel_values, **kwargs + ) + elif merged_features is None: + inputs_embeds = self.merge_features(input_ids, **kwargs) else: - input_embeddings = None + inputs_embeds = merged_features + + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) # Forward pass through the language model logits = self.language_model( input_ids, - inputs_embeds=input_embeddings, + inputs_embeds=inputs_embeds, mask=mask, cache=cache, ) return logits - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config.text_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model - def sanitize(self, weights): def transform_key(key): if "model.transformer" in key: diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index d512abc..1374474 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -372,52 +372,13 @@ def __call__( cache=None, **kwargs, ): - - input_embeddings = self.get_input_embeddings(input_ids, pixel_values) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds = self.get_input_embeddings(input_ids, pixel_values) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( - input_ids, cache=cache, inputs_embeds=input_embeddings + input_ids, cache=cache, inputs_embeds=inputs_embeds ) return logits - - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.projector_config = ProjectorConfig.from_dict( - model_config.projector_config - ) - model_config.text_config = TextConfig.from_dict(model_config.text_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model diff --git a/mlx_vlm/models/paligemma/paligemma.py b/mlx_vlm/models/paligemma/paligemma.py index b8179b9..b346436 100644 --- a/mlx_vlm/models/paligemma/paligemma.py +++ b/mlx_vlm/models/paligemma/paligemma.py @@ -142,9 +142,14 @@ def __call__( cache: Optional[mx.array] = None, **kwargs, ): - input_embeddings, final_attention_mask_4d = self.get_input_embeddings( + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds, final_attention_mask_4d = self.get_input_embeddings( input_ids, pixel_values, mask ) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( inputs=input_ids, @@ -153,42 +158,3 @@ def __call__( mask=final_attention_mask_4d, ) return logits - - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - config = json.load(f) - - model_config = ModelConfig.from_dict(config) - model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) - model_config.text_config = TextConfig.from_dict(config["text_config"]) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = model.sanitize(weights=weights) - - weights = VisionModel(model_config.vision_config).sanitize(weights=weights) - model.load_weights(list(weights.items())) - return model diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py index 6f2d412..6b613f8 100644 --- a/mlx_vlm/models/phi3_v/phi3_v.py +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -177,12 +177,17 @@ def __call__( pixel_values=None, image_sizes=None, cache=None, + **kwargs, ): + prefill_step_size = kwargs.pop("prefill_step_size", 256) + h = self.embed_tokens(inputs) p = np.argwhere(inputs < 0).tolist() if pixel_values is not None: h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) + else: + h = self.prefill(h, cache=cache, prefill_step_size=prefill_step_size) mask = create_attention_mask(h) @@ -212,6 +217,7 @@ def __call__( image_sizes=None, **kwargs, ): + out = self.model(inputs, pixel_values, image_sizes, cache) logits = self.lm_head(out).astype(self.lm_head.weight.dtype) return LanguageModelOutput(logits=logits) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 37dbc3c..b167e1b 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -144,52 +144,17 @@ def __call__( cache=None, **kwargs, ): - input_embddings = self.get_input_embeddings(input_ids, pixel_values) + prefill_step_size = kwargs.pop("prefill_step_size", 256) + inputs_embeds = self.get_input_embeddings(input_ids, pixel_values) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size + ) logits = self.language_model( - input_ids, cache=cache, inputs_embeds=input_embddings + input_ids, cache=cache, inputs_embeds=inputs_embeds ) return logits - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config.text_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model - def sanitize(self, weights): def transform_key(key): if "vision_tower" in key and "vision_model" not in key: diff --git a/mlx_vlm/models/qwen2_vl/language.py b/mlx_vlm/models/qwen2_vl/language.py index 1295de2..1effcd1 100644 --- a/mlx_vlm/models/qwen2_vl/language.py +++ b/mlx_vlm/models/qwen2_vl/language.py @@ -92,7 +92,7 @@ def __call__( offset = cache.offset if cache else 0 - if mask is not None: + if mask is not None and keys.shape[-2] == 1: mask = mask[..., : keys.shape[-2]] queries = self.rotary_emb(queries, offset=offset) @@ -160,20 +160,22 @@ def __init__(self, args: TextConfig): def __call__( self, - inputs: mx.array, + input_ids: mx.array, cache=None, inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, ): if inputs_embeds is None: - h = self.embed_tokens(inputs) + h = self.embed_tokens(input_ids) else: h = inputs_embeds - mask = create_attention_mask(h, cache) - if cache is None: cache = [None] * len(self.layers) + if mask is None: + mask = create_attention_mask(h, cache) + for layer, c in zip(self.layers, cache): h = layer(h, mask, c) @@ -195,12 +197,12 @@ def __init__(self, args: TextConfig): def __call__( self, - inputs: mx.array, + input_ids: mx.array = None, cache=None, inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds) + out = self.model(input_ids, cache=cache, inputs_embeds=inputs_embeds, mask=mask) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index bc90762..f883bea 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -42,7 +43,7 @@ def from_dict(cls, params): ) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config @@ -99,56 +100,24 @@ def __call__( cache=None, **kwargs, ): + prefill_step_size = kwargs.pop("prefill_step_size", 256) image_grid_thw = kwargs.pop("image_grid_thw", None) if image_grid_thw is not None: image_grid_thw = mx.array(image_grid_thw) - input_embddings = self.get_input_embeddings( + inputs_embeds = self.get_input_embeddings( input_ids, pixel_values, image_grid_thw ) - logits = self.language_model(None, cache=cache, inputs_embeds=input_embddings) - return logits - - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) + if pixel_values is None: + inputs_embeds = self.prefill( + inputs_embeds, cache=cache, prefill_step_size=prefill_step_size ) - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = ModelConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config) - - model = Model(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - - model.load_weights(list(weights.items())) - return model + logits = self.language_model( + input_ids=input_ids, cache=cache, inputs_embeds=inputs_embeds + ) + return logits def sanitize(self, weights): def transform_key(key): diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index bd699a4..bd32514 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -88,7 +88,7 @@ def __call__(self, seqlen: int) -> mx.array: inv_freq = 1.0 / ( self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim) ) - seq = mx.arange(seqlen, dtype=inv_freq.dtype) + seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype) freqs = mx.outer(seq, inv_freq) return freqs diff --git a/mlx_vlm/tests/test_smoke.py b/mlx_vlm/tests/test_smoke.py index 70877a0..07ad175 100644 --- a/mlx_vlm/tests/test_smoke.py +++ b/mlx_vlm/tests/test_smoke.py @@ -130,8 +130,11 @@ def test_generation( console.print(f"[bold green]✓[/] {test_type} generation successful") return False - except Exception as e: - console.print(f"[bold red]✗[/] {test_type} generation failed: {str(e)}") + except: + import traceback + + console.print(f"[bold red]✗[/] {test_type} generation failed:") + console.print(traceback.format_exc()) return True diff --git a/mlx_vlm/tests/test_trainer.py b/mlx_vlm/tests/test_trainer.py index 97339e7..70dd370 100644 --- a/mlx_vlm/tests/test_trainer.py +++ b/mlx_vlm/tests/test_trainer.py @@ -47,15 +47,15 @@ def test_dataset_getitem(self, mock_prepare_inputs, mock_get_prompt): mock_get_prompt.return_value = "Mocked prompt" - mock_prepare_inputs.return_value = ( - mx.array([1, 2, 3]), # input_ids - mx.array( + mock_prepare_inputs.return_value = { + "input_ids": mx.array([1, 2, 3]), # input_ids + "pixel_values": mx.array( [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] ), # pixel_values - mx.array([1, 1, 1]), # mask - (1, 1, 1), # image_grid_thw - [224, 224], # image_sizes - ) + "attention_mask": mx.array([1, 1, 1]), # mask + "image_grid_thw": (1, 1, 1), # image_grid_thw + "image_sizes": [224, 224], # image_sizes + } result = dataset[0] diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 213b0ef..22f61f2 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -89,27 +89,21 @@ def __getitem__(self, idx): image_token_index = self.config["image_token_index"] inputs = prepare_inputs( - self.image_processor, self.processor, images, prompts, image_token_index, self.image_resize_shape, ) - input_ids, pixel_values, mask = inputs[:3] + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + mask = inputs["attention_mask"] kwargs = { k: v - for k, v in zip( - [ - "image_grid_thw", - "image_sizes", - "aspect_ratio_ids", - "aspect_ratio_mask", - "cross_attention_mask", - ], - inputs[3:], - ) + for k, v in inputs.items() + if k not in ["input_ids", "pixel_values", "attention_mask"] } + if mask is None: mask = mx.ones_like(input_ids) @@ -226,16 +220,11 @@ def loss_fn(self, model, batch): input_ids = input_ids[:, :-1] - kwargs = {} - image_keys = [ - "image_grid_thw", - "image_sizes", - "aspect_ratio_ids", - "aspect_ratio_mask", - "cross_attention_mask", - ] - if any(key in batch for key in image_keys): - kwargs = {key: batch[key] for key in image_keys if key in batch} + kwargs = { + k: v + for k, v in batch.items() + if k not in ["input_ids", "pixel_values", "attention_mask"] + } # Forward pass outputs = model(input_ids, pixel_values, attention_mask, **kwargs) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index e1d0c37..d8ff538 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -25,7 +25,9 @@ PreTrainedTokenizerFast, ) -from .models.base import BaseImageProcessor, KVCache, SimpleKVCache +from mlx_vlm.models.cache import VLMFeatureCache + +from .models.base import BaseImageProcessor, KVCache, RotatingKVCache, SimpleKVCache from .sample_utils import top_p_sampling from .tokenizer_utils import load_tokenizer from .trainer import apply_lora_layers @@ -421,7 +423,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): api = HfApi() api.create_repo(repo_id=upload_repo, exist_ok=True) - api.upload_folder( + api.upload_large_folder( folder_path=path, repo_id=upload_repo, repo_type="model", @@ -841,11 +843,12 @@ def generate_step( mask, *, max_tokens: int = 256, - temp: float = 0.0, + temperature: float = 0.0, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, top_p: float = 1.0, logit_bias: Optional[Dict[int, float]] = None, + max_kv_size: Optional[int] = None, **kwargs, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ @@ -863,6 +866,7 @@ def generate_step( top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. logit_bias (dictionary, optional): Additive logit bias. + max_kv_size (int, optional): Set the maximum key-value cache size. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -876,13 +880,13 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: logits[:, indices] += values logprobs = logits - mx.logsumexp(logits) - if temp == 0: + if temperature == 0: token = mx.argmax(logits, axis=-1) else: if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) + token = top_p_sampling(logits, top_p, temperature) else: - token = mx.random.categorical(logits * (1 / temp)) + token = mx.random.categorical(logits * (1 / temperature)) return token, logprobs @@ -907,7 +911,18 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: (SimpleKVCache(), SimpleKVCache()) for n in model.language_model.layers ] else: - cache = [KVCache(model.language_model.head_dim, n) for n in kv_heads] + if max_kv_size is None: + cache = [KVCache(model.language_model.head_dim, n) for n in kv_heads] + else: + cache = [ + RotatingKVCache( + model.language_model.head_dim, + n, + max_size=max_kv_size, + keep=max_kv_size // 2 if pixel_values is None else 4, + ) + for n in kv_heads + ] repetition_context = input_ids.reshape(-1).tolist() @@ -1098,6 +1113,9 @@ def generate( text = "" last_response = None + merge_similar_tokens_ratio = kwargs.get("merge_similar_tokens_ratio", 1) + filter_topk_tokens_ratio = kwargs.get("filter_topk_tokens_ratio", 1) + for response in stream_generate(model, processor, prompt, image, **kwargs): if verbose: print(response.text, end="", flush=True) @@ -1109,8 +1127,12 @@ def generate( if len(text) == 0: print("No text generated for this prompt") return + + total_tokens = ( + last_response.prompt_tokens * merge_similar_tokens_ratio + ) * filter_topk_tokens_ratio print( - f"Prompt: {last_response.prompt_tokens} tokens, " + f"Prompt: {int(total_tokens)} tokens, " f"{last_response.prompt_tps:.3f} tokens-per-sec" ) print(