Skip to content

Commit

Permalink
add bei specfic user migrations in trt-llm config (#1361)
Browse files Browse the repository at this point in the history
* add bei specfic user migrations

* ruff fmt
  • Loading branch information
michaelfeil authored Feb 3, 2025
1 parent 53db85f commit 9eafca8
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from huggingface_hub.utils import validate_repo_id
from pydantic import BaseModel, PydanticDeprecatedSince20, model_validator, validator

from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS

logger = logging.getLogger(__name__)
# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
Expand Down Expand Up @@ -119,6 +121,12 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
class Config:
extra = "forbid"

def __init__(self, **data):
super().__init__(**data)
self._validate_kv_cache_flags()
self._validate_speculator_config()
self._bei_specfic_migration()

@validator("max_beam_width")
def check_max_beam_width(cls, v: int):
if isinstance(v, int):
Expand All @@ -128,6 +136,29 @@ def check_max_beam_width(cls, v: int):
)
return v

def _bei_specfic_migration(self):
"""performs embedding specfic optimizations (no kv-cache, high batch size)"""
if self.base_model == TrussTRTLLMModel.ENCODER:
# Encoder specific settings
logger.info(
f"Your setting of `build.max_seq_len={self.max_seq_len}` is not used and "
"automatically inferred from the model repo config.json -> `max_position_embeddings`"
)

if self.max_num_tokens < BEI_REQUIRED_MAX_NUM_TOKENS:
logger.warning(
f"build.max_num_tokens={self.max_num_tokens}, upgrading to {BEI_REQUIRED_MAX_NUM_TOKENS}"
)
self.max_num_tokens = BEI_REQUIRED_MAX_NUM_TOKENS
self.plugin_configuration.paged_kv_cache = False
self.plugin_configuration.use_paged_context_fmha = False

if "_kv" in self.quantization_type.value:
raise ValueError(
"encoder does not have a kv-cache, therefore a kv specfic datatype is not valid"
f"you selected build.quantization_type {self.quantization_type}"
)

def _validate_kv_cache_flags(self):
if not self.plugin_configuration.paged_kv_cache and (
self.plugin_configuration.use_paged_context_fmha
Expand Down Expand Up @@ -176,11 +207,6 @@ def max_draft_len(self) -> Optional[int]:
return self.speculator.num_draft_tokens
return None

def __init__(self, **data):
super().__init__(**data)
self._validate_kv_cache_flags()
self._validate_speculator_config()


class TrussSpeculatorConfiguration(BaseModel):
speculative_decoding_mode: TrussSpecDecMode = TrussSpecDecMode.DRAFT_EXTERNAL
Expand Down

0 comments on commit 9eafca8

Please sign in to comment.