Skip to content

Commit

Permalink
[V1][Core] Autotune encoder cache budget (vllm-project#11895)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored Jan 15, 2025
1 parent edce722 commit 70755e8
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 50 deletions.
15 changes: 10 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,13 +1387,15 @@ class SchedulerConfig:

is_multimodal_model: bool = False

# FIXME(woosuk & ywang96): Below are placeholder values. We need to
# calculate the actual values from the configurations.
# Multimodal encoder run compute budget, only used in V1
max_num_encoder_input_tokens = 16384
# NOTE: The following multimodal encoder budget will be initialized to
# max_num_batched_tokens and overridden in case max multimodal embedding
# size is larger.
# TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1
max_num_encoder_input_tokens: int = field(default=None) # type: ignore

# Multimodal encoder cache size, only used in V1
encoder_cache_size = 16384
encoder_cache_size: int = field(default=None) # type: ignore

# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
Expand Down Expand Up @@ -1467,6 +1469,9 @@ def __post_init__(self) -> None:
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens

if self.enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
Expand Down
29 changes: 24 additions & 5 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,8 @@ def get_max_tokens_per_item_by_modality(
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.
Note:
This is currently directly used only in V1.
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(
Expand All @@ -272,6 +269,28 @@ def get_max_tokens_per_item_by_modality(
for key, plugin in self._plugins.items()
}

def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
limits_per_plugin = self._limits_by_model[model_config]

return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0
}

def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
Expand Down
78 changes: 77 additions & 1 deletion vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Dict, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Set, Tuple

from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.v1.request import Request

if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig

logger = init_logger(__name__)


class EncoderCacheManager:

Expand Down Expand Up @@ -46,3 +53,72 @@ def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed


def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""

if not model_config.is_multimodal_model:
return 0, 0

# TODO: handle encoder-decoder models once we support them.
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)

return encoder_compute_budget, encoder_cache_size


def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""

max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)

if not max_tokens_by_modality_dict:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return 0, 0

_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])

encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
max_tokens_per_mm_item)

return encoder_compute_budget, encoder_cache_size
26 changes: 18 additions & 8 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
Expand All @@ -25,6 +26,7 @@ class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
Expand Down Expand Up @@ -69,16 +71,24 @@ def __init__(
self.running_reqs_data: Dict[str, RunningRequestData] = {}

# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)

# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(
cache_size=self.scheduler_config.encoder_cache_size)
cache_size=encoder_cache_size)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
9 changes: 6 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ def __init__(
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config)
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
)

self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
Expand Down
60 changes: 32 additions & 28 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -88,8 +89,12 @@ def __init__(
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False

self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size

# Lazy initialization
# self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -721,44 +726,30 @@ def profile_run(self) -> None:
]

# Profile with multimodal encoder & encoder cache.
if self.is_multimodal_model:

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
# TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
and self.encoder_cache_size > 0):

# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
self.model_config)

dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])

# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item

# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")
# the encoder budget.
encoder_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)

max_num_mm_items_encoder_budget = cdiv(encoder_budget,
max_tokens_per_mm_item)

# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
self.model_config)[dummy_data_modality]

# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
Expand All @@ -769,6 +760,19 @@ def profile_run(self) -> None:
max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)

logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_budget, max_num_mm_items, dummy_data_modality)

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data

# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
Expand Down

0 comments on commit 70755e8

Please sign in to comment.