Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Prompt caching and Vision token merging + filtering #177

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6151668
image feature caching
Blaizzy Jan 4, 2025
31d33d8
add attn and visionzip
Blaizzy Jan 5, 2025
6a97e8e
add sizes
Blaizzy Jan 5, 2025
7eabf44
todo: add debug visualization
Blaizzy Jan 5, 2025
fd1309e
Add token filtering and merging
Blaizzy Jan 7, 2025
8e273fd
add default merge ratio to 40
Blaizzy Jan 7, 2025
d012d01
Merge branch 'main' of https://github.com/Blaizzy/mlx-vlm into pc/pro…
Blaizzy Jan 8, 2025
7201780
use update_large_folder
Blaizzy Jan 8, 2025
21fc1b2
Fix trainer and Qwen2-VL (#179)
Blaizzy Jan 11, 2025
80f808c
add full traceback to smoke test
Blaizzy Jan 11, 2025
5681b24
testing rotating kvcache
Blaizzy Jan 18, 2025
d014f30
add prefill to baseModel
Blaizzy Jan 18, 2025
37db9b6
add prefill and merge
Blaizzy Jan 18, 2025
e676669
image feature caching
Blaizzy Jan 4, 2025
29bb5d3
add attn and visionzip
Blaizzy Jan 5, 2025
3607c94
add sizes
Blaizzy Jan 5, 2025
aebcc1b
todo: add debug visualization
Blaizzy Jan 5, 2025
5f950b3
Add token filtering and merging
Blaizzy Jan 7, 2025
875c9d6
add default merge ratio to 40
Blaizzy Jan 7, 2025
d7fbc04
use update_large_folder
Blaizzy Jan 8, 2025
d4a33db
add full traceback to smoke test
Blaizzy Jan 11, 2025
6827907
testing rotating kvcache
Blaizzy Jan 18, 2025
2ff7db1
add prefill to baseModel
Blaizzy Jan 18, 2025
897b270
add prefill and merge
Blaizzy Jan 18, 2025
4e52ce1
Merge branch 'pc/prompt-caching' of https://github.com/Blaizzy/mlx-vl…
Blaizzy Jan 18, 2025
284504c
renam temp to temperature
Blaizzy Jan 18, 2025
6b9e8d8
fix prefill
Blaizzy Jan 18, 2025
6a98147
add prefill and fix token merge & filter
Blaizzy Jan 18, 2025
b3aae82
dynamic sliding window
Blaizzy Jan 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,38 @@ def parse_arguments():
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Temperature for sampling."
"--temperature",
type=float,
default=DEFAULT_TEMP,
help="Temperature for sampling.",
)
parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.")
parser.add_argument("--verbose", action="store_false", help="Detailed output.")
parser.add_argument(
"--merge-similar-tokens-ratio",
type=float,
default=1.0,
help="Ratio of visual tokens to keep during merging similar tokens (between 0.1 and 1.0).",
choices=[x / 10 for x in range(1, 11)],
)
parser.add_argument(
"--filter-topk-tokens-ratio",
type=float,
help="Ratio of visual tokens to keep during filtering topk tokens (between 0.1 and 1.0).",
choices=[x / 10 for x in range(1, 11)],
)
parser.add_argument(
"--max-kv-size",
type=int,
default=None,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--prefill-step-size",
type=int,
default=256,
help="Set the prefill step size",
)
return parser.parse_args()


Expand All @@ -97,6 +125,9 @@ def main():
prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image))

kwargs = {}

if args.max_kv_size is not None:
kwargs["max_kv_size"] = args.max_kv_size
if args.resize_shape is not None:
resize_shape = args.resize_shape
if len(resize_shape) not in [1, 2]:
Expand All @@ -107,6 +138,19 @@ def main():
else resize_shape
)

kwargs["merge_similar_tokens_ratio"] = args.merge_similar_tokens_ratio
if args.filter_topk_tokens_ratio is None:
# If merge ratio is specified but filter ratio isn't, automatically set filter ratio
if args.merge_similar_tokens_ratio < 0.9:
# For aggressive merging (ratio < 0.9), keep 90% of tokens
kwargs["filter_topk_tokens_ratio"] = 0.90
else:
# For light merging (ratio >= 0.9), keep all tokens
kwargs["filter_topk_tokens_ratio"] = 1.0
else:
# Use explicitly provided filter ratio
kwargs["filter_topk_tokens_ratio"] = args.filter_topk_tokens_ratio

