diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 98b3f1bf4..78240d614 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -50,6 +50,7 @@ def main(): hp = configs.LlamaHParams.from_gguf_props(dataset.properties) llama_config = LlamaModelConfig(hp) + llama_config.static_tables = False # Rely on the compiler for hoisting tables. llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" model = PagedLlamaModelV1(dataset.root_theta, llama_config) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 91f700789..d253af617 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -28,7 +28,8 @@ def __init__( theta: Theta, *, context_length: int, - static_context_mask: bool = True, + static_tables: bool = True, + static_context_mask: bool = False, device: Optional[torch.device] = None, activation_dtype: torch.dtype = torch.float32, attention_dtype: torch.dtype = torch.float32, @@ -39,7 +40,7 @@ def __init__( self.attention_dtype = attention_dtype self.context_length = context_length - if static_context_mask: + if static_tables: self.register_buffer( "causal_context_mask", self.generate_causal_context_mask() ) @@ -66,10 +67,12 @@ def _maximally_negative_value(self, dtype): def generate_causal_context_mask(self) -> torch.Tensor: context_length = self.context_length + unary_broadcast_ones = torch.ones([1, 1], dtype=torch.bool, device=self.device) + context_broadcast_ones = unary_broadcast_ones.expand( + context_length, context_length + ) causal_context_mask = torch.triu( - torch.ones( - [context_length, context_length], dtype=torch.bool, device=self.device - ), + context_broadcast_ones, diagonal=1, )[None, None, :, :] return causal_context_mask @@ -114,9 +117,11 @@ def attention_mask( scenarios can benefit from managing this in different ways. """ if causal_context_mask is None: + # Try to use the statically generated. causal_context_mask = self.causal_context_mask if causal_context_mask is None: - causal_context_mask = self._generate_causal_context_mask() + # Fallback to dynamically generated. + causal_context_mask = self.generate_causal_context_mask() # Combine the causal context mask and input mask. dtype = self.attention_dtype diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 755392522..18984713d 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -21,14 +21,29 @@ def __init__( max_seqlen: int, device: Optional[torch.device] = None, use_hf: bool = False, + static_tables: bool = True, ): super().__init__() + # Force static_tables until compiler limitations are solved. + # See https://github.com/nod-ai/sharktank/issues/156 + static_tables = True self.device = device + self.rope_dimension_count = rope_dimension_count + self.max_seqlen = max_seqlen self.use_hf = use_hf - self._table = self._create_rotary_embed_table( - max_seqlen=max_seqlen, - dim=rope_dimension_count, - ) + if static_tables: + self.register_buffer( + "static_rotary_embed_table", self._create_rotary_embed_table() + ) + else: + self.static_rotary_embed_table = None + + @property + def rotary_embed_table(self): + if self.static_rotary_embed_table is None: + return self._create_rotary_embed_table() + else: + return self.static_rotary_embed_table def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int): # xq_, xk_ shape: bs, sl, _, dim @@ -80,7 +95,7 @@ def create_ordering_tensor(dim): _, sl, _, dim = xq_.shape # Offset the table based on starting position. - freqs_cis = self._table[start_index : start_index + sl, :] + freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :] assert freqs_cis.shape[-1] == dim assert ( freqs_cis.shape[0] >= sl @@ -139,7 +154,7 @@ def compute_batch_mask( ) + start_positions.unsqueeze(1) # Broadcast lookup to [b, ...]. self.trace_tensor("rope.positions_seq", positions_seq) - freqs_cis = self._table[positions_seq] + freqs_cis = self.rotary_embed_table[positions_seq] # Unsqueeze a unit dim for attention heads. broadcast_freqs_cis = freqs_cis.unsqueeze(2) @@ -167,10 +182,10 @@ def apply_batched_mask( def _create_rotary_embed_table( self, - max_seqlen: int, - dim: int, theta_value: float = 10000.0, ): + dim = self.rope_dimension_count + max_seqlen = self.max_seqlen freqs = 1.0 / ( theta_value ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index aaabd3fe6..984fc6524 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -52,6 +52,14 @@ class LlamaModelConfig: # rotary embedding). use_hf: bool = False + # If true, then the model may pre-initialize certain tables during + # init. This can be better for eager execution but when capturing a program, + # it is often better to preserve the calculation explicitly and rely on + # the compiler to transform it to an initialization time step. This can + # be the difference of many gigabytes of static data being embedded in + # the program and not. + static_tables: bool = True + def create_kv_cache(self) -> BaseKVCache: hp = self.hp if self.kv_cache_type == "direct": @@ -110,6 +118,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): super().__init__( theta, context_length=config.hp.context_length, + static_tables=config.static_tables, device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, @@ -131,6 +140,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): max_seqlen=hp.context_length, device=self.device, use_hf=self.use_hf, + static_tables=config.static_tables, ), ) self.add_module(