Skip to content

Commit

Permalink
[llama] Explicit option to use static tables. (#155)
Browse files Browse the repository at this point in the history
When exporting, it is better to leave table construction to be dynamic
and let the compiler move things to initialization time (versus
materializing large tables, which can be max_context_length^2). We set
static_tables=False unconditionally on export while leaving it True for
eager use.

Contains a workaround for #156.
  • Loading branch information
stellaraccident authored Aug 26, 2024
1 parent 7b11628 commit 686d9a8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
1 change: 1 addition & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def main():

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(hp)
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

Expand Down
17 changes: 11 additions & 6 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(
theta: Theta,
*,
context_length: int,
static_context_mask: bool = True,
static_tables: bool = True,
static_context_mask: bool = False,
device: Optional[torch.device] = None,
activation_dtype: torch.dtype = torch.float32,
attention_dtype: torch.dtype = torch.float32,
Expand All @@ -39,7 +40,7 @@ def __init__(
self.attention_dtype = attention_dtype
self.context_length = context_length

if static_context_mask:
if static_tables:
self.register_buffer(
"causal_context_mask", self.generate_causal_context_mask()
)
Expand All @@ -66,10 +67,12 @@ def _maximally_negative_value(self, dtype):

def generate_causal_context_mask(self) -> torch.Tensor:
context_length = self.context_length
unary_broadcast_ones = torch.ones([1, 1], dtype=torch.bool, device=self.device)
context_broadcast_ones = unary_broadcast_ones.expand(
context_length, context_length
)
causal_context_mask = torch.triu(
torch.ones(
[context_length, context_length], dtype=torch.bool, device=self.device
),
context_broadcast_ones,
diagonal=1,
)[None, None, :, :]
return causal_context_mask
Expand Down Expand Up @@ -114,9 +117,11 @@ def attention_mask(
scenarios can benefit from managing this in different ways.
"""
if causal_context_mask is None:
# Try to use the statically generated.
causal_context_mask = self.causal_context_mask
if causal_context_mask is None:
causal_context_mask = self._generate_causal_context_mask()
# Fallback to dynamically generated.
causal_context_mask = self.generate_causal_context_mask()

# Combine the causal context mask and input mask.
dtype = self.attention_dtype
Expand Down
31 changes: 23 additions & 8 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,29 @@ def __init__(
max_seqlen: int,
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = True,
):
super().__init__()
# Force static_tables until compiler limitations are solved.
# See https://github.com/nod-ai/sharktank/issues/156
static_tables = True
self.device = device
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self._table = self._create_rotary_embed_table(
max_seqlen=max_seqlen,
dim=rope_dimension_count,
)
if static_tables:
self.register_buffer(
"static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None

@property
def rotary_embed_table(self):
if self.static_rotary_embed_table is None:
return self._create_rotary_embed_table()
else:
return self.static_rotary_embed_table

def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int):
# xq_, xk_ shape: bs, sl, _, dim
Expand Down Expand Up @@ -80,7 +95,7 @@ def create_ordering_tensor(dim):
_, sl, _, dim = xq_.shape

# Offset the table based on starting position.
freqs_cis = self._table[start_index : start_index + sl, :]
freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :]
assert freqs_cis.shape[-1] == dim
assert (
freqs_cis.shape[0] >= sl
Expand Down Expand Up @@ -139,7 +154,7 @@ def compute_batch_mask(
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
self.trace_tensor("rope.positions_seq", positions_seq)
freqs_cis = self._table[positions_seq]
freqs_cis = self.rotary_embed_table[positions_seq]

# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
Expand Down Expand Up @@ -167,10 +182,10 @@ def apply_batched_mask(

def _create_rotary_embed_table(
self,
max_seqlen: int,
dim: int,
theta_value: float = 10000.0,
):
dim = self.rope_dimension_count
max_seqlen = self.max_seqlen
freqs = 1.0 / (
theta_value
** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim)
Expand Down
10 changes: 10 additions & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ class LlamaModelConfig:
# rotary embedding).
use_hf: bool = False

# If true, then the model may pre-initialize certain tables during
# init. This can be better for eager execution but when capturing a program,
# it is often better to preserve the calculation explicitly and rely on
# the compiler to transform it to an initialization time step. This can
# be the difference of many gigabytes of static data being embedded in
# the program and not.
static_tables: bool = True

def create_kv_cache(self) -> BaseKVCache:
hp = self.hp
if self.kv_cache_type == "direct":
Expand Down Expand Up @@ -110,6 +118,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
super().__init__(
theta,
context_length=config.hp.context_length,
static_tables=config.static_tables,
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
Expand All @@ -131,6 +140,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
max_seqlen=hp.context_length,
device=self.device,
use_hf=self.use_hf,
static_tables=config.static_tables,
),
)
self.add_module(
Expand Down

0 comments on commit 686d9a8

Please sign in to comment.