if args.chat:
chat = []
if args.system:
Expand All @@ -124,7 +168,7 @@ def main():
prompt,
args.image,
max_tokens=args.max_tokens,
temp=args.temp,
temperature=args.temperature,
**kwargs,
):
response += chunk.text
Expand All @@ -139,7 +183,7 @@ def main():
processor,
prompt,
image=args.image,
temp=args.temp,
temperature=args.temperature,
max_tokens=args.max_tokens,
verbose=args.verbose,
**kwargs,
Expand Down
128 changes: 126 additions & 2 deletions mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional

import mlx.core as mx
import mlx.nn as nn
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor
from transformers.image_processing_utils import get_size_dict
Expand Down Expand Up @@ -97,6 +98,10 @@ def update(self, keys, values):
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values

@property
def state(self):
return self.keys, self.values


class SimpleKVCache:
"""A simple key-value cache for transformer attention layers.
Expand Down Expand Up @@ -148,15 +153,15 @@ def update(self, keys, values):

class RotatingKVCache:

def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
def __init__(self, head_dim, n_kv_heads, max_size, keep=None, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keep = keep if keep is not None else step // 2
self.keys = None
self.values = None
self.offset = 0
Expand Down Expand Up @@ -271,3 +276,122 @@ class LanguageModelOutput:
logits: mx.array
cross_attention_states: Optional[List[mx.array]] = None
encoder_outputs: Optional[List[mx.array]] = None


class BaseModel(nn.Module):
def __init__(self):
super().__init__()
self.vision_tower = None
self.language_model = None

def prefill(self, input_embeds, cache=None, prefill_step_size=256):
# Process input in batches for better parallelization
num_batches = (
input_embeds.shape[1] + prefill_step_size - 1
) // prefill_step_size

if num_batches > 1:
# Pre-allocate slices for better memory efficiency
slices = [
input_embeds[:, i * prefill_step_size : (i + 1) * prefill_step_size, :]
for i in range(num_batches - 1)
]

# Process all full-sized batches in parallel
for slice in slices:
mask = create_attention_mask(slice, cache)
self.language_model(inputs_embeds=slice, cache=cache, mask=mask)
if cache is not None:
mx.eval([c.state for c in cache])
mx.metal.clear_cache()

# Return remaining slice
remaining_embeds = input_embeds[
:, (num_batches - 1) * prefill_step_size :, :
]
return remaining_embeds

return input_embeds

def get_topk_tokens(self, image_feature, attn, dominant_tokens_ratio=None):
batch_size, seq_len = image_feature.shape[:2]

k_tokens = (
int(image_feature.shape[1] * dominant_tokens_ratio)
if dominant_tokens_ratio is not None
else None
) # keep 25% of the visual tokens
if k_tokens is None:
return image_feature
cls_idx = 0 # self.config.image_token_index

attn_rec = mx.sum(attn[:, :, cls_idx + 1 :, cls_idx], axis=1)

topk_idx = mx.argsort(attn_rec, axis=1)[:, -k_tokens:]
# use this to plot the dominant attention map
# https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/model/multimodal_encoder/clip_encoder.py#L62
# https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/serve/gradio_web_server.py#L424

# Create CLS token indices array
# Shape: (B, 1)
cls_indices = mx.full((batch_size, 1), cls_idx, dtype=mx.int32)

# Concat with CLS token index
# Add 1 to account for the offset after CLS token
dominant_idx = mx.concatenate([cls_indices, topk_idx + cls_idx + 1], axis=1)

image_feature = mx.take(image_feature, dominant_idx, axis=1)[0]
return image_feature

def merge_similar_visual_tokens(
self, image_feature, visual_token_ratio, merge_ratio=0.4
):
# Skip CLS token (first token)
tokens = image_feature[:, 1:]
batch_size, num_tokens, hidden_dim = tokens.shape

# Calculate target number of tokens
target_tokens = max(1, int(num_tokens * visual_token_ratio))

while num_tokens > target_tokens:
# Calculate similarities between adjacent tokens
tokens_a = tokens[:, :-1] # all except last
tokens_b = tokens[:, 1:] # all except first

# Calculate cosine similarity
a_norm = mx.sqrt(mx.sum(tokens_a * tokens_a, axis=-1, keepdims=True))
b_norm = mx.sqrt(mx.sum(tokens_b * tokens_b, axis=-1, keepdims=True))
similarities = mx.sum(tokens_a * tokens_b, axis=-1)
similarities = similarities / (a_norm.squeeze(-1) * b_norm.squeeze(-1))

# Sort similarities and get indices of pairs to merge
# We'll merge about 50% of remaining excess tokens in each iteration
num_to_merge = max(1, int((num_tokens - target_tokens) * merge_ratio))
merge_indices = mx.argsort(similarities, axis=-1)[:, -num_to_merge:]

# Create a list to track which indices to merge
to_merge = set(merge_indices[0].tolist())

# Merge selected pairs
new_tokens = []
i = 0
while i < num_tokens:
if i < num_tokens - 1 and i in to_merge:
# Merge this token with the next one
merged = (tokens[:, i : i + 1] + tokens[:, i + 1 : i + 2]) / 2
new_tokens.append(merged)
i += 2
elif i > 0 and (i - 1) in to_merge:
# Skip this token as it was merged in the previous step
i += 1
else:
# Keep this token as is
new_tokens.append(tokens[:, i : i + 1])
i += 1

# Update tokens
tokens = mx.concatenate(new_tokens, axis=1)
num_tokens = tokens.shape[1]

# Reattach CLS token
return mx.concatenate([image_feature[:, :1], tokens], axis=1)
150 changes: 150 additions & 0 deletions mlx_vlm/models/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import hashlib
import os
from pathlib import Path
from typing import Dict, Optional, Union

import mlx.core as mx
from safetensors.torch import load_file, save_file


class VLMFeatureCache:
"""Cache for storing and retrieving image features from Vision Language Models."""

def __init__(self, cache_dir: Union[str, Path]):
"""
Initialize the feature cache.

Args:
cache_dir: Directory to store the cached features
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)

def _compute_file_hash(self, file_path: Union[str, Path]) -> str:
"""
Compute SHA-256 hash of a file.

Args:
file_path: Path to the file

Returns:
str: Hex digest of the file hash
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
# Read the file in chunks to handle large files efficiently
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()

def _get_cache_path(self, file_hash: str) -> Path:
"""
Get the cache file path for a given file hash.

Args:
file_hash: SHA-256 hash of the original file

Returns:
Path: Path where the cached features should be stored
"""
return self.cache_dir / f"{file_hash}.safetensors"

def save_features(
self,
image_path: Union[str, Path],
features: Dict[str, mx.array],
metadata: Optional[Dict[str, str]] = None,
) -> str:
"""
Save image features to cache.

Args:
image_path: Path to the original image file
features: Dictionary of feature tensors to cache
metadata: Optional metadata to store with the features

Returns:
str: Hash of the cached file
"""
file_hash = self._compute_file_hash(image_path)
cache_path = self._get_cache_path(file_hash)

# Add original file path to metadata
if metadata is None:
metadata = {}
metadata["original_file"] = str(image_path)
metadata["format"] = "mlx"

# Save features using safetensors
mx.save_safetensors(str(cache_path), {"image_features": features}, metadata)
return file_hash

def load_features(
self, image_path: Union[str, Path]
) -> Optional[Dict[str, mx.array]]:
"""
Load cached features for an image if they exist.

Args:
image_path: Path to the image file

Returns:
Optional[Dict[str, mx.array]]: Cached features if they exist, None otherwise
"""
file_hash = self._compute_file_hash(image_path)
cache_path = self._get_cache_path(file_hash)

if not cache_path.exists():
return None

features = mx.load(str(cache_path))
return features

def get_metadata(self, image_path: Union[str, Path]) -> Optional[Dict[str, str]]:
"""
Get metadata for cached features if they exist.

Args:
image_path: Path to the image file

Returns:
Optional[Dict[str, str]]: Metadata if cached features exist, None otherwise
"""
file_hash = self._compute_file_hash(image_path)
cache_path = self._get_cache_path(file_hash)

if not cache_path.exists():
return None

return load_file(cache_path)

def clear_cache(self):
"""Remove all cached features."""
for cache_file in self.cache_dir.glob("*.safetensors"):
cache_file.unlink()

def get_cache_size(self) -> int:
"""
Get the total size of cached features in bytes.

Returns:
int: Total size of cache in bytes
"""
return (
sum(f.stat().st_size for f in self.cache_dir.glob("*.safetensors"))
/ (1024 * 1024 * 1024),
"GB",
)

def __contains__(self, image_path: Union[str, Path]) -> bool:
"""
Check if features for an image are cached.

Args:
image_path: Path to the image file

Returns:
bool: True if features are cached, False otherwise
"""
file_hash = self._compute_file_hash(image_path)
return self._get_cache_path(file_hash).exists()
Loading
Loading